~jojo/effem

b7b3f2d7a75034fcfa42f32478346c41aef4de22 — JoJo 3 months ago 62dc666
get rid of the nested defs as localid -> globaldef thing in abase

`Cache::fetch_evaluated_at` couldn't make sense of the `Operand::Local`s
it was getting. And there was no great reason to keep them nested
quite like that. Easier to just generate anonymous globals at the top
level.

We'll still need to keep track of which defs are downstream from which
though, for when we add support for rerunning with the same cache but
modified inputs. We'll need to know which defs to invalidate etc.
That's what the new `children` field in the glob defs. Just a set of
dowstream defs.
7 files changed, 101 insertions(+), 69 deletions(-)

M src/abase.rs
M src/cache.rs
M src/check.rs
M src/eval.rs
M src/main.rs
M src/name.rs
M src/resolve.rs
M src/abase.rs => src/abase.rs +27 -33
@@ 2,6 2,7 @@

use crate::prelude::*;
use fem::{Expr as FExpr, Type};
use std::collections::HashSet;

#[derive(Clone, Debug)]
pub enum GlobDef {


@@ 13,14 14,14 @@ pub enum GlobDef {
#[derive(Clone, Debug)]
pub struct FunDef {
    // Nested definitions (including anonymous functions)
    pub defs: Vec<(LocalId, GlobDef)>,
    pub children: HashSet<DefId>,
    pub params: Vec<LocalId>,
    pub body: Block<Return>,
}

#[derive(Clone, Debug)]
pub struct GVarDef {
    pub defs: Vec<(LocalId, GlobDef)>,
    pub children: HashSet<DefId>,
    pub body: Block<Operand>,
}



@@ 63,11 64,11 @@ pub enum Binop {
    Neq,
}

#[derive(Clone, Debug)]
#[derive(Debug, Clone, Copy)]
pub enum Operand {
    Const(Const),
    // Extern Extern
    Local(DefId),
    Local(LocalId),
    Global(DefId),
}



@@ 144,9 145,10 @@ pub fn abase_def(cache: &mut Cache, rhs: &FExpr) -> Result<GlobDef> {
    AbaseDef::new(cache).run(rhs)
}

#[derive(Debug)]
struct AbaseDef<'c> {
    cache: &'c mut Cache,
    defs: Vec<(LocalId, GlobDef)>,
    children: HashSet<DefId>,
    vars: Vec<LocalId>,
    n_locals: u32,
    stms: Vec<Stm>,


@@ 154,24 156,22 @@ struct AbaseDef<'c> {

impl<'c> AbaseDef<'c> {
    fn new(cache: &'c mut Cache) -> Self {
        Self { cache, defs: vec![], vars: vec![], n_locals: 0, stms: vec![] }
        Self { cache, children: HashSet::new(), vars: vec![], n_locals: 0, stms: vec![] }
    }

    fn run(mut self, rhs: &FExpr) -> Result<GlobDef> {
        match self.abase(rhs)? {
            Expr::Operand(rand) => match rand {
                Operand::Local(id) => {
                    assert_eq!(id, 0);
                    assert_eq!(self.defs.len(), 1);
                    assert_eq!(self.vars.len(), 0);
                    Ok(self.defs.pop().unwrap().1)
                }
                Operand::Local(id) => panic!("ice: result of abasing def is a local reference {id:?}"),
                Operand::Global(id) => Ok(GlobDef::Alias(id)),
                _ => Ok(GlobDef::VarDef(GVarDef { defs: self.defs, body: Block { stms: self.stms, term: rand } })),
                _ => Ok(GlobDef::VarDef(GVarDef {
                    children: self.children,
                    body: Block { stms: self.stms, term: rand },
                })),
            },
            term => {
                let term = self.let_anon_soft(term)?;
                Ok(GlobDef::VarDef(GVarDef { defs: self.defs, body: Block { stms: self.stms, term } }))
                Ok(GlobDef::VarDef(GVarDef { children: self.children, body: Block { stms: self.stms, term } }))
            }
        }
    }


@@ 195,7 195,7 @@ impl<'c> AbaseDef<'c> {
                ref t => panic!("ice: integer expr has non-integer type {t}"),
            }))),
            F64(x) => Ok(Expr::Operand(Operand::Const(Const::F64(*x)))),
            Fun(ps, b) => Ok(Expr::Operand(Operand::Local(self.abase_fun(ps, b)?))),
            Fun(ps, b) => Ok(Expr::Operand(self.abase_fun(ps, b)?)),
            App(f, xs) => match &f.kind {
                Var(ResName { res: Res::Prim(p) }) => match p {
                    Prim::Add => self.abase_binop(xs, Binop::Add),


@@ 224,8 224,8 @@ impl<'c> AbaseDef<'c> {
                }
            },
            Var(ResName { res: Res::Def(id) }) => Ok(Expr::Operand(Operand::Global(*id))),
            Var(ResName { res: Res::Local(id) }) =>
                Ok(Expr::Operand(Operand::Local(*self.vars.get(*id as usize).unwrap_or_else(|| todo!())))),
            &Var(ResName { res: Res::Local(LocalId(id)) }) =>
                Ok(Expr::Operand(Operand::Local(*self.vars.get(id as usize).unwrap_or_else(|| todo!())))),
            Var(ResName { res: Res::Prim(_) }) => todo!(), // TODO: generate a closure around the op or smth
            Var(ResName { res: Res::Module(_) }) =>
                panic!("ice: found module id in expr context when abasing. Should've been caught by type checker."),


@@ 248,30 248,24 @@ impl<'c> AbaseDef<'c> {
    }

    // TODO: closures. capture free vars etc.
    fn abase_fun(&mut self, params: &[(PubIdent, LocalId)], body: &FExpr) -> Result<LocalId> {
        let fid = self.gen_local();

        let old_defs = std::mem::take(&mut self.defs);
    fn abase_fun(&mut self, params: &[(PubIdent, LocalId)], body: &FExpr) -> Result<Operand> {
        let old_children = std::mem::take(&mut self.children);
        let old_vars = std::mem::take(&mut self.vars);

        let params = params
            .iter()
            .map(|(_, param_id)| {
                assert_eq!(self.vars.len(), *param_id as usize);
                self.gen_local()
            })
            .collect::<Vec<_>>();
        debug_assert!(params.iter().enumerate().all(|(i, &(_, LocalId(param_id)))| param_id as usize == i));
        let params = params.iter().map(|_| self.gen_local()).collect::<Vec<_>>();
        self.vars.extend(&params);

        let body = self.enter_block(|self_| self_.abase(body).map(Return::Val))?;

        let new_defs = std::mem::replace(&mut self.defs, old_defs);
        let new_children = std::mem::replace(&mut self.children, old_children);
        self.vars = old_vars;

        let def = FunDef { defs: new_defs, params, body };
        self.defs.push((fid, GlobDef::FunDef(def)));
        let def = FunDef { children: new_children, params, body };
        let fid = self.cache.insert_base_anon(GlobDef::FunDef(def));
        self.children.insert(fid);

        Ok(fid)
        Ok(Operand::Global(fid))
    }

    fn enter_block<T>(&mut self, f: impl FnOnce(&mut Self) -> Result<T>) -> Result<Block<T>> {


@@ 284,7 278,7 @@ impl<'c> AbaseDef<'c> {
    fn gen_local(&mut self) -> LocalId {
        assert!(self.n_locals < u32::MAX);
        self.n_locals += 1;
        self.n_locals - 1
        LocalId(self.n_locals - 1)
    }

    fn let_anon_soft(&mut self, rhs: Expr) -> Result<Operand> {

M src/cache.rs => src/cache.rs +18 -4
@@ 50,7 50,11 @@ impl Cache {
    }

    pub fn fetch_evaluated_at(&mut self, fundef_query: DefId, args: &[abase::Operand]) -> Result<abase::Operand> {
        debug_assert!(args.iter().all(|arg| !matches!(arg, abase::Operand::Local(_))));
        #[cfg(debug_assertions)]
        for arg in args.iter().filter(|arg| matches!(arg, abase::Operand::Local(_))) {
            panic!("ice: Local operand {arg:?} as arg to another function");
        }

        let _ = self.fetch_evaluated(fundef_query)?;
        match self.abaseds[&fundef_query] {
            abase::GlobDef::FunDef(ref fdef) => eval::eval_fun_at(self, &fdef.clone(), args),


@@ 87,6 91,12 @@ impl Cache {
        })
    }

    pub fn insert_base_anon(&mut self, def: abase::GlobDef) -> DefId {
        let def_id = self.gen_def_id();
        self.abaseds.insert(def_id, def);
        def_id
    }

    pub fn fetch_checked(&mut self, def_query: DefId) -> Result<Fetch<&fem::Expr>> {
        if self.checkeds.contains_key(&def_query) {
            Ok(self.checkeds[&def_query].as_ref())


@@ 262,10 272,8 @@ impl Cache {
            Some(Res::Def(def_id)) => *def_id,
            Some(res) => panic!("ice: fetching/generating definition resolution for name {def_name}, but that name already resolves to {res:?}"),
            None => {
                assert!(self.n_defs < u32::MAX);
                let def_id = self.n_defs;
                let def_id = self.gen_def_id();
                let res = Res::Def(def_id);
                self.n_defs += 1;
                self.resolved_names.insert(def_name.clone(), res);
                self.resolved_names_rev.insert(res, def_name.clone());
                self.parent_modules.insert(def_id, module_id);


@@ 274,6 282,12 @@ impl Cache {
        }
    }

    fn gen_def_id(&mut self) -> DefId {
        assert!(self.n_defs < u32::MAX);
        self.n_defs += 1;
        self.n_defs - 1
    }

    pub fn fetch_def_name(&self, def_id: DefId) -> &FullName<String> {
        self.resolved_names_rev.get(&Res::Def(def_id)).expect("ice: if we have a def id, there should be a reverse")
    }

M src/check.rs => src/check.rs +2 -2
@@ 25,7 25,7 @@ impl<'c> CheckDef<'c> {
    fn infer(&mut self, expr: &RExpr) -> Result<Expr> {
        use resolve::ExprKind as Rek;
        match &expr.kind {
            &Rek::Bool(x) => Ok(Expr { typ: Type::F64, kind: ExprKind::Bool(x) }),
            &Rek::Bool(x) => Ok(Expr { typ: Type::Bool, kind: ExprKind::Bool(x) }),
            Rek::Int(x) => Ok(Expr {
                typ: if (isize::MIN as i128..=isize::MAX as i128).contains(x) {
                    Type::ISize


@@ 81,7 81,7 @@ impl<'c> CheckDef<'c> {
            }
            Rek::Var(r @ ResName { res: Res::Local(id) }) => match self.get_local(*id) {
                Some(t) => Ok(Expr { typ: t.clone(), kind: ExprKind::Var(r.clone()) }),
                None => panic!("ice: undefined local var of id {id}"),
                None => panic!("ice: undefined local var of id {id:?}"),
            },
            Rek::Var(r @ ResName { res: Res::Prim(Prim::Add | Prim::Sub | Prim::Mul | Prim::Quot | Prim::Rem) }) =>
                Ok(Expr {

M src/eval.rs => src/eval.rs +28 -28
@@ 3,23 3,22 @@ use abase::*;
use std::collections::HashMap;

pub fn eval_fun_at(cache: &mut Cache, def: &FunDef, args: &[Operand]) -> Result<Operand> {
    EvalDef::new(cache, &def.defs).run_fun(&def.params, args, &def.body)
    EvalDef::new(cache).run_fun(&def.params, args, &def.body)
}

pub fn eval_var(cache: &mut Cache, def: &GVarDef) -> Result<Operand> {
    EvalDef::new(cache, &def.defs).run_var(&def.body)
    EvalDef::new(cache).run_var(&def.body)
}

#[derive(Debug)]
struct EvalDef<'c, 'd> {
struct EvalDef<'c> {
    cache: &'c mut Cache,
    local_defs: HashMap<LocalId, &'d GlobDef>,
    regs: HashMap<LocalId, Operand>,
}

impl<'c, 'd> EvalDef<'c, 'd> {
    fn new(cache: &'c mut Cache, local_defs: &'d [(LocalId, GlobDef)]) -> Self {
        Self { cache, local_defs: local_defs.iter().map(|(id, def)| (*id, def)).collect(), regs: HashMap::new() }
impl<'c> EvalDef<'c> {
    fn new(cache: &'c mut Cache) -> Self {
        Self { cache, regs: HashMap::new() }
    }

    fn run_fun(mut self, params: &[LocalId], args: &[Operand], body: &Block<Return>) -> Result<Operand> {


@@ 36,7 35,7 @@ impl<'c, 'd> EvalDef<'c, 'd> {

    fn run_var(mut self, body: &Block<Operand>) -> Result<Operand> {
        self.eval_stms(&body.stms)?;
        self.eval_operand(&body.term)
        self.eval_operand(body.term)
    }

    fn eval(&mut self, expr: &Expr) -> Result<Operand> {


@@ 44,7 43,7 @@ impl<'c, 'd> EvalDef<'c, 'd> {
        use Binop::*;
        use Operand::*;
        match expr {
            Expr::Binop(op, x, y) => match (self.eval_operand(x)?, self.eval_operand(y)?) {
            &Expr::Binop(op, x, y) => match (self.eval_operand(x)?, self.eval_operand(y)?) {
                (Const(x), Const(y)) => Ok(Const(match op {
                    Add => (x + y).unwrap(),
                    Sub => (x - y).unwrap(),


@@ 60,15 59,15 @@ impl<'c, 'd> EvalDef<'c, 'd> {
                })),
                (x_e, y_e) => panic!("ice: cannot apply binary operation {op:?} to {x_e:?} and {y_e:?}"),
            },
            Expr::Call(f, xs) => match self.eval_operand(f)? {
            Expr::Call(f, xs) => match self.eval_operand(*f)? {
                Global(fid) => {
                    let xs = xs.iter().map(|x| self.eval_operand(x)).collect::<Result<Vec<_>>>()?;
                    let xs = xs.iter().map(|x| self.eval_operand(*x)).collect::<Result<Vec<_>>>()?;
                    self.cache.fetch_evaluated_at(fid, &xs)
                }
                fe => panic!("ice: applying non-function {fe:?} ({f:?})"),
            },
            Expr::Operand(rand) => self.eval_operand(rand),
            Expr::If(pred, conseq, alt) => match self.eval_operand(pred)? {
            &Expr::Operand(rand) => self.eval_operand(rand),
            Expr::If(pred, conseq, alt) => match self.eval_operand(*pred)? {
                Const(Bool(true)) => self.eval_stms(&conseq.stms).and_then(|()| self.eval(&conseq.term)),
                Const(Bool(false)) => self.eval_stms(&alt.stms).and_then(|()| self.eval(&alt.term)),
                pe => panic!("ice: `if` expects const bool predicate, found {pe:?}"),


@@ 93,22 92,23 @@ impl<'c, 'd> EvalDef<'c, 'd> {
        }
    }

    fn eval_operand(&mut self, rand: &Operand) -> Result<Operand> {
        match rand {
            Operand::Const(_) => Ok(rand.clone()),
            Operand::Local(id) =>
                if let Some(val) = self.regs.get(id) {
                    Ok(val.clone())
                } else if let Some(def) = self.local_defs.get(id) {
                    match def {
                        GlobDef::FunDef(_) => Ok(rand.clone()),
                        GlobDef::VarDef(def) => eval_var(self.cache, def),
                        GlobDef::Alias(target) => self.eval_operand(&Operand::Global(*target)),
                    }
                } else {
                    panic!("ice: undefined local {rand:?}\nregs: {:?}\ndefs: {:?}", self.regs, self.local_defs)
    fn eval_operand(&mut self, mut rand: Operand) -> Result<Operand> {
        let max_loops = 64;
        for _ in 0..max_loops {
            match rand {
                Operand::Const(_) => return Ok(rand),
                Operand::Local(id) =>
                    if let Some(val) = self.regs.get(&id) {
                        rand = val.clone();
                    } else {
                        panic!("ice: undefined local {rand:?}\nregs: {:?}", self.regs)
                    },
                Operand::Global(id) => match *self.cache.fetch_evaluated(id)? {
                    Operand::Global(id2) if id == id2 => return Ok(Operand::Global(id)),
                    rand2 => rand = rand2,
                },
            Operand::Global(id) => self.cache.fetch_evaluated(*id).cloned(),
            }
        }
        panic!("ice: eval_operand looped {max_loops} times. This is probably a compiler bug.")
    }
}

M src/main.rs => src/main.rs +22 -0
@@ 244,4 244,26 @@ mod test {
        ";
        assert_matches!(run_tmp(src), Ok(Operand::Const(Const::Int(3628800))))
    }

    #[test]
    fn test_mut_rec() {
        let src = "
            (def main (even? 20))
            (def even? (of (Fun [Int] Bool)
                           (fun [n] (if (= n 0) True (odd? (- n 1))))))
            (def odd? (of (Fun [Int] _)
                          (fun [n] (if (= n 1) True (even? (- n 1))))))
        ";
        assert_matches!(run_tmp(src), Ok(Operand::Const(Const::Bool(true))))
    }

    #[test]
    fn test_twice_mono() {
        let src = "
            (def twice (of (Fun [(Fun [Int] Int) Int] Int)
                           (fun [f x] (f (f x)))))
            (def main (twice (fun [x] (+ x 1)) 8))
        ";
        assert_matches!(run_tmp(src), Ok(Operand::Const(Const::Int(10))))
    }
}

M src/name.rs => src/name.rs +3 -1
@@ 41,7 41,9 @@ impl Res {
}

pub type DefId = u32;
pub type LocalId = u32;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct LocalId(pub u32);

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]

M src/resolve.rs => src/resolve.rs +1 -1
@@ 107,6 107,6 @@ impl<'c> Resolver<'c> {
    fn gen_local_id(&mut self) -> LocalId {
        assert!(self.local_count < u32::MAX);
        self.local_count += 1;
        self.local_count - 1
        LocalId(self.local_count - 1)
    }
}