~jojo/effem

626aefb6349872b096c80637546d9b1522a1d471 — JoJo 3 months ago 0b998ce
abase & eval lambdas (not capturing closures yet though)
4 files changed, 148 insertions(+), 79 deletions(-)

M src/abase.rs
M src/cache.rs
M src/eval.rs
M src/main.rs
M src/abase.rs => src/abase.rs +64 -24
@@ 5,52 5,53 @@ use crate::fem::Expr as FExpr;
use crate::name::*;
use anyhow::{anyhow, Result};

#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum GlobDef {
    FunDef(FunDef),
    VarDef(GVarDef),
    Alias(DefId),
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct FunDef {
    // Nested definitions (including anonymous functions)
    pub defs: Vec<(LocalId, GlobDef)>,
    pub params: Vec<()>,
    pub params: Vec<LocalId>,
    pub body: Block<Return>,
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct GVarDef {
    pub defs: Vec<(LocalId, GlobDef)>,
    pub body: Block<Operand>,
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub struct Block<Term> {
    pub stms: Vec<Stm>,
    pub term: Term,
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Stm {
    Let { lhs: LocalId, rhs: Expr },
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Return {
    Val(Expr),
    Void,
    // Void,
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Expr {
    Add(Operand, Operand),
    Mul(Operand, Operand),
    Call(Operand, Vec<Operand>),
    Operand(Operand),
}

#[derive(Clone)]
#[derive(Clone, Debug)]
pub enum Operand {
    Const(Const),
    // Extern Extern


@@ 93,7 94,7 @@ impl<'c> AbaseDef<'c> {
                _ => Ok(GlobDef::VarDef(GVarDef { defs: self.defs, body: Block { stms: self.stms, term: rand } })),
            },
            term => {
                let term = self.emit_let_anon(term)?;
                let term = self.let_anon_soft(term)?;
                Ok(GlobDef::VarDef(GVarDef { defs: self.defs, body: Block { stms: self.stms, term } }))
            }
        }


@@ 103,25 104,36 @@ impl<'c> AbaseDef<'c> {
        match *expr {
            FExpr::F64(x) => Ok(Expr::Operand(Operand::Const(Const::F64(x)))),
            FExpr::Fun(ref ps, ref b) => Ok(Expr::Operand(Operand::Local(self.abase_fun(ps, b)?))),
            FExpr::App(ref f, ref xs) => Ok(match **f {
            FExpr::App(ref f, ref xs) => match **f {
                FExpr::Var(ResName { res: Res::Prim(p) }) => match p {
                    Prim::Add => {
                        let x = self.abase(&xs[0])?;
                        let x = self.emit_let_anon(x)?;
                        let x = self.let_anon_soft(x)?;
                        let y = self.abase(&xs[1])?;
                        let y = self.emit_let_anon(y)?;
                        Expr::Add(x, y)
                        let y = self.let_anon_soft(y)?;
                        Ok(Expr::Add(x, y))
                    }
                    Prim::Mul => {
                        let x = self.abase(&xs[0])?;
                        let x = self.emit_let_anon(x)?;
                        let x = self.let_anon_soft(x)?;
                        let y = self.abase(&xs[1])?;
                        let y = self.emit_let_anon(y)?;
                        Expr::Mul(x, y)
                        let y = self.let_anon_soft(y)?;
                        Ok(Expr::Mul(x, y))
                    }
                },
                _ => todo!(),
            }),
                _ => {
                    let f = self.abase(f)?;
                    let f = self.let_anon_soft(f)?;
                    let xs = xs
                        .iter()
                        .map(|x| {
                            let x = self.abase(x)?;
                            self.let_anon_soft(x)
                        })
                        .collect::<Result<Vec<Operand>>>()?;
                    Ok(Expr::Call(f, xs))
                }
            },
            FExpr::Var(ResName { res: Res::Def(id) }) => {
                let _ = self.cache.fetch_base(id)?;
                Ok(Expr::Operand(Operand::Global(id)))


@@ 133,10 145,32 @@ impl<'c> AbaseDef<'c> {
        }
    }

    // TODO: closures. capture free vars etc.
    fn abase_fun(&mut self, params: &[(PubIdent, LocalId)], body: &FExpr) -> Result<LocalId> {
        // TODO: closures. capture free vars etc.

        todo!()
        let fid = self.gen_local();

        let old_defs = std::mem::replace(&mut self.defs, vec![]);
        let old_vars = std::mem::replace(&mut self.vars, vec![]);
        let old_stms = std::mem::replace(&mut self.stms, vec![]);

        let params = params
            .iter()
            .map(|(_, param_id)| {
                assert_eq!(self.vars.len(), *param_id as usize);
                self.gen_local()
            })
            .collect::<Vec<_>>();
        self.vars.extend(&params);
        let term = self.abase(body)?;

        let new_defs = std::mem::replace(&mut self.defs, old_defs);
        self.vars = old_vars;
        let new_stms = std::mem::replace(&mut self.stms, old_stms);

        let def = FunDef { defs: new_defs, params, body: Block { stms: new_stms, term: Return::Val(term) } };
        self.defs.push((fid, GlobDef::FunDef(def)));

        Ok(fid)
    }

    fn gen_local(&mut self) -> LocalId {


@@ 145,7 179,13 @@ impl<'c> AbaseDef<'c> {
        self.n_locals - 1
    }

    fn emit_let_anon(&mut self, rhs: Expr) -> Result<Operand> {
    fn let_anon_soft(&mut self, rhs: Expr) -> Result<Operand> {
        match rhs {
            Expr::Operand(rand) => Ok(rand),
            _ => self.let_anon_hard(rhs),
        }
    }
    fn let_anon_hard(&mut self, rhs: Expr) -> Result<Operand> {
        let lhs = self.gen_local();
        self.stms.push(Stm::Let { lhs, rhs });
        Ok(Operand::Local(lhs))

M src/cache.rs => src/cache.rs +19 -3
@@ 18,7 18,7 @@ pub struct Cache {
    resolveds: HashMap<DefId, RExpr>,
    // desugareds: HashMap<Query, fem::Expr>,
    abaseds: HashMap<DefId, abase::GlobDef>,
    evaluateds: HashMap<DefId, eval::Val>,
    evaluateds: HashMap<DefId, abase::Operand>,
}

impl Cache {


@@ 36,12 36,28 @@ impl Cache {
        }
    }

    pub fn fetch_evaluated(&mut self, def_query: DefId) -> anyhow::Result<&eval::Val> {
    pub fn fetch_evaluated_at(
        &mut self,
        fundef_query: DefId,
        args: &[abase::Operand],
    ) -> anyhow::Result<abase::Operand> {
        let _ = self.fetch_evaluated(fundef_query)?;
        match self.abaseds[&fundef_query] {
            abase::GlobDef::FunDef(ref fdef) => eval::eval_fun_at(self, &fdef.clone(), args),
            ref def => panic!("ice: {def:?} is not a function def. Cannot evaluate at arguments."),
        }
    }

    pub fn fetch_evaluated(&mut self, def_query: DefId) -> anyhow::Result<&abase::Operand> {
        if self.evaluateds.contains_key(&def_query) {
            Ok(self.evaluateds.get(&def_query).unwrap())
        } else {
            let base = self.fetch_base(def_query)?.clone();
            let val = eval::eval_def(self, &base)?;
            let val = match base {
                abase::GlobDef::VarDef(def) => eval::eval_var(self, &def)?,
                abase::GlobDef::Alias(target) => self.fetch_evaluated(target)?.clone(),
                abase::GlobDef::FunDef(_) => abase::Operand::Global(def_query),
            };
            self.evaluateds.insert(def_query, val);
            Ok(&self.evaluateds[&def_query])
        }

M src/eval.rs => src/eval.rs +56 -46
@@ 1,63 1,62 @@
use crate::abase::{Const, Expr, FunDef, GVarDef, GlobDef, Operand, Stm};
use crate::abase::*;
use crate::{cache::Cache, name::*};
use anyhow::Result;
use std::collections::HashMap;

// #[derive(Debug, Clone, PartialEq)]
// pub enum Val {
//     F64(f64),
// }
pub type Val = Const;

impl std::ops::Add for Val {
    type Output = Val;
    fn add(self, other: Val) -> Val {
        match (self, other) {
            (Val::F64(x), Val::F64(y)) => Val::F64(x + y),
        }
    }
}
impl std::ops::Mul for Val {
    type Output = Val;
    fn mul(self, other: Val) -> Val {
        match (self, other) {
            (Val::F64(x), Val::F64(y)) => Val::F64(x * y),
        }
    }
pub fn eval_fun_at(cache: &mut Cache, def: &FunDef, args: &[Operand]) -> Result<Operand> {
    println!("eval_fun_at: args: {args:?}\ndef: {def:?}\n");
    EvalDef::new(cache, &def.defs).run_fun(&def.params, args, &def.body)
}

pub fn eval_def(cache: &mut Cache, def: &GlobDef) -> Result<Val> {
    EvalDef::new(cache).run(def)
pub fn eval_var(cache: &mut Cache, def: &GVarDef) -> Result<Operand> {
    println!("eval_var: def: {def:?}\n");
    EvalDef::new(cache, &def.defs).run_var(&def.body)
}

struct EvalDef<'c> {
struct EvalDef<'c, 'd> {
    cache: &'c mut Cache,
    regs: HashMap<LocalId, Val>,
    local_defs: HashMap<LocalId, &'d GlobDef>,
    regs: HashMap<LocalId, Operand>,
}

impl<'c> EvalDef<'c> {
    fn new(cache: &'c mut Cache) -> Self {
        Self { cache, regs: HashMap::new() }
impl<'c, 'd> EvalDef<'c, 'd> {
    fn new(cache: &'c mut Cache, local_defs: &'d [(LocalId, GlobDef)]) -> Self {
        Self { cache, local_defs: local_defs.iter().map(|(id, def)| (*id, def)).collect(), regs: HashMap::new() }
    }

    fn run(mut self, def: &GlobDef) -> Result<Val> {
        match def {
            GlobDef::FunDef(FunDef { defs, params, body }) => todo!(),
            GlobDef::VarDef(GVarDef { defs, body }) if defs.is_empty() => {
                for stm in &body.stms {
                    self.eval_stm(stm)?;
                }
                self.eval_operand(&body.term)
            }
            GlobDef::VarDef(GVarDef { defs, body }) => todo!(),
            GlobDef::Alias(target) => self.cache.fetch_evaluated(*target).cloned(),
    fn run_fun(mut self, params: &[LocalId], args: &[Operand], body: &Block<Return>) -> Result<Operand> {
        for (id, arg) in params.iter().zip(args) {
            self.regs.insert(*id, arg.clone());
        }
        for stm in &body.stms {
            self.eval_stm(stm)?;
        }
        match &body.term {
            Return::Val(tail) => self.eval(tail),
        }
    }

    fn run_var(mut self, body: &Block<Operand>) -> Result<Operand> {
        for stm in &body.stms {
            self.eval_stm(stm)?;
        }
        self.eval_operand(&body.term)
    }

    fn eval(&mut self, expr: &Expr) -> Result<Val> {
    fn eval(&mut self, expr: &Expr) -> Result<Operand> {
        match expr {
            Expr::Add(x, y) => Ok(self.eval_operand(x)? + self.eval_operand(y)?),
            Expr::Mul(x, y) => Ok(self.eval_operand(x)? * self.eval_operand(y)?),
            Expr::Add(x, y) => match (self.eval_operand(x)?, self.eval_operand(y)?) {
                (Operand::Const(Const::F64(x)), Operand::Const(Const::F64(y))) => Ok(Operand::Const(Const::F64(x + y))),
                (xe, ye) => panic!("ice: cannot add {xe:?} + {ye:?} ({x:?} + {y:?})"),
            },
            Expr::Mul(x, y) => match (self.eval_operand(x)?, self.eval_operand(y)?) {
                (Operand::Const(Const::F64(x)), Operand::Const(Const::F64(y))) => Ok(Operand::Const(Const::F64(x * y))),
                (xe, ye) => panic!("ice: cannot add {xe:?} * {ye:?} ({x:?} * {y:?})"),
            },
            Expr::Call(f, xs) => match self.eval_operand(f)? {
                Operand::Global(fid) => self.cache.fetch_evaluated_at(fid, xs),
                fe => panic!("ice: applying non-function {fe:?} ({f:?})"),
            },
            Expr::Operand(rand) => self.eval_operand(rand),
        }
    }


@@ 72,10 71,21 @@ impl<'c> EvalDef<'c> {
        }
    }

    fn eval_operand(&mut self, rand: &Operand) -> Result<Val> {
    fn eval_operand(&mut self, rand: &Operand) -> Result<Operand> {
        match rand {
            Operand::Const(x) => Ok(x.clone()),
            Operand::Local(id) => Ok(self.regs.get(id).unwrap_or_else(|| todo!()).clone()),
            Operand::Const(_) => Ok(rand.clone()),
            Operand::Local(id) =>
                if let Some(val) = self.regs.get(id) {
                    Ok(val.clone())
                } else if let Some(def) = self.local_defs.get(id) {
                    match def {
                        GlobDef::FunDef(_) => Ok(rand.clone()),
                        GlobDef::VarDef(def) => eval_var(self.cache, def),
                        GlobDef::Alias(target) => self.eval_operand(&Operand::Global(*target)),
                    }
                } else {
                    panic!("ice: undefined local {rand:?}\nregs: {:?}\ndefs: {:?}", self.regs, self.local_defs)
                },
            Operand::Global(id) => self.cache.fetch_evaluated(*id).cloned(),
        }
    }

M src/main.rs => src/main.rs +9 -6
@@ 40,9 40,10 @@ fn main() {
#[cfg(test)]
mod test {
    use super::*;
    use abase::*;
    use eval::*;

    fn run_tmp(main_src: &str) -> anyhow::Result<Val> {
    fn run_tmp(main_src: &str) -> anyhow::Result<Operand> {
        let dir = tempfile::tempdir().unwrap();
        std::fs::write(dir.path().join("main.fm"), main_src).unwrap();
        let mut cache = Cache::for_package_in(dir.path());


@@ 56,17 57,19 @@ mod test {

    #[test]
    fn test_main_literal() {
        assert!(matches!(run_tmp("(def main 123.456)"), Ok(Val::F64(x)) if x == 123.456))
        assert!(matches!(run_tmp("(def main 123.456)"), Ok(Operand::Const(Const::F64(x))) if x == 123.456))
    }

    #[test]
    fn test_main_arithm1() {
        assert!(matches!(run_tmp("(def main (+.prim. 1.0 2.0))"), Ok(Val::F64(x)) if x == 3.0))
        assert!(matches!(run_tmp("(def main (+.prim. 1.0 2.0))"), Ok(Operand::Const(Const::F64(x))) if x == 3.0))
    }

    #[test]
    fn test_main_arithm2() {
        assert!(matches!(run_tmp("(def main (+.prim. (*.prim. 13.0 100.0) 37.0))"), Ok(Val::F64(x)) if x == 1337.0))
        assert!(
            matches!(run_tmp("(def main (+.prim. (*.prim. 13.0 100.0) 37.0))"), Ok(Operand::Const(Const::F64(x))) if x == 1337.0)
        )
    }

    #[test]


@@ 75,7 78,7 @@ mod test {
            (def x 1300.0)
            (def main (+ x 37.0))
        ";
        assert!(matches!(run_tmp(src), Ok(Val::F64(x)) if x == 1337.0))
        assert!(matches!(run_tmp(src), Ok(Operand::Const(Const::F64(x))) if x == 1337.0))
    }

    #[test]


@@ 84,6 87,6 @@ mod test {
            (def main (double 21.0))
            (def double (fun [x] (* x 2.0)))
        ";
        assert!(matches!(run_tmp(src), Ok(Val::F64(x)) if x == 42.0))
        assert!(matches!(run_tmp(src), Ok(Operand::Const(Const::F64(x))) if x == 42.0))
    }
}