package semantic
import (
"fmt"
"path/filepath"
"reflect"
"sort"
"strings"
"git.sr.ht/~mna/snow/pkg/token"
)
// The typecheck pass validates that all types in the unit are valid in their
// context. If this pass is successful, all statements use valid types. When
// necessary, ImplicitConv nodes are inserted to indicate type-unification
// conversions that were made implicitly by the type-checker.
//
// It is important to note that no type is modified in this pass - all types are
// as were assigned in the typeassign pass (this makes sense - we walk the same tree,
// we don't have more information about it, so why would we end up setting different
// types?).
func typecheck(unit *Unit, errh func(token.Pos, string)) {
t := &typecheckVisitor{
unit: unit,
errh: errh,
}
Walk(t, unit)
// TODO: eventually make this check only if building an executable
if unit.Main == nil {
errh(token.NoPos, "main function missing")
}
}
type typecheckVisitor struct {
errh func(token.Pos, string)
unit *Unit
// curFn is the current function we're in
curFn *Fn
// curStr is the current struct we're in
curStr *Struct
// curProp is the current property declaration we're processing - this isn't
// set in a clone call, only manually before each Var property declaration, and
// unset after all properties are done (because we can't move from a prop to
// inside a function, so no need to clone for this).
curProp *Var
// refMethod indicates if we're in a ref method context (i.e. curStr is
// set and is the struct that can be modified, and curFn is a ref method
// or a plain method nested in a ref method).
refMethod bool
}
func (t *typecheckVisitor) cloneInFn(fn *Fn) *typecheckVisitor {
isRef := t.refMethod
if fn.IsRef {
isRef = true
}
return &typecheckVisitor{
errh: t.errh,
unit: t.unit,
curFn: fn,
curStr: t.curStr,
refMethod: isRef,
}
}
func (t *typecheckVisitor) cloneInStruct(str *Struct) *typecheckVisitor {
// clear any current function, this doesn't relate to this struct in any way
// (this visitor will be used just to visit the body of that new struct)
return &typecheckVisitor{
errh: t.errh,
unit: t.unit,
curFn: nil,
curStr: str,
refMethod: false,
}
}
func (t *typecheckVisitor) expectTypeCtx(typed Typed, target interface{}, ctxs ...TypeContext) bool {
typ, ctx := typed.Type(), typed.TypeContext()
if !AsType(typ, target) {
tt := reflect.TypeOf(target)
for tt.Kind() == reflect.Ptr {
tt = tt.Elem()
}
t.errh(typed.Pos(), fmt.Sprintf("expected type %s; got %s", tt.Name(), typ))
return false
}
if !typ.Valid() {
switch n := typed.(type) {
case *Ident:
t.errh(typed.Pos(), fmt.Sprintf("invalid type for identifier %s: %s", n.Name, typ))
case Decl:
t.errh(typed.Pos(), fmt.Sprintf("invalid type for declaration %s: %s", n.Ident(), typ))
default:
t.errh(typed.Pos(), fmt.Sprintf("invalid type: %s", typ))
}
return false
}
if (len(ctxs) == 0 && ctx != Invalid) || ctx.isAnyOf(ctxs...) {
return true
}
// build the label for the error
var lbls strings.Builder
for i, ctx := range ctxs {
if i > 0 {
lbls.WriteString(", ")
}
lbls.WriteString(ctx.String())
}
if lbls.Len() == 0 {
lbls.WriteString("valid")
}
switch len(ctxs) {
case 1:
t.errh(typed.Pos(), fmt.Sprintf("expected type context to be %s; got %s", lbls.String(), ctx))
default:
t.errh(typed.Pos(), fmt.Sprintf("expected type context to be one of %s; got %s", lbls.String(), ctx))
}
return false
}
func (t *typecheckVisitor) Visit(n Node) Visitor {
switch n := n.(type) {
case *Unit, *File, *Block:
return t
// ************** DECLARATIONS *****************
case *Fn:
for _, attr := range n.Attrs {
Walk(t, attr)
}
// if the function is generic, maintain a set of generic placeholder names that must be used
// in the signature.
var set map[string]bool
if n.GenericParams != nil {
set = make(map[string]bool, len(n.GenericParams.Elems))
for _, ge := range n.GenericParams.Elems {
Walk(t, ge)
gt := AsGenericType(ge.Type())
set[gt.Name] = true
}
}
for _, p := range n.Params {
// params are Var and are type-verified there (e.g. for Typ context)
Walk(t, p)
}
if n.ReturnExpr != nil {
Walk(t, n.ReturnExpr)
var T Type
t.expectTypeCtx(n.ReturnExpr, &T, Typ)
}
var sigt *SignatureType
if !t.expectTypeCtx(n, &sigt, Immutable) {
break
}
t.typecheckFnDecl(n, sigt, set)
if n.Body != nil {
tt := t.cloneInFn(n)
Walk(tt, n.Body)
}
case *Var:
if n.TypeExpr != nil {
Walk(t, n.TypeExpr)
var T Type
if !t.expectTypeCtx(n.TypeExpr, &T, Typ) {
return nil
}
}
if n.Value != nil {
Walk(t, n.Value)
var T Type
if !t.expectTypeCtx(n.Value, &T, TypeContextValues...) {
return nil
}
if !T.AssignableTo(n.Type()) {
t.errh(n.Pos(), fmt.Sprintf("cannot assign type %s to variable of type %s", T, n.Type()))
return nil
}
if !T.IdenticalTo(n.Type()) {
n.Value = createImplicitConv(n.Value, n.Type())
}
}
var T Type
t.expectTypeCtx(n, &T, Immutable, Mutable)
case *Struct:
// All expressions within the struct body must not use any symbol in outer
// scopes except for Universe and TopLevel ones.
tt := t.cloneInStruct(n)
// TODO: validate that all generic params are used in the struct, and only those (well, others
// would not type-check as they would be undefined symbols)
if n.GenericParams != nil {
for _, ge := range n.GenericParams.Elems {
Walk(tt, ge)
}
}
for _, v := range n.Vars {
tt.curProp = v
Walk(tt, v)
}
tt.curProp = nil
for _, fn := range n.Fns {
Walk(tt, fn)
}
for _, str := range n.Structs {
Walk(tt, str)
}
var st *StructType
t.expectTypeCtx(n, &st, Typ)
case *Interface:
if n.GenericParams != nil {
for _, ge := range n.GenericParams.Elems {
Walk(t, ge)
}
}
for _, fn := range n.Methods {
Walk(t, fn)
}
var it *InterfaceType
t.expectTypeCtx(n, &it, Typ)
case *GenericElem:
var gt *GenericType
t.expectTypeCtx(n, >, Typ)
// ************** STATEMENTS *****************
case *Return:
// at this point, if present, returnType is guaranteed to be Valid().
if t.curFn == nil {
panic("no return type expected, but return statement encountered")
}
retType := AsSignatureType(t.curFn.Type()).Return
// expression type must type-check for the expected function's
// return type.
if n.Value != nil {
Walk(t, n.Value)
var valt Type
if !t.expectTypeCtx(n.Value, &valt, TypeContextValues...) {
return nil
}
if !valt.AssignableTo(retType) {
t.errh(n.Value.Pos(), fmt.Sprintf("invalid type for return value: expected %s, got %s", retType, valt))
return nil
}
// insert implicit conversion if required
if !valt.IdenticalTo(retType) {
n.Value = createImplicitConv(n.Value, retType)
}
return nil
}
// otherwise there is no return value, so the function's return type must be void
if !IsBasicOfKind(retType, Void) {
t.errh(n.Pos(), fmt.Sprintf("missing return value, expected a value of type %s", retType))
return nil
}
case *Assign:
Walk(t, n.Left)
Walk(t, n.Right)
var lt, rt Type
// left must be a mutable variable
if !t.expectTypeCtx(n.Left, <, Mutable) {
return nil
}
// if left is a struct property, assignable only in a "ref" context
if vv := asVarDeclRef(n.Left); vv != nil && vv.PropOf != nil && vv.PropOf == t.curStr && !t.refMethod {
t.errh(n.Left.Pos(), fmt.Sprintf("cannot assign to property %s; current method is not marked as ref", vv.Ident()))
}
// right must be any kind of value
if !t.expectTypeCtx(n.Right, &rt, TypeContextValues...) {
return nil
}
// type of left and right must be compatible
if !rt.AssignableTo(lt) {
t.errh(n.Right.Pos(), fmt.Sprintf("cannot assign type %s to variable of type %s", rt, lt))
return nil
}
// insert implicit conversion if required
if !rt.IdenticalTo(lt) {
n.Right = createImplicitConv(n.Right, lt)
}
case *ExprStmt:
Walk(t, n.Value)
var T Type
t.expectTypeCtx(n.Value, &T, TypeContextValues...)
case *If:
t.typecheckConds(n.Conds)
if n.Body != nil {
Walk(t, n.Body)
}
if n.Else != nil {
Walk(t, n.Else)
}
case *Guard:
t.typecheckConds(n.Conds)
if n.Else != nil {
Walk(t, n.Else)
}
// ************** EXPRESSIONS *****************
case *FnTypeExpr:
var T Type
for _, p := range n.Params {
Walk(t, p)
t.expectTypeCtx(p, &T, Typ)
}
if n.Return != nil {
Walk(t, n.Return)
t.expectTypeCtx(n.Return, &T, Typ)
}
t.expectTypeCtx(n, &T, Typ)
case *TupleTypeExpr:
var T Type
for _, f := range n.Fields {
Walk(t, f)
t.expectTypeCtx(f, &T, Typ)
}
t.expectTypeCtx(n, &T, Typ)
case *TupleVal:
var T Type
for _, v := range n.Values {
Walk(t, v)
t.expectTypeCtx(n, &T, TypeContextValues...)
}
t.expectTypeCtx(n, &T, Value)
case *Binary:
// operands must be valid for the operator, and type of the binary
// expression must unify the types of the operands.
Walk(t, n.Left)
Walk(t, n.Right)
var lt, rt *BasicType
if !t.expectTypeCtx(n.Left, <, TypeContextValues...) {
return nil
}
if !t.expectTypeCtx(n.Right, &rt, TypeContextValues...) {
return nil
}
// type of binary expression is invalid if operands cannot be unified for this operator
if !n.Type().Valid() {
t.errh(n.Pos(), fmt.Sprintf("incompatible operand types: %s %s %s", lt, n.Op, rt))
return nil
}
// if types are not identical, insert implicit conversion nodes the
// converted operand is always the smaller one
lsz, rsz := basicKindSizes[lt.Kind], basicKindSizes[rt.Kind]
if lsz < rsz {
n.Left = createImplicitConv(n.Left, rt)
} else if rsz < lsz {
n.Right = createImplicitConv(n.Right, lt)
}
case *Unary:
Walk(t, n.Right)
// operators only valid on basic kinds
var rt *BasicType
if !t.expectTypeCtx(n.Right, &rt, TypeContextValues...) {
return nil
}
if !IsBasicOfKind(rt, unaryOpsTable[n.Op]...) {
t.errh(n.Pos(), fmt.Sprintf("invalid operation: %s %s", n.Op, rt))
return nil
}
case *Paren:
Walk(t, n.Value)
T := n.Value.Type()
t.expectTypeCtx(n, &T, n.Value.TypeContext())
case *Call:
Walk(t, n.Fun)
for _, arg := range n.Args {
Walk(t, arg)
}
if n.InitOf != nil {
t.typecheckStructInit(n)
} else {
t.typecheckFnCall(n)
}
case *Selector:
// the type context is taken care of in the type assignment pass, and may be invalid
// this is validated in the larger expression or statement, depending on context
// (e.g. if a type is expected or a mutable value, etc.).
Walk(t, n.Left)
Walk(t, n.Sel)
var T Type
if !t.expectTypeCtx(n, &T) {
return nil
}
// if the selector is a ref method, only valid if the left side is a mutable context
// (i.e. must be a var struct to get a method that can mutate the struct).
var ref Decl
switch sel := n.Sel.(type) {
case *Ident:
ref = sel.Ref
case *GenericInst:
ref = sel.GenericDecl.Ref
}
if fn := AsFnDecl(ref); fn != nil && fn.IsRef {
if lctx := n.Left.TypeContext(); lctx != Mutable {
t.errh(n.Sel.Pos(), fmt.Sprintf("cannot access ref fn %s; left-hand side must be %s, is %s", ref.Ident(), Mutable, lctx))
}
}
case *GenericInst:
var T Type
t.expectTypeCtx(n, &T, Typ, Value)
// TODO: also, should the number of types be validated here instead? Conceptually would make more sense,
// but currently done in type-assign and it works well.
case *Ident:
// check that it has a valid type and context
var idt Type
ctxs := []TypeContext{Typ, Mutable, Immutable}
// tuple field selector is an identifier but it can be in a value context, and not a Typ one
if n.Index >= 0 {
// replace Typ with Value
ctxs[0] = Value
}
if !t.expectTypeCtx(n, &idt, ctxs...) {
return nil
}
if t.curStr != nil {
minScopeID := t.curStr.BodyScope.ID
// in a struct scope, cannot access outer expressions
if ref := n.Ref; ref != nil && ref.TypeContext().isAnyOf(Mutable, Immutable) {
if scope := ref.Scope(); scope.ID < minScopeID && !scope.IsTopLevel() && !scope.IsUniverse() {
t.errh(n.Pos(), fmt.Sprintf("%s is not a field on %s nor a top-level symbol", n.Name, t.curStr.Type()))
}
}
}
if t.curProp != nil {
// a struct property cannot access a struct method or another struct property during initialization
if fn := AsFnDecl(n.Ref); fn != nil && fn.MethodOf != nil && fn.MethodOf == t.curStr {
t.errh(n.Pos(), fmt.Sprintf("cannot access method %s in property initializer", fn.Ident()))
} else if prop := AsVarDecl(n.Ref); prop != nil && prop.PropOf != nil && prop.PropOf == t.curStr {
t.errh(n.Pos(), fmt.Sprintf("cannot access property %s in property initializer", prop.Ident()))
}
}
case *LitString:
var T Type
t.expectTypeCtx(n, &T, Constant)
case *LitInt:
var T Type
t.expectTypeCtx(n, &T, Constant)
default:
if n != nil {
panic(fmt.Sprintf("invalid node type: %T", n))
}
}
return nil
}
func (t *typecheckVisitor) typecheckFnDecl(fn *Fn, st *SignatureType, gens map[string]bool) {
if fn.Ident() == MainFnName && fn.Scope().IsTopLevel() {
t.typecheckMainFn(fn, st)
}
if fn.IsRef && fn.MethodOf == nil {
t.errh(fn.Pos(), fmt.Sprintf("function %s cannot have a ref modifier; only valid for struct functions (aka methods)", fn.Ident()))
}
t.typecheckFnAttrs(fn)
if len(gens) > 0 {
t.typecheckGenFnSig(fn, st, gens)
}
}
func (t *typecheckVisitor) typecheckGenFnSig(fn *Fn, st *SignatureType, gens map[string]bool) {
for _, pt := range st.Params {
if gt := AsGenericType(pt); gt != nil {
delete(gens, gt.Name)
}
}
if gt := AsGenericType(st.Return); gt != nil {
delete(gens, gt.Name)
}
miss := make([]string, 0, len(gens))
for k := range gens {
miss = append(miss, k)
}
if len(miss) > 0 {
sort.Strings(miss)
t.errh(fn.Pos(), fmt.Sprintf("unused generic type(s) in function signature: %v", miss))
}
}
func (t *typecheckVisitor) typecheckFnAttrs(fn *Fn) {
var extern *Call
attrs := make(map[string]bool, len(fn.Attrs))
for _, attr := range fn.Attrs {
if attr.InitOf == nil {
// invalid attribute (symbol not found), will have already raised an error,
// just skip it.
continue
}
nm := attr.InitOf.Ident()
if attrs[nm] {
t.errh(attr.Pos(), fmt.Sprintf("duplicate attribute @%s applied to function %s", nm, fn.Ident()))
continue
}
attrs[nm] = true
// special-case the @extern attribute
if nm == ExternAttrName {
extern = attr
}
}
// @extern-specific validations
if fn.Body == nil && extern == nil && fn.AbstractMethodOf == nil {
t.errh(fn.Pos(), fmt.Sprintf("function %s must have a body", fn.Ident()))
}
if extern != nil {
if fn.Body != nil {
t.errh(fn.Pos(), fmt.Sprintf("@%s function %s cannot have a body", ExternAttrName, fn.Ident()))
}
if fn.IsRef && fn.MethodOf != nil {
t.errh(fn.Pos(), fmt.Sprintf("@%s function %s cannot have a ref modifier", ExternAttrName, fn.Ident()))
}
if vals := extern.Values(); len(vals) > 0 {
pkg, _ := vals["pkg"].(string)
if pkg == "" {
imp, _ := vals["import"].(string)
pkg = filepath.Base(imp)
}
if decl := fn.Scope().LookupChain(pkg, fn.Pos()); decl != nil {
t.errh(extern.Pos(), fmt.Sprintf("external package identifier %s already declared", pkg))
}
}
}
}
func (t *typecheckVisitor) typecheckMainFn(fn *Fn, st *SignatureType) {
if len(st.Params) > 0 || !IsBasicOfKind(st.Return, Void) {
t.errh(fn.Pos(), fmt.Sprintf("main function must not have any parameter and return value, is %s", st))
}
if len(fn.Attrs) > 0 {
t.errh(fn.Pos(), "main function must not have any attribute")
}
if t.unit.Main != nil {
lpos := t.unit.FileSet.Position(t.unit.Main.Pos())
t.errh(fn.Pos(), fmt.Sprintf("main function already declared at %s", lpos))
} else {
t.unit.Main = fn
}
}
func (t *typecheckVisitor) typecheckConds(conds []Expr) {
for _, cond := range conds {
Walk(t, cond)
// each condition must be of type bool
var bt *BasicType
if !t.expectTypeCtx(cond, &bt, TypeContextValues...) {
continue
}
if !IsBasicOfKind(bt, Bool) {
t.errh(cond.Pos(), fmt.Sprintf("non-bool (type %s) used as condition", bt))
continue
}
}
}
func (t *typecheckVisitor) typecheckStructInit(n *Call) {
// must be a struct type
var st *StructType
if !t.expectTypeCtx(n.Fun, &st, Typ) {
return
}
// sanity check
if !st.IdenticalTo(n.InitOf.Type()) && !(n.InitOf.IsGeneric() && st.Decl == n.InitOf) {
panic(fmt.Sprintf("struct init with Fun type %s has unexpected InitOf.Type of %s", st, n.InitOf.Type()))
}
// all non-initialized fields must be provided, labels must be used,
// but order is not important.
labelToType := make(map[string]Type, len(n.InitOf.Vars))
required := make(map[string]bool, len(n.InitOf.Vars))
for _, v := range n.InitOf.Vars {
labelToType[v.Ident()] = st.typeOfSel(v.Type())
if v.Value == nil {
required[v.Ident()] = true
}
}
provided := make(map[string]bool, len(n.Args))
for i, arg := range n.Args {
lbl := n.Labels[i]
if lbl == "" {
t.errh(arg.Pos(), "label required for struct initializer")
// cannot give relevant "missing required field" if some have no label, so ignore them
required = nil
continue
}
if provided[lbl] {
t.errh(arg.Pos(), fmt.Sprintf("field already provided: %s", lbl))
continue
}
provided[lbl] = true
typ, ok := labelToType[lbl]
if !ok {
t.errh(arg.Pos(), fmt.Sprintf("invalid field: %s", lbl))
continue
}
delete(required, lbl)
var argt Type
if !t.expectTypeCtx(arg, &argt, TypeContextValues...) {
continue
}
if !argt.AssignableTo(typ) {
t.errh(arg.Pos(), fmt.Sprintf("invalid type for field %s: expected %s, got %s", lbl, typ, argt))
continue
}
// insert implicit conversion if required
if !argt.IdenticalTo(typ) {
n.Args[i] = createImplicitConv(arg, typ)
}
}
// all required fields must be provided
if len(required) > 0 {
fields := make([]string, 0, len(required))
for lbl := range required {
fields = append(fields, lbl)
}
sort.Strings(fields)
msg := "field"
if len(fields) > 1 {
msg += "s"
}
t.errh(n.Pos(), fmt.Sprintf("required %s not provided: %s", msg, strings.Join(fields, ",")))
}
// return type must be the struct type
var callt *StructType
if t.expectTypeCtx(n, &callt, Value) {
if !callt.IdenticalTo(st) {
t.errh(n.Pos(), fmt.Sprintf("invalid type for call: expected %s, got %s", st, callt))
}
}
}
func (t *typecheckVisitor) typecheckFnCall(n *Call) {
// must be a function type
var st *SignatureType
if !t.expectTypeCtx(n.Fun, &st, Value, Mutable, Immutable) {
return
}
// arity must match
if len(n.Args) != len(st.Params) {
t.errh(n.Pos(), fmt.Sprintf("wrong number of arguments in call: expected %d, got %d", len(st.Params), len(n.Args)))
return
}
// TODO: for calls via interface, the labels are not allowed?
// if labels are allowed, collect the expected labels
var expectedLabels []string
if fn := asFnDeclRef(n.Fun); fn != nil {
expectedLabels = make([]string, len(fn.Params))
for i, p := range fn.Params {
expectedLabels[i] = p.Ident()
}
}
// validate argument types and labels
var hasLabels bool
for i, arg := range n.Args {
lbl := n.Labels[i]
// validate the label/no label/mixed labels
if lbl != "" {
if len(expectedLabels) == 0 {
t.errh(arg.Pos(), fmt.Sprintf("label provided but not allowed on function value %s", st))
} else if i > 0 && !hasLabels {
t.errh(arg.Pos(), "invalid mix of labelled and unlabelled arguments")
} else if lbl != expectedLabels[i] {
t.errh(arg.Pos(), fmt.Sprintf("expected label %q at argument index %d, got %q", expectedLabels[i], i, lbl))
}
hasLabels = true
} else if len(expectedLabels) > 0 && hasLabels {
t.errh(arg.Pos(), "invalid mix of labelled and unlabelled arguments")
}
var argt Type
if !t.expectTypeCtx(arg, &argt, TypeContextValues...) {
continue
}
if !argt.AssignableTo(st.Params[i]) {
t.errh(arg.Pos(), fmt.Sprintf("invalid type for argument: expected %s, got %s", st.Params[i], argt))
continue
}
// insert implicit conversion if required
if !argt.IdenticalTo(st.Params[i]) {
n.Args[i] = createImplicitConv(arg, st.Params[i])
}
}
// return type must match type of call
var callt Type
if t.expectTypeCtx(n, &callt, Value) {
if !callt.AssignableTo(st.Return) {
t.errh(n.Pos(), fmt.Sprintf("invalid type for call: expected %s, got %s", st.Return, callt))
}
}
}
func asFnDeclRef(expr Expr) *Fn {
if expr.TypeContext() != Immutable {
return nil
}
return AsFnDecl(asIdentDeclRef(expr))
}
func asVarDeclRef(expr Expr) *Var {
if !expr.TypeContext().isAnyOf(Mutable, Immutable) {
return nil
}
return AsVarDecl(asIdentDeclRef(expr))
}
func asIdentDeclRef(expr Expr) Decl {
var ref Decl
Inspect(expr, func(n Node) bool {
if n == nil {
return false
}
switch n := n.(type) {
case *Paren:
return true
case *Selector:
return true
case *Ident:
ref = n.Ref
}
return false
})
return ref
}
func createImplicitConv(expr Expr, toType Type) *ImplicitConv {
var ic ImplicitConv
ic.pos = expr.Pos()
ic.scope = expr.Scope()
ic.Value = expr
ic.typ = toType
ic.ctx = Value
return &ic
}