~jojo/effem

ff5b60401a8fe04b50606a583b2ea44a7c9594ec — JoJo 3 months ago 626aefb
basic bidirectional typing (infer / check)
7 files changed, 207 insertions(+), 31 deletions(-)

M src/abase.rs
M src/cache.rs
A src/check.rs
M src/fem.rs
M src/main.rs
M src/parse.rs
M src/resolve.rs
M src/abase.rs => src/abase.rs +15 -15
@@ 101,11 101,12 @@ impl<'c> AbaseDef<'c> {
    }

    fn abase(&mut self, expr: &FExpr) -> Result<Expr> {
        match *expr {
            FExpr::F64(x) => Ok(Expr::Operand(Operand::Const(Const::F64(x)))),
            FExpr::Fun(ref ps, ref b) => Ok(Expr::Operand(Operand::Local(self.abase_fun(ps, b)?))),
            FExpr::App(ref f, ref xs) => match **f {
                FExpr::Var(ResName { res: Res::Prim(p) }) => match p {
        use crate::fem::ExprKind::*;
        match &expr.kind {
            F64(x) => Ok(Expr::Operand(Operand::Const(Const::F64(*x)))),
            Fun(ps, b) => Ok(Expr::Operand(Operand::Local(self.abase_fun(ps, b)?))),
            App(f, xs) => match &f.kind {
                Var(ResName { res: Res::Prim(p) }) => match p {
                    Prim::Add => {
                        let x = self.abase(&xs[0])?;
                        let x = self.let_anon_soft(x)?;


@@ 134,14 135,13 @@ impl<'c> AbaseDef<'c> {
                    Ok(Expr::Call(f, xs))
                }
            },
            FExpr::Var(ResName { res: Res::Def(id) }) => {
                let _ = self.cache.fetch_base(id)?;
                Ok(Expr::Operand(Operand::Global(id)))
            Var(ResName { res: Res::Def(id) }) => {
                let _ = self.cache.fetch_base(*id)?;
                Ok(Expr::Operand(Operand::Global(*id)))
            }
            FExpr::Var(ResName { res: Res::Local(id) }) =>
                Ok(Expr::Operand(Operand::Local(self.vars.get(id as usize).unwrap_or_else(|| todo!()).clone()))),
            FExpr::Var(ResName { res: Res::Prim(prim) }) =>
                Err(anyhow!("error: can't abase primitive `{prim}` in isolation")),
            Var(ResName { res: Res::Local(id) }) =>
                Ok(Expr::Operand(Operand::Local(*self.vars.get(*id as usize).unwrap_or_else(|| todo!())))),
            Var(ResName { res: Res::Prim(prim) }) => Err(anyhow!("error: can't abase primitive `{prim}` in isolation")),
        }
    }



@@ 149,9 149,9 @@ impl<'c> AbaseDef<'c> {
    fn abase_fun(&mut self, params: &[(PubIdent, LocalId)], body: &FExpr) -> Result<LocalId> {
        let fid = self.gen_local();

        let old_defs = std::mem::replace(&mut self.defs, vec![]);
        let old_vars = std::mem::replace(&mut self.vars, vec![]);
        let old_stms = std::mem::replace(&mut self.stms, vec![]);
        let old_defs = std::mem::take(&mut self.defs);
        let old_vars = std::mem::take(&mut self.vars);
        let old_stms = std::mem::take(&mut self.stms);

        let params = params
            .iter()

M src/cache.rs => src/cache.rs +12 -3
@@ 1,7 1,7 @@
use crate::name::*;
use crate::parse::{self, Expr as PExpr, Module as PModule};
use crate::resolve::{self, Expr as RExpr};
use crate::{abase, eval, fem, lex};
use crate::{abase, check, eval, fem, lex};
use anyhow::{anyhow, Context};
use std::borrow::Borrow;
use std::collections::hash_map::{Entry, HashMap};


@@ 17,6 17,7 @@ pub struct Cache {
    module_local_resolved_names: HashMap<FullName, HashMap<PubIdent, Res>>,
    resolveds: HashMap<DefId, RExpr>,
    // desugareds: HashMap<Query, fem::Expr>,
    checkeds: HashMap<DefId, fem::Expr>,
    abaseds: HashMap<DefId, abase::GlobDef>,
    evaluateds: HashMap<DefId, abase::Operand>,
}


@@ 31,6 32,7 @@ impl Cache {
            resolved_names_rev: HashMap::new(),
            module_local_resolved_names: HashMap::new(),
            resolveds: HashMap::new(),
            checkeds: HashMap::new(),
            abaseds: HashMap::new(),
            evaluateds: HashMap::new(),
        }


@@ 75,10 77,17 @@ impl Cache {
    }

    pub fn fetch_checked(&mut self, def_query: DefId) -> anyhow::Result<&fem::Expr> {
        self.fetch_desugared(def_query)
        if self.checkeds.contains_key(&def_query) {
            Ok(&self.checkeds[&def_query])
        } else {
            let desugared: RExpr = self.fetch_desugared(def_query)?.clone();
            let checked = check::check_def(self, &desugared)?;
            self.checkeds.insert(def_query, checked);
            Ok(&self.checkeds[&def_query])
        }
    }

    pub fn fetch_desugared(&mut self, def_query: DefId) -> anyhow::Result<&fem::Expr> {
    pub fn fetch_desugared(&mut self, def_query: DefId) -> anyhow::Result<&RExpr> {
        self.fetch_resolved(def_query)
        // Ok(if self.desugareds.contains_key(def_query) {
        //     &self.desugareds[def_query]

A src/check.rs => src/check.rs +95 -0
@@ 0,0 1,95 @@
use crate::cache::Cache;
use crate::fem::*;
use crate::name::*;
use crate::resolve::Expr as RExpr;
use anyhow::{anyhow, Result};
use std::collections::HashMap;

pub fn check_def(cache: &mut Cache, body: &RExpr) -> Result<Expr> {
    CheckDef::new(cache).run(body)
}

struct CheckDef<'c> {
    cache: &'c mut Cache,
    scopes: Vec<HashMap<LocalId, Type>>,
}

impl<'c> CheckDef<'c> {
    fn new(cache: &'c mut Cache) -> Self {
        Self { cache, scopes: vec![] }
    }

    fn run(mut self, body: &RExpr) -> Result<Expr> {
        self.infer(body)
    }

    fn infer(&mut self, expr: &RExpr) -> Result<Expr> {
        match expr {
            RExpr::F64(x) => Ok(Expr { typ: Type::F64, kind: ExprKind::F64(*x) }),
            RExpr::Fun(ps, b) if ps.is_empty() => {
                let bi = self.infer(b)?;
                Ok(Expr { typ: Type::Fun(vec![], Box::new(bi.typ.clone())), kind: ExprKind::Fun(vec![], Box::new(bi)) })
            }
            RExpr::Fun(_, _) => Err(anyhow!("Can't infer type of lambdas with >0 parameters (yet). Please surround the lambda in a type annotation to have the type checked instead of inferred.")),
            RExpr::App(f, args) => {
                let fi = self.infer(f)?;
                if let Type::Fun(ps, r) = &fi.typ {
                    if ps.len() == args.len() {
                        let argsc = ps.iter().zip(args).map(|(p, a)| self.check(a, p)).collect::<Result<Vec<_>>>()?;
                        Ok(Expr { typ: (**r).clone(), kind: ExprKind::App(Box::new(fi), argsc)})
                    } else {
                        Err(anyhow!("Arity mismatch in function application. Function specifies {} parameters, but {} arguments were given.", ps.len(), args.len()))
                    }
                } else {
                    Err(anyhow!("Expected a function to apply, found a {:?}", fi.typ))
                }
            }
            RExpr::Var(r@ResName { res: Res::Def(id) }) => {
                let typ = self.cache.fetch_checked(*id).map(|Expr { typ, ..}| typ.clone())?;
                Ok(Expr { typ, kind: ExprKind::Var(r.clone())})
            }
            RExpr::Var(r@ResName { res: Res::Local(id) }) => match self.get_local(*id) {
                Some(t) => Ok(Expr { typ: t.clone(), kind: ExprKind::Var(r.clone())}),
                None => panic!("ice: undefined local var of id {id}")
            },
            RExpr::Var(r@ResName { res: Res::Prim(Prim::Add | Prim::Mul) }) => Ok
                (Expr { typ: Type::Fun(vec![Type::F64, Type::F64], Box::new(Type::F64)), kind: ExprKind::Var(r.clone())}),
            RExpr::Annot(e, t) => self.check(e, t),
        }
    }

    fn check(&mut self, expr: &RExpr, expected: &Type) -> Result<Expr> {
        match expr {
            RExpr::Fun(ps, b) =>
                if let Type::Fun(tps, tr) = expected {
                    if ps.len() != tps.len() {
                        return Err(anyhow!("Arity mismatch in function parameters. Expectectations from the outside are that there be {} parameters, but the function here instead has {}.", tps.len(), ps.len()));
                    }
                    self.scopes.push(ps.iter().zip(tps).map(|((_, id), t)| (*id, t.clone())).collect());
                    let bc = self.check(b, tr)?;
                    self.scopes.pop();
                    Ok(Expr { typ: expected.clone(), kind: ExprKind::Fun(ps.clone(), Box::new(bc)) })
                } else {
                    Err(anyhow!("Expected {expected:?}, found a function"))
                },
            RExpr::Annot(e, t) =>
                if t == expected {
                    self.check(e, t)
                } else {
                    Err(anyhow!("Expected {expected:?}, found {t:?} (according to annotation)."))
                },
            _ => {
                let expri = self.infer(expr)?;
                if &expri.typ == expected {
                    Ok(expri)
                } else {
                    Err(anyhow!("Expected {expected:?}, found {:?}", expri.typ))
                }
            }
        }
    }

    fn get_local(&self, id: LocalId) -> Option<&Type> {
        self.scopes.iter().rev().find_map(|locals| locals.get(&id))
    }
}

M src/fem.rs => src/fem.rs +16 -1
@@ 1,3 1,18 @@
//! Effem core IR

pub use crate::resolve::Expr;
use crate::name::{LocalId, PubIdent, ResName};
pub use crate::parse::Type;

#[derive(Debug, Clone, PartialEq)]
pub struct Expr {
    pub typ: Type,
    pub kind: ExprKind,
}

#[derive(Debug, Clone, PartialEq)]
pub enum ExprKind {
    F64(f64),
    Fun(Vec<(PubIdent, LocalId)>, Box<Expr>),
    App(Box<Expr>, Vec<Expr>),
    Var(ResName),
}

M src/main.rs => src/main.rs +16 -6
@@ 6,6 6,7 @@

mod abase;
mod cache;
mod check;
mod desugar;
mod diag;
mod eval;


@@ 57,18 58,18 @@ mod test {

    #[test]
    fn test_main_literal() {
        assert!(matches!(run_tmp("(def main 123.456)"), Ok(Operand::Const(Const::F64(x))) if x == 123.456))
        assert!(matches!(run_tmp("(def main 123.456)").unwrap(), Operand::Const(Const::F64(x)) if x == 123.456))
    }

    #[test]
    fn test_main_arithm1() {
        assert!(matches!(run_tmp("(def main (+.prim. 1.0 2.0))"), Ok(Operand::Const(Const::F64(x))) if x == 3.0))
        assert!(matches!(run_tmp("(def main (+.prim. 1.0 2.0))").unwrap(), Operand::Const(Const::F64(x)) if x == 3.0))
    }

    #[test]
    fn test_main_arithm2() {
        assert!(
            matches!(run_tmp("(def main (+.prim. (*.prim. 13.0 100.0) 37.0))"), Ok(Operand::Const(Const::F64(x))) if x == 1337.0)
            matches!(run_tmp("(def main (+.prim. (*.prim. 13.0 100.0) 37.0))").unwrap(), Operand::Const(Const::F64(x)) if x == 1337.0)
        )
    }



@@ 78,15 79,24 @@ mod test {
            (def x 1300.0)
            (def main (+ x 37.0))
        ";
        assert!(matches!(run_tmp(src), Ok(Operand::Const(Const::F64(x))) if x == 1337.0))
        assert!(matches!(run_tmp(src).unwrap(), Operand::Const(Const::F64(x)) if x == 1337.0))
    }

    #[test]
    fn test_app_lambda() {
        let src = "
            (def main (double 21.0))
            (def double (fun [x] (* x 2.0)))
            (def double (of (Fun [F64] F64)
                            (fun [x] (* x 2.0))))
        ";
        assert!(matches!(run_tmp(src), Ok(Operand::Const(Const::F64(x))) if x == 42.0))
        assert!(matches!(run_tmp(src).unwrap(), Operand::Const(Const::F64(x)) if x == 42.0))
    }

    #[test]
    fn test_type_err_lambda_f64() {
        let src = "
            (def main (+ 1.0 (fun [x] (* x 2.0))))
        ";
        assert!(run_tmp(src).is_err())
    }
}

M src/parse.rs => src/parse.rs +45 -1
@@ 28,6 28,13 @@ pub enum Expr {
    Fun(Vec<PrivIdent>, Box<Expr>),
    App(Box<Expr>, Vec<Expr>),
    Var(ParsedName),
    Annot(Box<Expr>, Type),
}

#[derive(Debug, Clone, PartialEq)]
pub enum Type {
    F64,
    Fun(Vec<Type>, Box<Type>),
}

pub fn parse<'s>(file: Option<Rc<Path>>, src: &'s str, tokens: &[Token<'s>]) -> anyhow::Result<Module> {


@@ 126,7 133,7 @@ fn expr<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Expr> {
}

fn pexpr<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Expr> {
    alt2(
    alt3(
        |inp| {
            let (inp, _) = special_form("fun")(inp)?;
            let (inp, params) = brackets(rest(ident))(inp)?;


@@ 134,11 141,32 @@ fn pexpr<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Expr> {
            let (inp, ()) = end(inp)?;
            Ok((inp, Expr::Fun(params, Box::new(body))))
        },
        |inp| {
            let (inp, _) = special_form("of")(inp)?;
            let (inp, (t, e)) = pair(typ, expr)(inp)?;
            Ok((inp, Expr::Annot(Box::new(e), t)))
        },
        map(pair(expr, rest(expr)), |(f, xs)| Expr::App(Box::new(f), xs)),
    )(inp)
}

fn typ<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Type> {
    alt2(give(literally("F64"), Type::F64), parens(ptyp))(inp)
}

fn ptyp<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Type> {
    let (inp, _) = special_form("Fun")(inp)?;
    let (inp, params) = brackets(rest(typ))(inp)?;
    let (inp, ret) = typ(inp)?;
    let (inp, ()) = end(inp)?;
    Ok((inp, Type::Fun(params, Box::new(ret))))
}

fn special_form<'s: 't, 't>(id: &'static str) -> impl Parser<'s, 't, PrivIdent> {
    literally(id)
}

fn literally<'s: 't, 't>(id: &'static str) -> impl Parser<'s, 't, PrivIdent> {
    move |inp| {
        inp.filter_next(|t| match t.tok {
            Tok::Ident(x) if x == id => Ok(PrivIdent { start: t.offset, len: x.len() as u32 }),


@@ 330,6 358,10 @@ fn map<'s: 't, 't, A, B>(mut f: impl Parser<'s, 't, A>, mut g: impl FnMut(A) -> 
    move |inp| f(inp).map(|(inp, x)| (inp, g(x)))
}

fn give<'s: 't, 't, A, B: Clone>(mut f: impl Parser<'s, 't, A>, x: B) -> impl Parser<'s, 't, B> {
    move |inp| f(inp).map(|(inp, _)| (inp, x.clone()))
}

#[cfg(test)]
mod test {
    use super::*;


@@ 359,4 391,16 @@ mod test {
                        && matches!(&xs[..], [Expr::Var(x1), Expr::Var(x2)] if x1.segments[0].start == 23 && x2.segments[0].start == 25))
        )
    }

    #[test]
    fn test_expr_annotation1() {
        let src = "(def foo (of F64 123.0))";
        let m = parse(None, src, &lex(None, src).unwrap()).unwrap();
        assert_matches!(
            &m.defs["foo"],
            Expr::Annot(e, t)
            if matches!(&**e, Expr::F64(x) if *x == 123.0)
            && matches!(t, Type::F64)
        )
    }
}

M src/resolve.rs => src/resolve.rs +8 -5
@@ 1,3 1,4 @@
pub use crate::parse::Type;
use crate::{cache::*, name::*, parse};
use anyhow::{anyhow, Result};
use std::borrow::Cow;


@@ 15,6 16,7 @@ pub enum Expr {
    Fun(Vec<(PubIdent, LocalId)>, Box<Expr>),
    App(Box<Expr>, Vec<Expr>),
    Var(ResName),
    Annot(Box<Expr>, Type),
}

pub fn resolve_def(cache: &mut Cache, def_name: &FullName, def_body: &parse::Expr) -> Result<Expr> {


@@ 40,9 42,9 @@ impl<'c> Resolver<'c> {

    fn resolve(&mut self, e: &parse::Expr) -> Result<Expr> {
        use parse::Expr::*;
        Ok(match *e {
            F64(x) => Expr::F64(x),
            Fun(ref params, ref body) => {
        Ok(match e {
            F64(x) => Expr::F64(*x),
            Fun(params, body) => {
                let rparams =
                    params.iter().map(|p| Ok((self.pub_ident(p)?, self.gen_local_id()))).collect::<Result<Vec<_>>>()?;
                self.scopes.push(rparams.iter().map(|(p, id)| (p.clone(), Res::Local(*id))).collect());


@@ 50,9 52,9 @@ impl<'c> Resolver<'c> {
                self.scopes.pop();
                Expr::Fun(rparams, Box::new(rbody))
            }
            App(ref f, ref args) =>
            App(f, args) =>
                Expr::App(Box::new(self.resolve(f)?), args.iter().map(|a| self.resolve(a)).collect::<Result<_>>()?),
            Var(ref v) => {
            Var(v) => {
                // To begin with, our lang is very primitive, but there is still some symbol scoping to consider.
                // If main.fm defines foo and foo', and misc.fm defines bar, then
                // - `foo` can refer to `foo'` as just `foo'`, but must refer to `bar` by the full `bar.misc.pkg.`


@@ 60,6 62,7 @@ impl<'c> Resolver<'c> {
                // So `resolve` must be aware of the module it's in from the start
                Expr::Var(ResName { res: self.resolve_parsed_name(v)? })
            }
            Annot(e, t) => Expr::Annot(Box::new(self.resolve(e)?), t.clone()),
        })
    }