M src/abase.rs => src/abase.rs +24 -24
@@ 563,17 563,17 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
}
fn size_of(&mut self, t: &kapo::Type) -> Operand {
- use kapo::Type::*;
+ use kapo::Type;
match t {
- Bool => Const::NSize(1).into(),
- &Int { width, .. } => Const::NSize(width as usize / 8).into(),
- ISize | NSize => Const::SizeofPtr.into(),
- F64 => Const::NSize(8).into(),
+ Type::Bool => Const::NSize(1).into(),
+ &Type::Int { width, .. } => Const::NSize(width as usize / 8).into(),
+ Type::ISize | Type::NSize => Const::SizeofPtr.into(),
+ Type::F64 => Const::NSize(8).into(),
// TODO: Change this once `Fun` represents a closure
- Fun(..) => Const::SizeofPtr.into(),
- PolyFun(..) => todo!(),
- Tuple(us) => self.size_of_tuple(us.iter()).0,
- TVar(tv) => {
+ Type::Fun(..) => Const::SizeofPtr.into(),
+ Type::PolyFun(..) => todo!(),
+ Type::Tuple(us) => self.size_of_tuple(us.iter()).0,
+ Type::TVar(tv) => {
let i = self.tvars.binary_search(tv).unwrap();
self.abase_vwt_field_size(self.vwt_param(i))
}
@@ 581,14 581,14 @@ impl<'c, 'k> AbaseDef<'c, 'k> {
}
fn align_of(&mut self, t: &kapo::Type) -> Operand {
- use kapo::Type::*;
+ use kapo::Type;
match t {
- Tuple(_us) => todo!(), // us.iter().map(|u| self.align_of(u)).fold(Align::byte(), |acc, a| acc.max(&a)),
- TVar(tv) => {
+ Type::Tuple(_us) => todo!(), // us.iter().map(|u| self.align_of(u)).fold(Align::byte(), |acc, a| acc.max(&a)),
+ Type::TVar(tv) => {
let i = self.tvars.binary_search(tv).unwrap();
self.abase_vwt_field_align(self.vwt_param(i))
}
- Fun(..) | PolyFun(..) => Const::SizeofPtr.into(),
+ Type::Fun(..) | Type::PolyFun(..) => Const::SizeofPtr.into(),
_ => self.size_of(t),
}
}
@@ 678,12 678,12 @@ fn abase_type(t: &kapo::Type) -> Pass {
}
fn abase_int(x: i128, typ: &kapo::Type) -> Const {
- use kapo::Type::*;
+ use kapo::Type;
match *typ {
- Int { width, signed: true } => Const::Int { width, val: x as i64 },
- Int { width, signed: false } => Const::Nat { width, val: x as u64 },
- ISize => Const::ISize(x as isize),
- NSize => Const::NSize(x as usize),
+ Type::Int { width, signed: true } => Const::Int { width, val: x as i64 },
+ Type::Int { width, signed: false } => Const::Nat { width, val: x as u64 },
+ Type::ISize => Const::ISize(x as isize),
+ Type::NSize => Const::NSize(x as usize),
ref t => panic!("ice: integer expr doesn't have signed integer type. found {}", t.describe()),
}
}
@@ 996,13 996,13 @@ impl Converge for ConvVarMerge {
}
fn abase_tag(tag: u64, typ: &kapo::Type) -> Option<Const> {
- use kapo::Type::*;
+ use kapo::Type;
match typ {
- Bool => Some(Const::Bool(tag != 0)),
- Int { signed: true, .. } | ISize => Some(abase_int(tag as i64 as i128, typ)),
- Int { .. } | NSize => Some(abase_int(tag as i128, typ)),
- Tuple(_) => None,
- F64 | TVar(_) | Fun(_, _) | PolyFun(..) => todo!(),
+ Type::Bool => Some(Const::Bool(tag != 0)),
+ Type::Int { signed: true, .. } | Type::ISize => Some(abase_int(tag as i64 as i128, typ)),
+ Type::Int { .. } | Type::NSize => Some(abase_int(tag as i128, typ)),
+ Type::Tuple(_) => None,
+ Type::F64 | Type::TVar(_) | Type::Fun(_, _) | Type::PolyFun(..) => todo!(),
}
}
M src/cache.rs => src/cache.rs +7 -5
@@ 22,7 22,7 @@ pub struct Cache {
resolved_names: HashMap<FullName<String>, Res>,
resolved_names_rev: HashMap<Res, FullName<String>>,
module_local_resolved_names: HashMap<ModuleId, HashMap<String, Res>>,
- syns: HashMap<SynId, kapo::Type>,
+ syns: HashMap<SynId, (Vec<kapo::TVarRef>, kapo::Type)>,
resolveds: HashMap<DefId, resolve::Def>,
sigs: HashMap<DefId, Fetch<kapo::Type, Option<kapo::Type>>>,
checkeds: HashMap<DefId, Fetch<check::Def, ()>>,
@@ 309,17 309,19 @@ impl Cache {
}
}
- pub fn fetch_synonymous_type(&mut self, syn_id: SynId) -> Result<&kapo::Type> {
+ pub fn fetch_synonymous_type(&mut self, syn_id: SynId) -> Result<(&[kapo::TVarRef], &kapo::Type)> {
if self.syns.contains_key(&syn_id) {
- Ok(&self.syns[&syn_id])
+ let (params, t) = &self.syns[&syn_id];
+ Ok((params.as_slice(), t))
} else {
let module_id = *self.syn_parent_modules.get(&syn_id).expect("ice: a syn ID should have a parent module");
let syn_ident = self.get_syn_name(syn_id).expect("ice: a syn ID should have a reverse").last().clone();
let module = self.fetch_parsed_module(module_id)?;
let parsed = module.defsyns[&syn_ident].clone();
- let resolved = resolve::resolve_defsyn(self, module_id, &parsed.2)?;
+ let resolved = resolve::resolve_defsyn(self, module_id, &parsed.2, &parsed.1)?;
self.syns.insert(syn_id, resolved);
- Ok(&self.syns[&syn_id])
+ let (params, t) = &self.syns[&syn_id];
+ Ok((params.as_slice(), t))
}
}
M src/check.rs => src/check.rs +14 -8
@@ 105,7 105,7 @@ pub struct Def {
}
#[derive(Debug, Clone)]
-enum TVar {
+pub enum TVar {
Univ, // Bound in a universal quantification
Exist, // Existential type. As of yet unknown.
Subst(Type),
@@ 399,7 399,16 @@ impl<'c, 'r> Checker<'c, 'r> {
self.check(b, &expected)?;
}
}
- _ => {
+ Expr::Tuple(ref xs) => match &expected {
+ Type::Tuple(ts) if ts.len() == xs.len() => {
+ for (&x, t) in zip(xs, ts) {
+ self.check(x, t)?;
+ }
+ }
+ Type::Tuple(ts) => return err(ArityMisTuple { loc, expected: ts.len(), found: xs.len() }),
+ _ => return err(ExpectedFoundTuple(loc, expected.clone())),
+ },
+ Expr::Bool(_) | Expr::F64(_) | Expr::Var(_) | Expr::App(..) => {
let t_expr = self.infer_ignore_annot(eref)?.clone();
return self.check_subtype(loc, &t_expr, &expected);
}
@@ 418,10 427,7 @@ impl<'c, 'r> Checker<'c, 'r> {
Type::ISize if isize::MIN as i128 <= x && x <= isize::MAX as i128 => Ok(()),
Type::NSize if 0 <= x && x as u128 <= usize::MAX as u128 => Ok(()),
Type::Int { .. } | Type::ISize | Type::NSize => err(ExpectedFoundOutOfRange(loc, expected.clone())),
- _ => {
- let t_expr = self.infer_int(x);
- err(ExpectedFound(loc, expected.clone(), t_expr.clone()))
- }
+ _ => err(ExpectedFoundInteger(loc, expected.clone())),
}
}
@@ 586,7 592,7 @@ fn instantiate_univ(t: &Type) -> Cow<Type> {
}
}
-fn subst_type_rec<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
+pub fn subst_type_rec<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
match t {
Type::Bool | Type::Int { .. } | Type::ISize | Type::NSize | Type::F64 => t.clone(),
&Type::TVar(tv) => {
@@ 608,7 614,7 @@ fn subst_type_rec<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
}
}
-fn subst_type_once<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
+pub fn subst_type_once<'a>(tvars: &impl Fn(TVarRef) -> &'a TVar, t: &Type) -> Type {
match t {
Type::Bool | Type::Int { .. } | Type::ISize | Type::NSize | Type::F64 => t.clone(),
&Type::TVar(tv) => {
M src/diag.rs => src/diag.rs +6 -0
@@ 85,8 85,11 @@ pub enum TypeErr {
ExpectedFoundLogicBinop(Loc, KType),
ExpectedFoundUnop(Loc, KType),
ExpectedFoundOutOfRange(Loc, KType),
+ ExpectedFoundInteger(Loc, KType),
+ ExpectedFoundTuple(Loc, KType),
ArityMisLambda { loc: Loc, check: usize, lit: usize },
ArityMisApp { loc: Loc, params: usize, args: usize },
+ ArityMisTuple { loc: Loc, expected: usize, found: usize },
// InferLambda(Loc),
// InferLambdaParam(Loc),
AppNonFun(Loc, KType),
@@ 108,8 111,11 @@ impl TypeErr {
ExpectedFoundLogicBinop(loc, e) => loc.error(cache, format!("Expected {e:?}, found a logic binary operator.")),
ExpectedFoundUnop(loc, e) => loc.error(cache, format!("Expected {e:?}, found a unary operator.")),
ExpectedFoundOutOfRange(loc, e) => loc.error(cache, format!("Expected {e:?}, found integer out of range for that type.")),
+ ExpectedFoundInteger(loc, e) => loc.error(cache, format!("Expected {e:?}, found integer.")),
+ ExpectedFoundTuple(loc, e) => loc.error(cache, format!("Expected {e:?}, found tuple.")),
ArityMisLambda { loc, check, lit } => loc.error(cache, format!("Arity mismatch in function parameters. Expectectations from the outside are that there be {check} parameters, but the function here instead has {lit}.")),
ArityMisApp { loc, params, args } => loc.error(cache, format!("Arity mismatch in function application. Function specifies {params} parameters, but {args} arguments were given.")),
+ ArityMisTuple { loc, expected, found } => loc.error(cache, format!("Arity mismatch in tuple. Expected tuple of {expected} elements, found tuple of {found} elements.")),
// InferLambda(loc) => loc.error(cache, "Can't infer type of lambdas without any additional type information (yet). Please surround the lambda in a type annotation to have the type checked instead of inferred."),
// InferLambdaParam(loc) => loc.error(cache, "Can't infer type of lambda parameter (yet). Maybe surround the lambda in a type annotation to have the type checked instead of inferred."),
AppNonFun(loc, t) => loc.error(cache, format!("Expected a function to apply, found a {}.", t.describe())),
M src/main.rs => src/main.rs +15 -11
@@ 663,17 663,21 @@ mod test {
assert_matches!(run_tmp(src), Ok(111))
}
- // #[test]
- // fn test_type_synonym2() {
- // let src = "
- // (defsyn Byte [] N8)
- // (defsyn Pair [a b] [a b])
- // (defsyn Be16 [] (Pair Byte Byte))
- // (def read-n16-be (of (Fun [Be16] N16) (fun [x] (match x [[hi lo] (+ (* (of N16 (cast lo)) 8) (of N16 (cast hi)))]))))
- // (def main (as Int (read-n16-be [2 16])))
- // ";
- // assert_matches!(run_tmp(src), Ok(528))
- // }
+ #[test]
+ fn test_type_synonym2() {
+ let src = "
+ (defsyn Byte [] N8)
+ (defsyn Pair [a b] [a b])
+ (defsyn Be16 [] (Pair Byte Byte))
+ (def read-n16-be
+ (of (Fun [Be16] N16)
+ (fun [x] (match x
+ [[hi lo] (+ (* (of N16 (cast hi)) 256)
+ (of N16 (cast lo)))]))))
+ (def main (of Int (cast (read-n16-be [2 16]))))
+ ";
+ assert_matches!(run_tmp(src), Ok(528))
+ }
// #[test]
// fn test_nominal_type() {
M src/parse.rs => src/parse.rs +2 -2
@@ 110,7 110,7 @@ pub enum Type {
Fun(Vec<Type>, Box<Type>),
PolyFun(Vec<IdentSpan>, Vec<Type>, Box<Type>),
Const(Name),
- App(Box<Type>, Vec<Type>),
+ App(Name, Vec<Type>),
TVar(IdentSpan),
Hole(Loc),
Tuple(Vec<Type>),
@@ 311,7 311,7 @@ fn ptyp<'s: 't, 't>(inp: Inp<'s, 't>) -> Res<'s, 't, Type> {
};
Ok((inp, t))
},
- map(pair(typ, rest(typ)), |(f, xs)| Type::App(Box::new(f), xs)),
+ map(pair(name, rest(typ)), |(f, xs)| Type::App(f, xs)),
)(inp)
}
M src/resolve.rs => src/resolve.rs +51 -20
@@ 32,9 32,14 @@ pub fn resolve_def(cache: &mut Cache, module: ModuleId, def_body: &parse::Expr)
Ok(Def { exprs, expr_locs, expr_annots, var_locs, str_arena, tvars, pats, pat_locs, root })
}
-pub fn resolve_defsyn(cache: &mut Cache, module: ModuleId, syn_typ: &parse::Type) -> Result<Type> {
+pub fn resolve_defsyn(
+ cache: &mut Cache,
+ module: ModuleId,
+ syn_typ: &parse::Type,
+ params: &[parse::IdentSpan],
+) -> Result<(Vec<TVarRef>, Type)> {
let mut resolver = Resolver::new(cache, module);
- resolver.resolve_type(syn_typ)
+ resolver.resolve_parametrized_type(params, |resolver, tvrefs| resolver.resolve_type(syn_typ).map(|t| (tvrefs, t)))
}
struct Resolver<'c> {
@@ 154,28 159,33 @@ impl<'c> Resolver<'c> {
ps.iter().map(|p| self.resolve_type(p)).collect::<Result<_>>()?,
Box::new(self.resolve_type(r)?),
)),
- parse::Type::PolyFun(tvs, ps, r) => {
- let ss: Vec<(Loc, StrRef)> = tvs
- .iter()
- .map(|&span| {
- let s = self.pub_ident(span)?;
- let sref = self.str_arena.add(&s.s);
- Ok((s.loc, sref))
- })
- .collect::<Result<_>>()?;
- let tvrefs = ss.iter().map(|&(loc, sref)| self.tvars.add(TVar::Univ(loc, sref))).collect::<Vec<_>>();
- let scope = zip(ss.iter().map(|&(_, sref)| self.str_arena[sref].to_string()), tvrefs.clone()).collect();
- self.tvar_scopes.push(scope);
- let ps_res = ps.iter().map(|p| self.resolve_type(p)).collect::<Result<_>>()?;
- let r_res = Box::new(self.resolve_type(r)?);
- self.tvar_scopes.pop();
+ parse::Type::PolyFun(tvs, ps, r) => self.resolve_parametrized_type(tvs, |self_, tvrefs| {
+ let ps_res = ps.iter().map(|p| self_.resolve_type(p)).collect::<Result<_>>()?;
+ let r_res = Box::new(self_.resolve_type(r)?);
Ok(Type::PolyFun(tvrefs, ps_res, r_res))
- }
+ }),
parse::Type::Const(cname) => match self.resolve_parsed_name(cname)? {
- Res::Syn(id) => self.cache.fetch_synonymous_type(id).cloned(),
+ Res::Syn(id) => match self.cache.fetch_synonymous_type(id)? {
+ ([], t) => Ok(t.clone()),
+ (_params, _) => todo!(), // Err
+ },
+ _ => todo!(),
+ },
+ parse::Type::App(f, xs) => match self.resolve_parsed_name(f)? {
+ Res::Syn(id) => {
+ let xs =
+ xs.iter().map(|x| self.resolve_type(x).map(check::TVar::Subst)).collect::<Result<Vec<_>>>()?;
+ let (params, t) = self.cache.fetch_synonymous_type(id)?;
+ if xs.len() != params.len() {
+ todo!() // return Err
+ }
+ let lookup = |tv: kapo::TVarRef| {
+ params.iter().position(|p| *p == tv).map(|i| &xs[i]).unwrap_or(&check::TVar::Exist)
+ };
+ Ok(check::subst_type_once(&lookup, t))
+ }
_ => todo!(),
},
- parse::Type::App(..) => todo!(),
&parse::Type::TVar(span) => {
let s = self.cache.fetch_source(self.file).map(|src| span.substr_in(src))?;
self.tvar_scopes
@@ 189,6 199,27 @@ impl<'c> Resolver<'c> {
}
}
+ fn resolve_parametrized_type<T>(
+ &mut self,
+ tvars: &[parse::IdentSpan],
+ f: impl FnOnce(&mut Self, Vec<TVarRef>) -> Result<T>,
+ ) -> Result<T> {
+ let ss: Vec<(Loc, StrRef)> = tvars
+ .iter()
+ .map(|&span| {
+ let s = self.pub_ident(span)?;
+ let sref = self.str_arena.add(&s.s);
+ Ok((s.loc, sref))
+ })
+ .collect::<Result<_>>()?;
+ let tvrefs = ss.iter().map(|&(loc, sref)| self.tvars.add(TVar::Univ(loc, sref))).collect::<Vec<_>>();
+ let scope = zip(ss.iter().map(|&(_, sref)| self.str_arena[sref].to_string()), tvrefs.clone()).collect();
+ self.tvar_scopes.push(scope);
+ let result = f(self, tvrefs)?;
+ self.tvar_scopes.pop();
+ Ok(result)
+ }
+
fn resolve_pat(&mut self, pat: &parse::Pat) -> Result<(HashMap<String, Res>, PatRef)> {
let mut vars = HashMap::new();
let pref = self.resolve_pat_mut(&mut vars, pat)?;