~jojo/kapreolo

b48ae150a9211fc84ec827b2503cd9228dc2848f — JoJo 8 months ago cb5bf2e
nullary type synonyms
8 files changed, 201 insertions(+), 65 deletions(-)

M src/abase.rs
M src/cache.rs
M src/check.rs
M src/kapo.rs
M src/main.rs
M src/name.rs
M src/parse.rs
M src/resolve.rs
M src/abase.rs => src/abase.rs +2 -2
@@ 220,8 220,8 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
                Pass::Indirect => converge.indirect(self.vars[&vref].clone(), self),
            },
            Var(Res::Prim(_)) => todo!(), // TODO: generate a closure around the op or smth
            Var(Res::Module(_)) => {
                panic!("ice: found module id in expr context when abasing. Should've been caught by type checker.")
            Var(res @ (Res::Module(_) | Res::Syn(_))) => {
                panic!("ice: found res {res:?} in expr context when abasing. Should've been handled earlier.")
            }
            If(pred, conseq, alt) => {
                let pred_a = self.abase_reg_expr(pred)?;

M src/cache.rs => src/cache.rs +88 -20
@@ 11,15 11,18 @@ use std::rc::Rc;
pub struct Cache {
    root_dir: PathBuf,
    n_defs: u32,
    n_syns: u32,
    n_modules: u32,
    sources: HashMap<FileId, String>,
    file_paths: HashMap<FileId, PathBuf>,
    module_files: HashMap<ModuleId, FileId>,
    parsed_modules: HashMap<ModuleId, HashMap<String, (Loc, PExpr)>>,
    parsed_modules: HashMap<ModuleId, ParsedModule>,
    parent_modules: HashMap<DefId, ModuleId>,
    syn_parent_modules: HashMap<SynId, ModuleId>,
    resolved_names: HashMap<FullName<String>, Res>,
    resolved_names_rev: HashMap<Res, FullName<String>>,
    module_local_resolved_names: HashMap<ModuleId, HashMap<String, Res>>,
    syns: HashMap<SynId, kapo::Type>,
    resolveds: HashMap<DefId, resolve::Def>,
    sigs: HashMap<DefId, Fetch<kapo::Type, Option<kapo::Type>>>,
    checkeds: HashMap<DefId, Fetch<check::Def, ()>>,


@@ 37,6 40,7 @@ impl Cache {
        Self {
            root_dir: root_dir.to_owned(),
            n_defs: 0,
            n_syns: 0,
            n_modules: 0,
            file_paths: HashMap::new(),
            sources: HashMap::new(),


@@ 45,7 49,9 @@ impl Cache {
            resolved_names: HashMap::new(),
            resolved_names_rev: HashMap::new(),
            parent_modules: HashMap::new(),
            syn_parent_modules: HashMap::new(),
            module_local_resolved_names: HashMap::new(),
            syns: HashMap::new(),
            resolveds: HashMap::new(),
            sigs: HashMap::new(),
            checkeds: HashMap::new(),


@@ 292,18 298,31 @@ impl Cache {
        if self.resolveds.contains_key(&def_id) {
            Ok(&self.resolveds[&def_id])
        } else {
            let module_id =
                *self.parent_modules.get(&def_id).expect("ice: if there's a def id, there should be a parent module");
            let module_id = *self.parent_modules.get(&def_id).expect("ice: a def should have a parent module");
            let def_ident =
                self.get_def_name(def_id).expect("ice: if we have a def id, there should be a reverse").last().clone();
            let module = self.fetch_parsed_module(module_id)?;
            let parsed = module[&def_ident].clone();
            let parsed = module.defs[&def_ident].clone();
            let resolved = resolve::resolve_def(self, module_id, &parsed.1)?;
            self.resolveds.insert(def_id, resolved);
            Ok(&self.resolveds[&def_id])
        }
    }

    pub fn fetch_synonymous_type(&mut self, syn_id: SynId) -> Result<&kapo::Type> {
        if self.syns.contains_key(&syn_id) {
            Ok(&self.syns[&syn_id])
        } else {
            let module_id = *self.syn_parent_modules.get(&syn_id).expect("ice: a syn ID should have a parent module");
            let syn_ident = self.get_syn_name(syn_id).expect("ice: a syn ID should have a reverse").last().clone();
            let module = self.fetch_parsed_module(module_id)?;
            let parsed = module.defsyns[&syn_ident].clone();
            let resolved = resolve::resolve_defsyn(self, module_id, &parsed.2)?;
            self.syns.insert(syn_id, resolved);
            Ok(&self.syns[&syn_id])
        }
    }

    pub fn fetch_module_local_resolved_names(&mut self, module_id: ModuleId) -> Result<&HashMap<String, Res>> {
        static PRELUDE: &[(&str, Res)] = &[
            ("+", Res::Prim(Prim::Add)),


@@ 323,16 342,20 @@ impl Cache {
        if self.module_local_resolved_names.contains_key(&module_id) {
            Ok(&self.module_local_resolved_names[&module_id])
        } else {
            let def_idents: Vec<String> = self.fetch_parsed_module(module_id)?.keys().cloned().collect();
            let ids = PRELUDE
                .iter()
                .map(|(ident, res)| (ident.to_string(), *res))
                .chain(
                    def_idents
                        .into_iter()
                        .map(|def_ident| (def_ident.clone(), Res::Def(self.gen_def_resolution(def_ident, module_id)))),
                )
                .collect();
            let module = self.fetch_parsed_module(module_id)?;
            let def_idents: Vec<String> = module.defs.keys().cloned().collect();
            let defsyn_idents: Vec<String> = module.defsyns.keys().cloned().collect();
            let mut ids: HashMap<String, Res> = PRELUDE.iter().map(|(ident, res)| (ident.to_string(), *res)).collect();
            ids.extend(
                def_idents
                    .into_iter()
                    .map(|def_ident| (def_ident.clone(), Res::Def(self.gen_def_resolution(def_ident, module_id)))),
            );
            ids.extend(
                defsyn_idents
                    .into_iter()
                    .map(|ident| (ident.clone(), Res::Syn(self.gen_defsyn_resolution(ident, module_id)))),
            );
            self.module_local_resolved_names.insert(module_id, ids);
            Ok(&self.module_local_resolved_names[&module_id])
        }


@@ 370,7 393,10 @@ impl Cache {
            match self.fetch_global_resolved_name(&module_name)? {
                Res::Module(module_id) => {
                    let module = self.fetch_parsed_module(module_id)?;
                    if module.contains_key(&def_ident.s) {
                    if module.defs.contains_key(&def_ident.s) {
                        return Ok(Res::Def(self.gen_def_resolution(def_ident.s, module_id)));
                    }
                    if module.defsyns.contains_key(&def_ident.s) {
                        return Ok(Res::Def(self.gen_def_resolution(def_ident.s, module_id)));
                    }
                }


@@ 412,7 438,7 @@ impl Cache {
    }

    fn gen_def_resolution(&mut self, def_ident: String, module_id: ModuleId) -> DefId {
        debug_assert!(self.fetch_parsed_module(module_id).unwrap().contains_key(&def_ident));
        debug_assert!(self.fetch_parsed_module(module_id).unwrap().defs.contains_key(&def_ident));
        let def_name = self.get_module_name(module_id).clone().with(def_ident);
        match self.resolved_names.get(&def_name) {
            Some(Res::Def(def_id)) => *def_id,


@@ 430,23 456,53 @@ impl Cache {
        }
    }

    fn gen_defsyn_resolution(&mut self, syn_ident: String, module_id: ModuleId) -> SynId {
        debug_assert!(self.fetch_parsed_module(module_id).unwrap().defsyns.contains_key(&syn_ident));
        let syn_name = self.get_module_name(module_id).clone().with(syn_ident);
        match self.resolved_names.get(&syn_name) {
            Some(Res::Syn(id)) => *id,
            Some(res) => panic!(
                "ice: fetching/generating synonym resolution for name {}, but that name already resolves to {:?}",
                syn_name, res
            ),
            None => {
                let syn_id = self.gen_syn_id();
                let res = Res::Syn(syn_id);
                self.resolved_names.insert(syn_name.clone(), res);
                self.resolved_names_rev.insert(res, syn_name.clone());
                self.syn_parent_modules.insert(syn_id, module_id);
                syn_id
            }
        }
    }

    fn gen_def_id(&mut self) -> DefId {
        assert!(self.n_defs < u32::MAX);
        self.n_defs += 1;
        self.n_defs - 1
    }

    fn gen_syn_id(&mut self) -> DefId {
        assert!(self.n_syns < u32::MAX);
        self.n_syns += 1;
        self.n_syns - 1
    }

    pub fn get_def_name(&self, def_id: DefId) -> Option<&FullName<String>> {
        self.resolved_names_rev.get(&Res::Def(def_id))
    }

    pub fn get_syn_name(&self, syn_id: SynId) -> Option<&FullName<String>> {
        self.resolved_names_rev.get(&Res::Syn(syn_id))
    }

    pub fn get_module_name(&self, module_id: ModuleId) -> &FullName<String> {
        self.resolved_names_rev
            .get(&Res::Module(module_id))
            .expect("ice: if we have a module id, there should be a reverse")
    }

    pub fn fetch_parsed_module(&mut self, module_id: ModuleId) -> Result<&HashMap<String, (Loc, PExpr)>> {
    fn fetch_parsed_module(&mut self, module_id: ModuleId) -> Result<&ParsedModule> {
        if !self.parsed_modules.contains_key(&module_id) {
            // TODO: nested modules in the same file?
            let file_id = *self


@@ 456,14 512,20 @@ impl Cache {
            let src = self.fetch_source(file_id)?;
            let tokens = lex::lex(Some(file_id), src)?;
            let parsed = parse::parse(Some(file_id), &tokens)?;
            let mut dedup: HashMap<String, (Loc, PExpr)> = HashMap::new();
            let mut module = ParsedModule { defs: HashMap::new(), defsyns: HashMap::new() };
            for (lhs, rhs) in parsed.defs {
                let (lhs, loc) = (lhs.substr_in(src), Loc::new(Some(file_id), lhs.start));
                if dedup.insert(lhs.to_string(), (loc, rhs)).is_some() {
                if module.defs.insert(lhs.to_string(), (loc, rhs)).is_some() {
                    return Err(Error::Resolve(ResolveErr::DupDef(loc)));
                }
            }
            self.parsed_modules.insert(module_id, dedup);
            for (name, params, body) in parsed.defsyns {
                let (name, loc) = (name.substr_in(src), Loc::new(Some(file_id), name.start));
                if module.defsyns.insert(name.to_string(), (loc, params, body)).is_some() {
                    return Err(Error::Resolve(ResolveErr::DupDef(loc)));
                }
            }
            self.parsed_modules.insert(module_id, module);
        }
        Ok(&self.parsed_modules[&module_id])
    }


@@ 490,6 552,12 @@ impl Cache {
    }
}

#[derive(Debug)]
struct ParsedModule {
    defs: HashMap<String, (Loc, PExpr)>,
    defsyns: HashMap<String, (Loc, Vec<parse::IdentSpan>, parse::Type)>,
}

#[derive(Debug, Clone)]
pub enum Fetch<T, E> {
    Done(T),

M src/check.rs => src/check.rs +1 -0
@@ 276,6 276,7 @@ impl<'c, 'r> Checker<'c, 'r> {
            }
            Expr::Var(Res::Prim(Prim::Cast)) => Type::Fun(vec![Type::ISize], Box::new(Type::ISize)),
            Expr::Var(Res::Module(id)) => return err(ModuleIsNotExpr(self.expr_locs[eref], id)),
            Expr::Var(Res::Syn(id)) => panic!("ice: Res::Syn ({id}) in check"),
            Expr::If(p, c, a) => {
                self.check(p, &Type::Bool)?;
                let t_c = self.infer(c)?.clone();

M src/kapo.rs => src/kapo.rs +1 -0
@@ 165,6 165,7 @@ pub enum Res {
    Var(VarRef),
    Prim(Prim),
    Module(ModuleId),
    Syn(SynId),
}

impl Res {

M src/main.rs => src/main.rs +14 -4
@@ 653,16 653,26 @@ mod test {
        assert_matches!(run_tmp(src), Ok(-128))
    }

    #[test]
    fn test_type_synonym1() {
        let src = "
            (defsyn MyInt [] Int)
            (def inc (of (Fun [MyInt] MyInt) (fun [x] (+ 1 x))))
            (def main (of Int (inc 110)))
        ";
        assert_matches!(run_tmp(src), Ok(111))
    }

    // #[test]
    // fn test_type_synonym() {
    // fn test_type_synonym2() {
    //     let src = "
    //         (defsyn Byte [] N8)
    //         (defsyn Pair [a b] [a b])
    //         (defsyn Be16 [] (Pair Byte Byte))
    //         (def read-n16-be (of (Fun [Be16] N16) (fun [x] (match x [[hi lo] (+ (* (as N16 lo) 8) (as N16 hi))]))))
    //         (def main (read-n16-be [2 16]))
    //         (def read-n16-be (of (Fun [Be16] N16) (fun [x] (match x [[hi lo] (+ (* (of N16 (cast lo)) 8) (of N16 (cast hi)))]))))
    //         (def main (as Int (read-n16-be [2 16])))
    //     ";
    //     assert_matches!(run_tmp(src), Ok(Operand::Const(Const::N16(528))))
    //     assert_matches!(run_tmp(src), Ok(528))
    // }

    // #[test]

M src/name.rs => src/name.rs +1 -0
@@ 5,6 5,7 @@ use std::hash::Hash;
use std::path::Path;

pub type DefId = u32;
pub type SynId = u32;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]

M src/parse.rs => src/parse.rs +62 -22
@@ 7,6 7,7 @@ use std::path::Path;
#[derive(Debug)]
pub struct Module {
    pub defs: Vec<(IdentSpan, Expr)>,
    pub defsyns: Vec<(IdentSpan, Vec<IdentSpan>, Type)>,
}

// TODO: Parse directly to indexed flat tree.


@@ 108,6 109,8 @@ pub enum Type {
    F64,
    Fun(Vec<Type>, Box<Type>),
    PolyFun(Vec<IdentSpan>, Vec<Type>, Box<Type>),
    Const(Name),
    App(Box<Type>, Vec<Type>),
    TVar(IdentSpan),
    Hole(Loc),
    Tuple(Vec<Type>),


@@ 116,9 119,16 @@ pub enum Type {
pub fn parse(file: Option<FileId>, tokens: &[Token]) -> std::result::Result<Module, ParseErr> {
    let inp = Inp { tokens, dist: 0, context: file.map(|file| Loc::FileGeneral { file }).unwrap_or(Loc::AnonGeneral) };
    let go = move || {
        let (inp, defs) = many0(def)(inp)?;
        let (inp, toplevels) = many0(item)(inp)?;
        end(inp)?;
        Ok(Module { defs })
        let (mut defs, mut defsyns) = (vec![], vec![]);
        for item in toplevels {
            match item {
                Item::Def(x) => defs.push(x),
                Item::Defsyn(x) => defsyns.push(x),
            }
        }
        Ok(Module { defs, defsyns })
    };
    go()
}


@@ 183,14 193,29 @@ impl<'s, 't> Inp<'s, 't> {
    }
}

fn def<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, (IdentSpan, Expr)> {
    parens(|inp| {
        let (inp, _) = special_form("def")(inp)?;
        let (inp, lhs) = ident(inp)?;
        let (inp, rhs) = expr(inp)?;
        let (inp, ()) = end(inp)?;
        Ok((inp, (lhs, rhs)))
    })(inp)
enum Item {
    Def((IdentSpan, Expr)),
    Defsyn((IdentSpan, Vec<IdentSpan>, Type)),
}

fn item<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Item> {
    parens(alt2(
        |inp| {
            let (inp, _) = special_form("def")(inp)?;
            let (inp, lhs) = ident(inp)?;
            let (inp, rhs) = expr(inp)?;
            let (inp, ()) = end(inp)?;
            Ok((inp, Item::Def((lhs, rhs))))
        },
        |inp| {
            let (inp, _) = special_form("defsyn")(inp)?;
            let (inp, name) = ident(inp)?;
            let (inp, tvars) = brackets(rest(ident))(inp)?;
            let (inp, body) = typ(inp)?;
            let (inp, ()) = end(inp)?;
            Ok((inp, Item::Defsyn((name, tvars, body))))
        },
    ))(inp)
}

fn expr<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Expr> {


@@ 263,25 288,31 @@ fn typ<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Type> {
        ),
        alt2(give(literally("Int"), Type::ISize), give(literally("Nat"), Type::NSize)),
        alt2(give(literally("F64"), Type::F64), give(literally("Bool"), Type::Bool)),
        alt2(
        alt3(
            |inp @ Inp { context, .. }| map(literally("_"), |s| Type::Hole(context.with_offset(s.start)))(inp),
            map(ident, Type::TVar),
            map(small_ident, Type::TVar),
            map(name, Type::Const),
        ),
        alt2(map(brackets(rest(typ)), Type::Tuple), parens(ptyp)),
    )(inp)
}

fn ptyp<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Type> {
    let (inp, _) = special_form("Fun")(inp)?;
    let (inp, tvars) = opt(preceded(keyword("for"), brackets(rest(ident))))(inp)?;
    let (inp, params) = brackets(rest(typ))(inp)?;
    let (inp, ret) = typ(inp)?;
    let (inp, ()) = end(inp)?;
    let t = match tvars {
        Some(tvars) => Type::PolyFun(tvars, params, Box::new(ret)),
        None => Type::Fun(params, Box::new(ret)),
    };
    Ok((inp, t))
    alt2(
        |inp| {
            let (inp, _) = special_form("Fun")(inp)?;
            let (inp, tvars) = opt(preceded(keyword("for"), brackets(rest(ident))))(inp)?;
            let (inp, params) = brackets(rest(typ))(inp)?;
            let (inp, ret) = typ(inp)?;
            let (inp, ()) = end(inp)?;
            let t = match tvars {
                Some(tvars) => Type::PolyFun(tvars, params, Box::new(ret)),
                None => Type::Fun(params, Box::new(ret)),
            };
            Ok((inp, t))
        },
        map(pair(typ, rest(typ)), |(f, xs)| Type::App(Box::new(f), xs)),
    )(inp)
}

fn special_form<'s: 't, 't>(id: &'static str) -> impl Parser<'s, 't, IdentSpan> {


@@ 338,6 369,15 @@ fn ident<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, IdentSpan> {
    })
}

fn small_ident<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, IdentSpan> {
    inp.filter_next(|t| match t.tok {
        Tok::Ident(x) if x.starts_with(|c: char| c.is_lowercase()) => {
            Ok(IdentSpan { start: t.offset, len: x.len() as u32 })
        }
        _ => Err("*small identifier*"),
    })
}

fn lit_float<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, f64> {
    inp.filter_next(|t| match t.tok {
        Tok::Float(x) => Ok(x),

M src/resolve.rs => src/resolve.rs +32 -17
@@ 25,8 25,16 @@ pub enum TVar {
    Hole(Loc),
}

pub fn resolve_def(cache: &mut Cache, parent_module: ModuleId, def_body: &parse::Expr) -> Result<Def> {
    Resolver::resolve_def(cache, parent_module, def_body)
pub fn resolve_def(cache: &mut Cache, module: ModuleId, def_body: &parse::Expr) -> Result<Def> {
    let mut resolver = Resolver::new(cache, module);
    let root = resolver.resolve(def_body)?;
    let Resolver { exprs, expr_locs, expr_annots, var_locs, str_arena, tvars, pats, pat_locs, .. } = resolver;
    Ok(Def { exprs, expr_locs, expr_annots, var_locs, str_arena, tvars, pats, pat_locs, root })
}

pub fn resolve_defsyn(cache: &mut Cache, module: ModuleId, syn_typ: &parse::Type) -> Result<Type> {
    let mut resolver = Resolver::new(cache, module);
    resolver.resolve_type(syn_typ)
}

struct Resolver<'c> {


@@ 46,14 54,13 @@ struct Resolver<'c> {
}

impl<'c> Resolver<'c> {
    fn resolve_def(cache: &'c mut Cache, module: ModuleId, body: &parse::Expr) -> Result<Def> {
        let module_top_level = cache.fetch_module_local_resolved_names(module)?.clone();
    fn new(cache: &'c mut Cache, module: ModuleId) -> Self {
        let file = cache.get_module_file(module);
        let mut resolver = Self {
        Resolver {
            cache,
            file,
            module,
            scopes: vec![module_top_level],
            scopes: vec![],
            exprs: Arena::new(),
            expr_locs: OutOfBand::new(),
            expr_annots: SparseOutOfBand::new(),


@@ 63,10 70,7 @@ impl<'c> Resolver<'c> {
            tvars: Arena::new(),
            pats: Arena::new(),
            pat_locs: OutOfBand::new(),
        };
        let root = resolver.resolve(body)?;
        let Self { exprs, expr_locs, expr_annots, var_locs, str_arena, tvars, pats, pat_locs, .. } = resolver;
        Ok(Def { exprs, expr_locs, expr_annots, var_locs, str_arena, tvars, pats, pat_locs, root })
        }
    }

    fn resolve(&mut self, expr: &parse::Expr) -> Result<ExprRef> {


@@ 121,18 125,24 @@ impl<'c> Resolver<'c> {
            self.cache.fetch_global_resolved_name(&name)
        } else {
            match name.split_first() {
                (first, None) => self
                    .scopes
                    .iter()
                    .rev()
                    .find_map(|scope| scope.get(first.as_str()))
                    .cloned()
                    .ok_or_else(|| Error::Resolve(UndefInMod(first.clone(), self.module))),
                (first, None) => self.resolve_ident(&first),
                (_first, _rest) => todo!(),
            }
        }
    }

    fn resolve_ident(&mut self, ident: &PubIdent) -> Result<Res> {
        if let Some(res) = self.scopes.iter().rev().find_map(|scope| scope.get(ident.as_str())) {
            Ok(*res)
        } else {
            self.cache
                .fetch_module_local_resolved_names(self.module)?
                .get(ident.as_str())
                .cloned()
                .ok_or_else(|| Error::Resolve(UndefInMod(ident.clone(), self.module)))
        }
    }

    fn resolve_type(&mut self, tpar: &parse::Type) -> Result<Type> {
        match tpar {
            parse::Type::Bool => Ok(Type::Bool),


@@ 161,6 171,11 @@ impl<'c> Resolver<'c> {
                self.tvar_scopes.pop();
                Ok(Type::PolyFun(tvrefs, ps_res, r_res))
            }
            parse::Type::Const(cname) => match self.resolve_parsed_name(cname)? {
                Res::Syn(id) => self.cache.fetch_synonymous_type(id).cloned(),
                _ => todo!(),
            },
            parse::Type::App(..) => todo!(),
            &parse::Type::TVar(span) => {
                let s = self.cache.fetch_source(self.file).map(|src| span.substr_in(src))?;
                self.tvar_scopes