package semantic
import (
"fmt"
"git.sr.ht/~mna/snow/pkg/token"
)
const maxDeferredLoops = 100
// The typeassign pass sets the types of every expression and declaration in
// the unit. Upon completion, all Expr have a Type and TypeContext, all Decl
// have a Type, and all *Ident have their Ref set if valid.
func typeassign(unit *Unit, errh func(token.Pos, string)) {
t := &typeassignVisitor{
errh: errh,
}
Walk(t, unit)
// it is possible that after a pass of deferred nodes, there are *more* deferred
// nodes, because it unlocked a sub-tree that still can't be resolved. So we
// cannot stop just by checking if the number of deferred calls diminished or not.
for i := 0; i < maxDeferredLoops && len(t.deferred) > 0; i++ {
// run deferred processing to try and resolve unresolved types
pass := make([]func(), len(t.deferred))
copy(pass, t.deferred)
t.deferred = t.deferred[:0]
for _, fn := range pass {
fn()
}
}
}
type typeassignVisitor struct {
errh func(token.Pos, string)
deferred []func()
// flag indicating that the Identifier is part of a generic instantiation expression.
genericInstIdent bool
// flag indicating that we are visiting the left-hand side of a selector.
selectorLHS bool
}
func (t *typeassignVisitor) addDeferred(n Node) {
genericInstIdent := t.genericInstIdent
selectorLHS := t.selectorLHS
t.deferred = append(t.deferred, func() {
t.genericInstIdent = genericInstIdent
t.selectorLHS = selectorLHS
Walk(t, n)
})
}
func (t *typeassignVisitor) errUndefined(ident *Ident, in Type) {
if in != nil {
t.errh(ident.Pos(), fmt.Sprintf("undefined in %s: %s", in, ident.Name))
} else {
t.errh(ident.Pos(), fmt.Sprintf("undefined: %s", ident.Name))
}
ident.ctx = Invalid
ident.typ = unresolvedType{}
}
func (t *typeassignVisitor) Visit(n Node) Visitor {
switch n := n.(type) {
case *Unit, *File, *Block, *Return, *Assign, *ExprStmt, *If, *Guard:
return t
// ************** DECLARATIONS *****************
case *Fn:
sigt := &SignatureType{
Params: make([]Type, len(n.Params)),
}
for _, attr := range n.Attrs {
Walk(t, attr)
}
if n.GenericParams != nil {
for _, elem := range n.GenericParams.Elems {
Walk(t, elem)
}
}
for i, param := range n.Params {
Walk(t, param)
sigt.Params[i] = param.Type()
}
if n.ReturnExpr != nil {
Walk(t, n.ReturnExpr)
sigt.Return = n.ReturnExpr.Type()
} else {
sigt.Return = &BasicType{Kind: Void}
}
if n.Body != nil {
Walk(t, n.Body)
}
n.typ = sigt
case *Var:
// type of var is either its explicit type or the type of its initialization
var typ Type = unresolvedType{}
if n.TypeExpr != nil {
Walk(t, n.TypeExpr)
typ = n.TypeExpr.Type()
}
if n.Value != nil {
Walk(t, n.Value)
if n.TypeExpr == nil {
typ = n.Value.Type()
}
}
n.typ = typ
case *Struct:
st := &StructType{Decl: n}
if n.GenericParams != nil {
for _, elem := range n.GenericParams.Elems {
Walk(t, elem)
}
}
for _, v := range n.Vars {
Walk(t, v)
}
for _, fn := range n.Fns {
Walk(t, fn)
}
for _, str := range n.Structs {
Walk(t, str)
}
n.typ = st
case *GenericElem:
// TODO: when traits/capabilities are added, this would lookup the matching
// generic clause on the parent, and add the specific traits constraints on
// the type.
n.typ = &GenericType{
Name: n.Ident(),
ScopeID: n.Scope().ID,
}
// ************** EXPRESSIONS *****************
case *FnTypeExpr:
// build the signature type of this function literal
sigt := &SignatureType{
Params: make([]Type, len(n.Params)),
}
for i, param := range n.Params {
Walk(t, param)
sigt.Params[i] = param.Type()
}
if n.Return != nil {
Walk(t, n.Return)
sigt.Return = n.Return.Type()
} else {
sigt.Return = &BasicType{Kind: Void}
}
n.ctx = Typ
n.typ = sigt
case *TupleTypeExpr:
tupt := &TupleType{
Fields: make([]Type, len(n.Fields)),
}
for i, typ := range n.Fields {
Walk(t, typ)
tupt.Fields[i] = typ.Type()
}
n.ctx = Typ
n.typ = tupt
case *TupleVal:
tupt := &TupleType{
Fields: make([]Type, len(n.Values)),
}
for i, val := range n.Values {
Walk(t, val)
tupt.Fields[i] = val.Type()
}
n.ctx = Value
n.typ = tupt
case *Binary:
// assign a type to each operand
Walk(t, n.Left)
Walk(t, n.Right)
var newType Type = unresolvedType{}
lt, rt := AsBasicType(n.Left.Type()), AsBasicType(n.Right.Type())
if lt != nil && rt != nil {
kl, kr := lt.Kind, rt.Kind
if kr < kl {
kl, kr = kr, kl
}
opKind := binaryOpsTable[n.Op][kl][kr]
if bt := (&BasicType{Kind: opKind}); bt.Valid() {
newType = bt
}
}
n.ctx = Value
n.typ = newType
case *Unary:
// start with type of operand
Walk(t, n.Right)
rt := n.Right.Type()
// type of unary is the same as long as operand is an allowed type for the
// operator
if !IsBasicOfKind(rt, unaryOpsTable[n.Op]...) {
rt = unresolvedType{}
}
n.ctx = Value
n.typ = rt
case *Paren:
// type is that of its value, and context remains the same
Walk(t, n.Value)
n.ctx = n.Value.TypeContext()
n.typ = n.Value.Type()
case *Call:
Walk(t, n.Fun)
for _, arg := range n.Args {
Walk(t, arg)
}
// the type of the call expression is the return type of the signature type
// of n.Fun, if it is a signature type, or the type of the struct or
// attribute if it is a struct initializer or func attribute. Otherwise
// default to unresolved.
lt := n.Fun.Type()
n.ctx = Value
switch lt := lt.(type) {
case *SignatureType:
n.typ = lt.Return
case *StructType:
n.typ = lt
n.InitOf = lt.Decl
default:
n.typ = unresolvedType{}
}
case *Selector:
slhs := t.selectorLHS // flag can be recursive, do not set to false after Walk, set to old value
t.selectorLHS = true
Walk(t, n.Left)
t.selectorLHS = slhs
// must not call walk on n.Sel, as it will not find the right symbol in
// the lookup chain - must look in n.Left.
switch sel := n.Sel.(type) {
case *Ident:
t.typeAssignSelector(n.Left.Type(), n.Left.TypeContext(), sel)
case *GenericInst:
t.typeAssignSelector(n.Left.Type(), n.Left.TypeContext(), sel.GenericDecl)
// at this point sel.Type() should be a generic type and sel.Ref a generic decl,
// now instantiate it.
t.typeAssignGenericInst(sel)
}
n.ctx = selectorTypeContext[n.Left.TypeContext()][n.Sel.TypeContext()]
n.typ = n.Sel.Type()
case *GenericInst:
t.genericInstIdent = true
Walk(t, n.GenericDecl)
t.genericInstIdent = false
t.typeAssignGenericInst(n)
case *Ident:
decl := n.Scope().LookupChain(n.Name, n.Pos())
if decl == nil {
t.errUndefined(n, nil)
break
}
if !t.genericInstIdent && !t.selectorLHS && decl.IsGeneric() {
// cannot access a generic identifier without an instantiation, unless it is used
// on the left-hand side of a selector to access an inner member.
t.errh(n.Pos(), "cannot use generic declaration without instantiation")
}
n.Ref = decl
n.ctx = decl.TypeContext()
if n.typ = decl.Type(); n.typ == nil {
n.typ = unresolvedType{}
}
case *LitString:
n.ctx = Constant
n.typ = &BasicType{Kind: String}
case *LitInt:
n.ctx = Constant
n.typ = &BasicType{Kind: Int}
default:
if n != nil {
panic(fmt.Sprintf("invalid node type: %T", n))
}
}
if typed, ok := n.(Typed); ok {
if !typed.Type().Valid() {
t.addDeferred(n)
}
}
return nil
}
// returns the concrete type of an instantiated generic declaration.
func (t *typeassignVisitor) instantiateGeneric(gi *GenericInst, gen GenericDecl, types []Type) Type {
// make the map of generic type names to actual types
gc := gen.GenClause()
resolve := makeGenericResolveMap(gc, types)
if count := len(types); count < len(gc.Elems) || len(types) > len(gc.Elems) {
t.errh(gi.Pos(), fmt.Sprintf("wrong number of types provided in generic instantiation, want %d, got %d", len(gc.Elems), len(types)))
}
switch gen := gen.(type) {
case *Fn:
if gen.GenericInsts == nil {
gen.GenericInsts = make(map[*GenericInst][]Type)
}
gen.GenericInsts[gi] = types
return gen.Type().resolveGeneric(resolve)
case *Struct:
if gen.GenericInsts == nil {
gen.GenericInsts = make(map[*GenericInst][]Type)
}
gen.GenericInsts[gi] = types
return &StructType{Decl: gen, Inst: types}
default:
t.errh(gi.Pos(), fmt.Sprintf("invalid generic declaration type: %T", gen))
}
return unresolvedType{}
}
func (t *typeassignVisitor) typeAssignGenericInst(gi *GenericInst) {
gi.ctx = Invalid
gi.typ = unresolvedType{}
if !gi.GenericDecl.Type().Valid() {
return
}
genRef := gi.GenericDecl.Ref
if genRef == nil || !genRef.Type().Valid() {
return
}
if !genRef.IsGeneric() {
t.errh(gi.Pos(), "invalid instantiation of a non-generic declaration")
return
}
types := make([]Type, len(gi.TypeExprs))
for i, te := range gi.TypeExprs {
Walk(t, te)
types[i] = te.Type()
}
ctx := genRef.TypeContext()
switch {
case ctx == Typ:
gi.ctx = Typ // instantiated struct generics are types
case ctx.isAnyOf(TypeContextValues...):
gi.ctx = Value // instantiated fn generics are values
}
gi.typ = t.instantiateGeneric(gi, genRef.(GenericDecl), types)
}
func (t *typeassignVisitor) typeAssignSelector(left Type, leftCtx TypeContext, sel *Ident) {
// do not add sel as deferred - this is not walked as an Ident, but as part
// of the Selector expr. If Sel is marked unresolved, then the Selector
// expression will be unresolved too and will be properly added to deferred.
switch left := left.(type) {
case *StructType:
decl := left.Decl.BodyScope.Lookup(sel.Name)
if decl != nil {
sel.ctx = decl.TypeContext()
sel.typ = left.typeOfSel(decl.Type())
sel.Ref = decl
return
}
t.errUndefined(sel, left)
case *TupleType:
sel.ctx = leftCtx
if sel.Index < 0 || sel.Index >= len(left.Fields) {
t.errUndefined(sel, left)
return
}
sel.typ = left.Fields[sel.Index]
default:
sel.ctx = Invalid
sel.typ = unresolvedType{}
}
}