~jojo/kapreolo

c74b5fd287acfb0fac24e2e57162d0ffb4fef9a5 — JoJo 8 months ago b48ae15 main
parametric type synonyms
7 files changed, 119 insertions(+), 70 deletions(-)

M src/abase.rs
M src/cache.rs
M src/check.rs
M src/diag.rs
M src/main.rs
M src/parse.rs
M src/resolve.rs
M src/abase.rs => src/abase.rs +24 -24
@@ 563,17 563,17 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
    }

    fn size_of(&mut self, t: &kapo::Type) -> Operand {
        use kapo::Type::*;
        use kapo::Type;
        match t {
            Bool => Const::NSize(1).into(),
            &Int { width, .. } => Const::NSize(width as usize / 8).into(),
            ISize | NSize => Const::SizeofPtr.into(),
            F64 => Const::NSize(8).into(),
            Type::Bool => Const::NSize(1).into(),
            &Type::Int { width, .. } => Const::NSize(width as usize / 8).into(),
            Type::ISize | Type::NSize => Const::SizeofPtr.into(),
            Type::F64 => Const::NSize(8).into(),
            // TODO: Change this once `Fun` represents a closure
            Fun(..) => Const::SizeofPtr.into(),
            PolyFun(..) => todo!(),
            Tuple(us) => self.size_of_tuple(us.iter()).0,
            TVar(tv) => {
            Type::Fun(..) => Const::SizeofPtr.into(),
            Type::PolyFun(..) => todo!(),
            Type::Tuple(us) => self.size_of_tuple(us.iter()).0,
            Type::TVar(tv) => {
                let i = self.tvars.binary_search(tv).unwrap();
                self.abase_vwt_field_size(self.vwt_param(i))
            }


@@ 581,14 581,14 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
    }

    fn align_of(&mut self, t: &kapo::Type) -> Operand {
        use kapo::Type::*;
        use kapo::Type;
        match t {
            Tuple(_us) => todo!(), // us.iter().map(|u| self.align_of(u)).fold(Align::byte(), |acc, a| acc.max(&a)),
            TVar(tv) => {
            Type::Tuple(_us) => todo!(), // us.iter().map(|u| self.align_of(u)).fold(Align::byte(), |acc, a| acc.max(&a)),
            Type::TVar(tv) => {
                let i = self.tvars.binary_search(tv).unwrap();
                self.abase_vwt_field_align(self.vwt_param(i))
            }
            Fun(..) | PolyFun(..) => Const::SizeofPtr.into(),
            Type::Fun(..) | Type::PolyFun(..) => Const::SizeofPtr.into(),
            _ => self.size_of(t),
        }
    }


@@ 678,12 678,12 @@ fn abase_type(t: &kapo::Type) -> Pass {
}

fn abase_int(x: i128, typ: &kapo::Type) -> Const {
    use kapo::Type::*;
    use kapo::Type;
    match *typ {
        Int { width, signed: true } => Const::Int { width, val: x as i64 },
        Int { width, signed: false } => Const::Nat { width, val: x as u64 },
        ISize => Const::ISize(x as isize),
        NSize => Const::NSize(x as usize),
        Type::Int { width, signed: true } => Const::Int { width, val: x as i64 },
        Type::Int { width, signed: false } => Const::Nat { width, val: x as u64 },
        Type::ISize => Const::ISize(x as isize),
        Type::NSize => Const::NSize(x as usize),
        ref t => panic!("ice: integer expr doesn't have signed integer type. found {}", t.describe()),
    }
}


@@ 996,13 996,13 @@ impl Converge for ConvVarMerge {
}

fn abase_tag(tag: u64, typ: &kapo::Type) -> Option<Const> {
    use kapo::Type::*;
    use kapo::Type;
    match typ {
        Bool => Some(Const::Bool(tag != 0)),
        Int { signed: true, .. } | ISize => Some(abase_int(tag as i64 as i128, typ)),
        Int { .. } | NSize => Some(abase_int(tag as i128, typ)),
        Tuple(_) => None,
        F64 | TVar(_) | Fun(_, _) | PolyFun(..) => todo!(),
        Type::Bool => Some(Const::Bool(tag != 0)),
        Type::Int { signed: true, .. } | Type::ISize => Some(abase_int(tag as i64 as i128, typ)),
        Type::Int { .. } | Type::NSize => Some(abase_int(tag as i128, typ)),
        Type::Tuple(_) => None,
        Type::F64 | Type::TVar(_) | Type::Fun(_, _) | Type::PolyFun(..) => todo!(),
    }
}


M src/cache.rs => src/cache.rs +7 -5
@@ 22,7 22,7 @@ pub struct Cache {
    resolved_names: HashMap<FullName<String>, Res>,
    resolved_names_rev: HashMap<Res, FullName<String>>,
    module_local_resolved_names: HashMap<ModuleId, HashMap<String, Res>>,
    syns: HashMap<SynId, kapo::Type>,
    syns: HashMap<SynId, (Vec<kapo::TVarRef>, kapo::Type)>,
    resolveds: HashMap<DefId, resolve::Def>,
    sigs: HashMap<DefId, Fetch<kapo::Type, Option<kapo::Type>>>,
    checkeds: HashMap<DefId, Fetch<check::Def, ()>>,


@@ 309,17 309,19 @@ impl Cache {
        }
    }

    pub fn fetch_synonymous_type(&mut self, syn_id: SynId) -> Result<&kapo::Type> {
    pub fn fetch_synonymous_type(&mut self, syn_id: SynId) -> Result<(&[kapo::TVarRef], &kapo::Type)> {
        if self.syns.contains_key(&syn_id) {
            Ok(&self.syns[&syn_id])
            let (params, t) = &self.syns[&syn_id];
            Ok((params.as_slice(), t))
        } else {
            let module_id = *self.syn_parent_modules.get(&syn_id).expect("ice: a syn ID should have a parent module");
            let syn_ident = self.get_syn_name(syn_id).expect("ice: a syn ID should have a reverse").last().clone();
            let module = self.fetch_parsed_module(module_id)?;
            let parsed = module.defsyns[&syn_ident].clone();
            let resolved = resolve::resolve_defsyn(self, module_id, &parsed.2)?;
            let resolved = resolve::resolve_defsyn(self, module_id, &parsed.2, &parsed.1)?;
            self.syns.insert(syn_id, resolved);
            Ok(&self.syns[&syn_id])
            let (params, t) = &self.syns[&syn_id];
            Ok((params.as_slice(), t))
        }
    }


M src/check.rs => src/check.rs +14 -8
@@ 105,7 105,7 @@ pub struct Def {
}

#[derive(Debug, Clone)]
enum TVar {
pub enum TVar {
    Univ,  // Bound in a universal quantification
    Exist, // Existential type. As of yet unknown.
    Subst(Type),


@@ 399,7 399,16 @@ impl<'c, 'r> Checker<'c, 'r> {
                    self.check(b, &expected)?;
                }
            }
            _ => {
            Expr::Tuple(ref xs) => match &expected {
                Type::Tuple(ts) if ts.len() == xs.len() => {
                    for (&x, t) in zip(xs, ts) {
                        self.check(x, t)?;
                    }
                }
                Type::Tuple(ts) => return err(ArityMisTuple { loc, expected: ts.len(), found: xs.len() }),
                _ => return err(ExpectedFoundTuple(loc, expected.clone())),
            },
            Expr::Bool(_) | Expr::F64(_) | Expr::Var(_) | Expr::App(..) => {
                let t_expr = self.infer_ignore_annot(eref)?.clone();
                return self.check_subtype(loc, &t_expr, &expected);
            }


@@ 418,10 427,7 @@ impl<'c, 'r> Checker<'c, 'r> {
            Type::ISize if isize::MIN as i128 <= x && x <= isize::MAX as i128 => Ok(()),
            Type::NSize if 0 <= x && x as u128 <= usize::MAX as u128 => Ok(()),
            Type::Int { .. } | Type::ISize | Type::NSize => err(ExpectedFoundOutOfRange(loc, expected.clone())),
            _ => {
                let t_expr = self.infer_int(x);
                err(ExpectedFound(loc, expected.clone(), t_expr.clone()))
            }
            _ => err(ExpectedFoundInteger(loc, expected.clone())),
        }
    }



@@ 586,7 592,7 @@ fn instantiate_univ(t: &Type) -> Cow<Type> {
    }
}

fn subst_type_rec<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
pub fn subst_type_rec<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
    match t {
        Type::Bool | Type::Int { .. } | Type::ISize | Type::NSize | Type::F64 => t.clone(),
        &Type::TVar(tv) => {


@@ 608,7 614,7 @@ fn subst_type_rec<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
    }
}

fn subst_type_once<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
pub fn subst_type_once<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
    match t {
        Type::Bool | Type::Int { .. } | Type::ISize | Type::NSize | Type::F64 => t.clone(),
        &Type::TVar(tv) => {

M src/diag.rs => src/diag.rs +6 -0
@@ 85,8 85,11 @@ pub enum TypeErr {
    ExpectedFoundLogicBinop(Loc, KType),
    ExpectedFoundUnop(Loc, KType),
    ExpectedFoundOutOfRange(Loc, KType),
    ExpectedFoundInteger(Loc, KType),
    ExpectedFoundTuple(Loc, KType),
    ArityMisLambda { loc: Loc, check: usize, lit: usize },
    ArityMisApp { loc: Loc, params: usize, args: usize },
    ArityMisTuple { loc: Loc, expected: usize, found: usize },
    // InferLambda(Loc),
    // InferLambdaParam(Loc),
    AppNonFun(Loc, KType),


@@ 108,8 111,11 @@ impl TypeErr {
            ExpectedFoundLogicBinop(loc, e) => loc.error(cache, format!("Expected {e:?}, found a logic binary operator.")),
            ExpectedFoundUnop(loc, e) => loc.error(cache, format!("Expected {e:?}, found a unary operator.")),
            ExpectedFoundOutOfRange(loc, e) => loc.error(cache, format!("Expected {e:?}, found integer out of range for that type.")),
            ExpectedFoundInteger(loc, e) => loc.error(cache, format!("Expected {e:?}, found integer.")),
            ExpectedFoundTuple(loc, e) => loc.error(cache, format!("Expected {e:?}, found tuple.")),
            ArityMisLambda { loc, check, lit } => loc.error(cache, format!("Arity mismatch in function parameters. Expectectations from the outside are that there be {check} parameters, but the function here instead has {lit}.")),
            ArityMisApp { loc, params, args } => loc.error(cache, format!("Arity mismatch in function application. Function specifies {params} parameters, but {args} arguments were given.")),
            ArityMisTuple { loc, expected, found } => loc.error(cache, format!("Arity mismatch in tuple. Expected tuple of {expected} elements, found tuple of {found} elements.")),
            // InferLambda(loc) => loc.error(cache, "Can't infer type of lambdas without any additional type information (yet). Please surround the lambda in a type annotation to have the type checked instead of inferred."),
            // InferLambdaParam(loc) => loc.error(cache, "Can't infer type of lambda parameter (yet). Maybe surround the lambda in a type annotation to have the type checked instead of inferred."),
            AppNonFun(loc, t) => loc.error(cache, format!("Expected a function to apply, found a {}.", t.describe())),

M src/main.rs => src/main.rs +15 -11
@@ 663,17 663,21 @@ mod test {
        assert_matches!(run_tmp(src), Ok(111))
    }

    // #[test]
    // fn test_type_synonym2() {
    //     let src = "
    //         (defsyn Byte [] N8)
    //         (defsyn Pair [a b] [a b])
    //         (defsyn Be16 [] (Pair Byte Byte))
    //         (def read-n16-be (of (Fun [Be16] N16) (fun [x] (match x [[hi lo] (+ (* (of N16 (cast lo)) 8) (of N16 (cast hi)))]))))
    //         (def main (as Int (read-n16-be [2 16])))
    //     ";
    //     assert_matches!(run_tmp(src), Ok(528))
    // }
    #[test]
    fn test_type_synonym2() {
        let src = "
            (defsyn Byte [] N8)
            (defsyn Pair [a b] [a b])
            (defsyn Be16 [] (Pair Byte Byte))
            (def read-n16-be
              (of (Fun [Be16] N16)
                  (fun [x] (match x
                             [[hi lo] (+ (* (of N16 (cast hi)) 256)
                                         (of N16 (cast lo)))]))))
            (def main (of Int (cast (read-n16-be [2 16]))))
        ";
        assert_matches!(run_tmp(src), Ok(528))
    }

    // #[test]
    // fn test_nominal_type() {

M src/parse.rs => src/parse.rs +2 -2
@@ 110,7 110,7 @@ pub enum Type {
    Fun(Vec<Type>, Box<Type>),
    PolyFun(Vec<IdentSpan>, Vec<Type>, Box<Type>),
    Const(Name),
    App(Box<Type>, Vec<Type>),
    App(Name, Vec<Type>),
    TVar(IdentSpan),
    Hole(Loc),
    Tuple(Vec<Type>),


@@ 311,7 311,7 @@ fn ptyp<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Type> {
            };
            Ok((inp, t))
        },
        map(pair(typ, rest(typ)), |(f, xs)| Type::App(Box::new(f), xs)),
        map(pair(name, rest(typ)), |(f, xs)| Type::App(f, xs)),
    )(inp)
}


M src/resolve.rs => src/resolve.rs +51 -20
@@ 32,9 32,14 @@ pub fn resolve_def(cache: &mut Cache, module: ModuleId, def_body: &parse::Expr) 
    Ok(Def { exprs, expr_locs, expr_annots, var_locs, str_arena, tvars, pats, pat_locs, root })
}

pub fn resolve_defsyn(cache: &mut Cache, module: ModuleId, syn_typ: &parse::Type) -> Result<Type> {
pub fn resolve_defsyn(
    cache: &mut Cache,
    module: ModuleId,
    syn_typ: &parse::Type,
    params: &[parse::IdentSpan],
) -> Result<(Vec<TVarRef>, Type)> {
    let mut resolver = Resolver::new(cache, module);
    resolver.resolve_type(syn_typ)
    resolver.resolve_parametrized_type(params, |resolver, tvrefs| resolver.resolve_type(syn_typ).map(|t| (tvrefs, t)))
}

struct Resolver<'c> {


@@ 154,28 159,33 @@ impl<'c> Resolver<'c> {
                ps.iter().map(|p| self.resolve_type(p)).collect::<Result<_>>()?,
                Box::new(self.resolve_type(r)?),
            )),
            parse::Type::PolyFun(tvs, ps, r) => {
                let ss: Vec<(Loc, StrRef)> = tvs
                    .iter()
                    .map(|&span| {
                        let s = self.pub_ident(span)?;
                        let sref = self.str_arena.add(&s.s);
                        Ok((s.loc, sref))
                    })
                    .collect::<Result<_>>()?;
                let tvrefs = ss.iter().map(|&(loc, sref)| self.tvars.add(TVar::Univ(loc, sref))).collect::<Vec<_>>();
                let scope = zip(ss.iter().map(|&(_, sref)| self.str_arena[sref].to_string()), tvrefs.clone()).collect();
                self.tvar_scopes.push(scope);
                let ps_res = ps.iter().map(|p| self.resolve_type(p)).collect::<Result<_>>()?;
                let r_res = Box::new(self.resolve_type(r)?);
                self.tvar_scopes.pop();
            parse::Type::PolyFun(tvs, ps, r) => self.resolve_parametrized_type(tvs, |self_, tvrefs| {
                let ps_res = ps.iter().map(|p| self_.resolve_type(p)).collect::<Result<_>>()?;
                let r_res = Box::new(self_.resolve_type(r)?);
                Ok(Type::PolyFun(tvrefs, ps_res, r_res))
            }
            }),
            parse::Type::Const(cname) => match self.resolve_parsed_name(cname)? {
                Res::Syn(id) => self.cache.fetch_synonymous_type(id).cloned(),
                Res::Syn(id) => match self.cache.fetch_synonymous_type(id)? {
                    ([], t) => Ok(t.clone()),
                    (_params, _) => todo!(), // Err
                },
                _ => todo!(),
            },
            parse::Type::App(f, xs) => match self.resolve_parsed_name(f)? {
                Res::Syn(id) => {
                    let xs =
                        xs.iter().map(|x| self.resolve_type(x).map(check::TVar::Subst)).collect::<Result<Vec<_>>>()?;
                    let (params, t) = self.cache.fetch_synonymous_type(id)?;
                    if xs.len() != params.len() {
                        todo!() // return Err
                    }
                    let lookup = |tv: kapo::TVarRef| {
                        params.iter().position(|p| *p == tv).map(|i| &xs[i]).unwrap_or(&check::TVar::Exist)
                    };
                    Ok(check::subst_type_once(&lookup, t))
                }
                _ => todo!(),
            },
            parse::Type::App(..) => todo!(),
            &parse::Type::TVar(span) => {
                let s = self.cache.fetch_source(self.file).map(|src| span.substr_in(src))?;
                self.tvar_scopes


@@ 189,6 199,27 @@ impl<'c> Resolver<'c> {
        }
    }

    fn resolve_parametrized_type<T>(
        &mut self,
        tvars: &[parse::IdentSpan],
        f: impl FnOnce(&mut Self, Vec<TVarRef>) -> Result<T>,
    ) -> Result<T> {
        let ss: Vec<(Loc, StrRef)> = tvars
            .iter()
            .map(|&span| {
                let s = self.pub_ident(span)?;
                let sref = self.str_arena.add(&s.s);
                Ok((s.loc, sref))
            })
            .collect::<Result<_>>()?;
        let tvrefs = ss.iter().map(|&(loc, sref)| self.tvars.add(TVar::Univ(loc, sref))).collect::<Vec<_>>();
        let scope = zip(ss.iter().map(|&(_, sref)| self.str_arena[sref].to_string()), tvrefs.clone()).collect();
        self.tvar_scopes.push(scope);
        let result = f(self, tvrefs)?;
        self.tvar_scopes.pop();
        Ok(result)
    }

    fn resolve_pat(&mut self, pat: &parse::Pat) -> Result<(HashMap<String, Res>, PatRef)> {
        let mut vars = HashMap::new();
        let pref = self.resolve_pat_mut(&mut vars, pat)?;