~jojo/kapreolo

f0b5398cf45507de4a305013062d0de986589d22 — JoJo 2 months ago ff31d0c
Collapse chain of decision tree tests to switch
3 files changed, 67 insertions(+), 41 deletions(-)

M src/abase.rs
M src/base.rs
M src/eval.rs
M src/abase.rs => src/abase.rs +44 -14
@@ 177,20 177,50 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
    ) -> Result<Flow> {
        use pattern_match::*;
        match tree {
            // TODO: switch if multiple Test with same access in a row
            Decision::Test(accessor, typ, tag, yes, no) => {
                let expected_tag = Operand::Const(abase_tag(tag, &typ).unwrap());
            Decision::Test(accessor, typ, tag, yes, mut no) => {
                let found_tag = self.abase_accessor(accessors, access, accessor)?;
                let tags_eq =
                    self.let_anon_hard(Flow::Produce(Operation::Binop(Binop::Eq, found_tag, expected_tag)))?;
                let yes = self.enter_block(|self_| {
                    let mut access = access.clone(); // Bindings from this branch won't be available in the no-branch
                    self_.abase_decision_tree(accessors, &mut access, rhs_instances, rhs_param_blocks, *yes)
                })?;
                let no = self.enter_block(|self_| {
                    self_.abase_decision_tree(accessors, access, rhs_instances, rhs_param_blocks, *no)
                })?;
                Ok(Flow::Diverge(Diverge::If(tags_eq, yes, no)))
                // Collapse chain of tests on same accessor to a single switch
                let mut cases = vec![];
                loop {
                    match *no {
                        Decision::Test(accessor1, _, tag1, yes1, no1) if accessor1 == accessor => {
                            cases.push((tag1, yes1));
                            no = no1;
                        }
                        _ => break,
                    }
                }
                if cases.is_empty() {
                    let expected_tag = Operand::Const(abase_tag(tag, &typ).unwrap());
                    let tags_eq =
                        self.let_anon_hard(Flow::Produce(Operation::Binop(Binop::Eq, found_tag, expected_tag)))?;
                    let yes = self.enter_block(|self_| {
                        let mut access = access.clone(); // Bindings from this branch won't be available in the no-branch
                        self_.abase_decision_tree(accessors, &mut access, rhs_instances, rhs_param_blocks, *yes)
                    })?;
                    let no = self.enter_block(|self_| {
                        self_.abase_decision_tree(accessors, access, rhs_instances, rhs_param_blocks, *no)
                    })?;
                    Ok(Flow::Diverge(Diverge::If(tags_eq, yes, no)))
                } else {
                    cases.push((tag, yes));
                    cases.sort_by_key(|&(tag, _)| tag);
                    let cases = cases
                        .into_iter()
                        .map(|(lhs, rhs)| {
                            let lhs = abase_tag(lhs, &typ).unwrap();
                            let rhs = self.enter_block(|self_| {
                                let mut access = access.clone(); // Bindings from this branch won't be available in any other branches
                                self_.abase_decision_tree(accessors, &mut access, rhs_instances, rhs_param_blocks, *rhs)
                            })?;
                            Ok((lhs, rhs))
                        })
                        .collect::<Result<Vec<(Const, BlockRef)>>>()?;
                    let default = self.enter_block(|self_| {
                        self_.abase_decision_tree(accessors, access, rhs_instances, rhs_param_blocks, *no)
                    })?;
                    Ok(Flow::Diverge(Diverge::Switch { obj: found_tag, cases, default }))
                }
            }
            Decision::Accept(vars, rhs) if rhs_instances[&rhs] > 1 => {
                if !rhs_param_blocks.contains_key(&rhs) {


@@ 204,7 234,7 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
                    rhs_param_blocks.insert(rhs, blk);
                }
                let blk = rhs_param_blocks[&rhs];
                Ok(Flow::CallBlock(
                Ok(Flow::Sequence(
                    blk,
                    vars.iter().map(|&(_, a)| self.abase_accessor(accessors, access, a)).collect::<Result<_>>()?,
                ))

M src/base.rs => src/base.rs +14 -18
@@ 58,7 58,7 @@ pub enum Stm {
#[derive(Debug, Clone)]
pub enum Flow {
    Diverge(Diverge),
    CallBlock(BlockRef, Vec<Operand>),
    Sequence(BlockRef, Vec<Operand>),
    // Void,
    Produce(Operation),
    // Continue(Vec<Operand>) // only valid in loops


@@ 67,7 67,7 @@ pub enum Flow {
#[derive(Debug, Clone)]
pub enum Diverge {
    If(Operand, BlockRef, BlockRef),
    // Switch { obj: Operand, cases: Vec<(Const, BlockRef)>, default: Option<BlockRef> },
    Switch { obj: Operand, cases: Vec<(Const, BlockRef)>, default: BlockRef },
}

#[derive(Clone, Debug)]


@@ 405,7 405,7 @@ impl Pretty for Flow {
        match self {
            Flow::Produce(ration) => ration.pretty(pr),
            Flow::Diverge(div) => div.pretty(pr),
            Flow::CallBlock(blk, args) => {
            Flow::Sequence(blk, args) => {
                write!(pr, "(goto {blk} [")?;
                pr.write_sep(args, " ")?;
                pr.write_str("])")


@@ 443,21 443,17 @@ impl Pretty for Diverge {
    fn pretty(&self, pr: &mut Prettier<Body>) -> FResult {
        match self {
            Diverge::If(pred, conseq, alt) => write!(pr, "(if {pred} {conseq} {alt})"),
            // Diverge::Switch { obj, cases, default } => {
            //     write!(pr, "(switch {obj}")?;
            //     pr.newline_indent()?;
            //     pr.add_indent(2, |pr| {
            //         for (lhs, rhs) in cases {
            //             write!(pr, "[{lhs} {rhs}]")?;
            //             pr.newline_indent()?;
            //         }
            //         match default {
            //             Some(default) => write!(pr, ":default {default}")?,
            //             None => pr.write_str(";; default not needed")?,
            //         }
            //         pr.write_char(')')
            //     })
            // }
            Diverge::Switch { obj, cases, default } => {
                write!(pr, "(switch {obj}")?;
                pr.newline_indent()?;
                pr.add_indent(2, |pr| {
                    for (lhs, rhs) in cases {
                        write!(pr, "[{lhs} {rhs}]")?;
                        pr.newline_indent()?;
                    }
                    write!(pr, ":default {default})")
                })
            }
        }
    }
}

M src/eval.rs => src/eval.rs +9 -9
@@ 67,15 67,15 @@ impl<'c, 'd, 'a> EvalDef<'c, 'd, 'a> {
                Const(Bool(false)) => self.eval_block_ref(*alt, &[]),
                pe => panic!("ice: `if` expects const bool predicate, found {pe:?}"),
            },
            // &Flow::Diverge(Diverge::Switch { obj, ref cases, default }) => match self.eval_operand(obj)? {
            //     Const(obj) =>
            //         match cases.binary_search_by(|(lhs, _)| lhs.partial_cmp(&obj).unwrap_or(Ordering::Less)) {
            //             Ok(i) => self.eval_block_ref(cases[i].1),
            //             Err(_) => self.eval_block_ref(default.unwrap()),
            //         },
            //     _ => panic!("ice: couldn't evaluate `switch` object operand {obj} to Const"),
            // },
            Flow::CallBlock(blk, args) => self.eval_block_ref(*blk, args),
            Flow::Diverge(Diverge::Switch { obj, cases, default }) => match self.eval_operand(obj.clone())? {
                Const(obj) =>
                    match cases.binary_search_by(|(lhs, _)| lhs.partial_cmp(&obj).unwrap_or(Ordering::Less)) {
                        Ok(i) => self.eval_block_ref(cases[i].1, &[]),
                        Err(_) => self.eval_block_ref(*default, &[]),
                    },
                _ => panic!("ice: couldn't evaluate `switch` object operand {obj} to Const"),
            },
            Flow::Sequence(blk, args) => self.eval_block_ref(*blk, args),
        }
    }