~jojo/Carth

e6d89630754e050627e6e1b98567ac1b54146a65 — JoJo 1 year, 1 month ago 3404f50
Infer: Separate constraint solver from inference logic

Sdiehl suggested it in a tutorial. Makes sense. Performance mostly
unchanged. Need to make solving faster / do less work.
2 files changed, 121 insertions(+), 114 deletions(-)

M src/Infer.hs
M src/Subst.hs
M src/Infer.hs => src/Infer.hs +121 -112
@@ 4,19 4,18 @@
module Infer (inferTopDefs, checkType', checkType'') where

import Prelude hiding (span)
import Lens.Micro.Platform (assign, makeLenses, over, use, view, mapped, to, Lens')
import Lens.Micro.Platform (makeLenses, over, view, mapped, to)
import Control.Applicative hiding (Const(..))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor
import Data.Functor
import Data.Graph (SCC(..), stronglyConnComp)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Maybe
import qualified Data.Set as Set
import Data.Set (Set)
import Control.Arrow ((>>>))

import Misc


@@ 33,9 32,11 @@ import TypeAst hiding (TConst)
newtype ExpectedType = Expected Type
data FoundType = Found SrcPos Type

type Constraint = (ExpectedType, FoundType)

data Env = Env
    { _envTypeDefs :: TypeDefs
    -- Separarate global defs and local defs, because ~generalize~ only has to look as
    -- Separarate global defs and local defs, because `generalize` only has to look at
    -- local defs.
    , _envGlobDefs :: Map String Scheme
    , _envLocalDefs :: Map String Scheme


@@ 47,30 48,24 @@ data Env = Env
    }
makeLenses ''Env

data St = St
    { _tvCount :: Word
    , _substs :: Subst
    }
    deriving (Show)
makeLenses ''St
type TVCount = Word

type Infer a = ReaderT Env (StateT St (Except TypeErr)) a
type Infer a = WriterT [Constraint] (ReaderT Env (StateT TVCount (Except TypeErr))) a

------------------------------------------------------------------------------------------
-- Inference
------------------------------------------------------------------------------------------

inferTopDefs :: TypeDefs -> Ctors -> Externs -> [Parsed.Def] -> Except TypeErr Defs
inferTopDefs tdefs ctors externs defs =
    let initEnv = Env { _envTypeDefs = tdefs
                      , _envGlobDefs = builtinVirtuals
                      , _envLocalDefs = Map.empty
                      , _envCtors = ctors
                      }
        initSt = St { _tvCount = 0, _substs = Map.empty }
    in  evalStateT (runReaderT inferTopDefs' initEnv) initSt
  where
    inferTopDefs' = do
        let externs' = fmap (first (Forall Set.empty)) externs
        defs'' <- augment envGlobDefs (fmap fst externs') (inferDefs envGlobDefs defs)
        pure defs''
    let initEnv = Env
            { _envTypeDefs = tdefs
            , _envGlobDefs = Map.union (fmap (Forall Set.empty . fst) externs)
                                       builtinVirtuals
            , _envLocalDefs = Map.empty
            , _envCtors = ctors
            }
    in  evalStateT (runReaderT (fmap fst (runWriterT (inferGlobalDefs defs))) initEnv) 0

checkType :: SrcPos -> Parsed.Type -> Infer Type
checkType pos t = view envTypeDefs >>= \tds -> checkType' tds pos t


@@ 98,13 93,37 @@ checkType'' tdefsParams pos = go
                else throwError (TypeInstArityMismatch pos x expectedN foundN)
        Nothing -> throwError (UndefType pos x)

inferDefs :: Lens' Env (Map String Scheme) -> [Parsed.Def] -> Infer Defs
inferDefs envDefs defs = do
inferGlobalDefs :: [Parsed.Def] -> Infer Defs
inferGlobalDefs defs = do
    checkNoDuplicateDefs defs
    let ordered = orderDefs defs
    inferDefsComponents envDefs ordered
    foldr
        (\scc inferRest -> do
            (def, constraints) <- censor (const []) $ listen $ inferComponent scc
            sub <- lift $ lift $ lift $ solve constraints
            let def' = substDef sub def
            Topo rest <- augment envGlobDefs (Map.fromList (defSigs def)) inferRest
            pure (Topo (def' : rest))
        )
        (pure (Topo []))
        ordered

inferLocalDefs :: [Parsed.Def] -> Infer Defs
inferLocalDefs defs = do
    checkNoDuplicateDefs defs
    let ordered = orderDefs defs
    foldr
        (\scc inferRest -> do
            def <- inferComponent scc
            Topo rest <- augment envLocalDefs (Map.fromList (defSigs def)) inferRest
            pure (Topo (def : rest))
        )
        (pure (Topo []))
        ordered

checkNoDuplicateDefs :: [Parsed.Def] -> Infer ()
checkNoDuplicateDefs = checkNoDuplicateDefs' Set.empty
  where
    checkNoDuplicateDefs = checkNoDuplicateDefs' Set.empty
    checkNoDuplicateDefs' already = \case
        (Id (WithPos p x), _) : ds -> if Set.member x already
            then throwError (ConflictingVarDef p x)


@@ 123,22 142,10 @@ orderDefs :: [Parsed.Def] -> [SCC Parsed.Def]
orderDefs = stronglyConnComp . graph
    where graph = map (\d@(n, _) -> (d, n, Set.toList (freeVars d)))

inferDefsComponents :: Lens' Env (Map String Scheme) -> [SCC Parsed.Def] -> Infer Defs
inferDefsComponents envDefs = flip foldr (pure (Topo [])) $ \scc inferRest -> do
    def <- inferComponent scc
    Topo rest <- augment envDefs (Map.fromList (defSigs def)) inferRest
    pure (Topo (def : rest))
  where
    inferComponent :: SCC Parsed.Def -> Infer Def
    inferComponent d = do
        -- TODO: Why is this fine even when we're in a LetRec? Seems like it would mess
        --       things up seriously? Or do I just not cover any cases like this in my
        --       tests?
        assign substs Map.empty
        d' <- case d of
            AcyclicSCC vert -> fmap VarDef (inferVarDef vert)
            CyclicSCC verts -> fmap RecDefs (inferRecDefs verts)
        use substs <&> \s -> substDef s d'
inferComponent :: SCC Parsed.Def -> Infer Def
inferComponent = \case
    AcyclicSCC vert -> fmap VarDef (inferVarDef vert)
    CyclicSCC verts -> fmap RecDefs (inferRecDefs verts)

inferVarDef :: Parsed.Def -> Infer VarDef
inferRecDefs :: [Parsed.Def] -> Infer RecDefs


@@ 146,17 153,20 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
  where
    inferVarDef' (lhs, WithPos defPos (mayscm, body)) = do
        t <- fresh
        body' <- inferDef t lhs mayscm (getPos body) (infer body)
        scm <- generalize t
        (body', cs) <- listen $ inferDef t lhs mayscm (getPos body) (infer body)
        sub <- lift $ lift $ lift $ solve cs
        env <- view envLocalDefs
        let scm = generalize (substEnv sub env) (subst sub t)
        pure (idstr lhs, WithPos defPos (scm, body'))

    inferRecDefs' ds = do
        ts <- replicateM (length ds) fresh
        let dummyScms = map (Forall Set.empty) ts
        let (names, poss) = unzip (map (bimap idstr getPos) ds)
        fs <- augment envLocalDefs (Map.fromList (zip names dummyScms))
            $ zipWithM inferRecDef ts ds
        scms <- mapM generalize ts
        let dummyDefs = Map.fromList (zip names (map (Forall Set.empty) ts))
        (fs, cs) <- listen $ augment envLocalDefs dummyDefs $ zipWithM inferRecDef ts ds
        sub <- lift $ lift $ lift $ solve cs
        env <- view envLocalDefs
        let scms = map (generalize (substEnv sub env) . subst sub) ts
        pure (zip names (zipWith3 (curry . WithPos) poss scms fs))

    inferRecDef :: Type -> Parsed.Def -> Infer (WithPos FunMatch)


@@ 183,7 193,8 @@ checkScheme = curry $ \case
    (_, Just (Parsed.Forall pos vs t)) -> do
        t' <- checkType pos t
        let s1 = Forall vs t'
        s2 <- generalize t'
        env <- view envLocalDefs
        let s2 = generalize env t'
        if (s1 == s2) then pure (Just s1) else throwError (InvalidUserTypeSig pos s1 s2)

infer :: Parsed.Expr -> Infer (Type, Expr)


@@ 211,7 222,7 @@ infer (WithPos pos e) = fmap (second (WithPos pos)) $ case e of
        (t, body') <- augment1 envLocalDefs (defSig def') (infer body)
        pure (t, Let (VarDef def') body')
    Parsed.LetRec defs b -> do
        Topo defs' <- inferDefs envLocalDefs defs
        Topo defs' <- inferLocalDefs defs
        let withDef def inferX = do
                (tx, x') <- withLocals (defSigs def) inferX
                pure (tx, WithPos pos (Let def x'))


@@ 345,78 356,76 @@ lookupEnv (Id (WithPos pos x)) = do
withLocals :: [(String, Scheme)] -> Infer a -> Infer a
withLocals = augment envLocalDefs . Map.fromList

unify :: ExpectedType -> FoundType -> Infer ()
unify (Expected t1) (Found pos t2) = do
    s1 <- use substs
    s2 <- unify' (Expected (subst s1 t1)) (Found pos (subst s1 t2))
    assign substs (composeSubsts s2 s1)

unify' :: ExpectedType -> FoundType -> Infer Subst
unify' (Expected t1) (Found pos t2) = lift $ lift $ withExcept
    (\case
        InfiniteType'' a t -> InfType pos t1 t2 a t
        UnificationFailed'' t'1 t'2 -> UnificationFailed pos t1 t2 t'1 t'2
    )
    (unify'' t1 t2)

data UnifyErr'' = InfiniteType'' TVar Type | UnificationFailed'' Type Type

unify'' :: Type -> Type -> Except UnifyErr'' Subst
unify'' = curry $ \case
    (TPrim a, TPrim b) | a == b -> pure Map.empty
    (TConst (c0, ts0), TConst (c1, ts1)) | c0 == c1 -> if length ts0 /= length ts1
        then ice "lengths of TConst params differ in unify"
        else unifys ts0 ts1
    (TVar a, TVar b) | a == b -> pure Map.empty
    (TVar a, t) | occursIn a t -> throwError (InfiniteType'' a t)
    -- Do not allow "override" of explicit (user given) type variables.
    (a@(TVar (TVExplicit _)), b@(TVar (TVImplicit _))) -> unify'' b a
    (a@(TVar (TVExplicit _)), b) -> throwError (UnificationFailed'' a b)
    (TVar a, t) -> pure (Map.singleton a t)
    (t, TVar a) -> unify'' (TVar a) t
    (TFun t1 t2, TFun u1 u2) -> unifys [t1, t2] [u1, u2]
    (TBox t, TBox u) -> unify'' t u
    (t1, t2) -> throwError (UnificationFailed'' t1 t2)

unifys :: [Type] -> [Type] -> Except UnifyErr'' Subst
unifys ts us = foldM
    (\s (t, u) -> fmap (flip composeSubsts s) (unify'' (subst s t) (subst s u)))
    Map.empty
    (zip ts us)

occursIn :: TVar -> Type -> Bool
occursIn a t = Set.member a (ftv t)

instantiate :: Scheme -> Infer Type
instantiate (Forall params t) = do
    let params' = Set.toList params
    args <- mapM (const fresh) params'
    pure (subst (Map.fromList (zip params' args)) t)

generalize :: Type -> Infer Scheme
generalize t = do
    env <- ask
    s <- use substs
    let t' = subst s t
    pure (Forall (generalize' (substEnv s env) t') t')
generalize :: Map String Scheme -> Type -> Scheme
generalize env t = Forall (Set.difference (ftv t) (ftvEnv env)) t
  where
    generalize' :: Env -> Type -> Set TVar
    generalize' env t = Set.difference (ftv t) (ftvEnv env)

    substEnv :: Subst -> Env -> Env
    substEnv s = over (envLocalDefs . mapped . scmBody) (subst s)

    ftvEnv :: Env -> Set TVar
    ftvEnv = Set.unions . map (ftvScheme . snd) . Map.toList . view envLocalDefs

    ftvScheme :: Scheme -> Set TVar
    ftvEnv env = Set.unions (map ftvScheme (Map.elems env))
    ftvScheme (Forall tvs t) = Set.difference (ftv t) tvs

substEnv :: Subst -> Map String Scheme -> Map String Scheme
substEnv s = over (mapped . scmBody) (subst s)

fresh :: Infer Type
fresh = fmap TVar fresh'

fresh' :: Infer TVar
fresh' = fmap TVImplicit fresh''
fresh' = fmap TVImplicit (get <* modify (+ 1))

fresh'' :: Infer Word
fresh'' = tvCount <<+= 1
unify :: ExpectedType -> FoundType -> Infer ()
unify e f = tell [(e, f)]

------------------------------------------------------------------------------------------
-- Constraint solver
------------------------------------------------------------------------------------------

data UnifyErr = UInfType TVar Type | UFailed Type Type

solve :: [Constraint] -> Except TypeErr Subst
solve = solve' Map.empty
  where
    solve' :: Subst -> [Constraint] -> Except TypeErr Subst
    solve' sub1 = \case
        [] -> pure sub1
        (Expected et, Found pos ft) : cs -> do
            sub2 <- withExcept (toTypeErr pos et ft) (unifies et ft)
            solve' (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)

    substConstraint sub (Expected t1, Found pos t2) =
        (Expected (subst sub t1), Found pos (subst sub t2))

toTypeErr :: SrcPos -> Type -> Type -> UnifyErr -> TypeErr
toTypeErr pos t1 t2 = \case
    UInfType a t -> InfType pos t1 t2 a t
    UFailed t'1 t'2 -> UnificationFailed pos t1 t2 t'1 t'2

unifies :: Type -> Type -> Except UnifyErr Subst
unifies = curry $ \case
    (TPrim a, TPrim b) | a == b -> pure Map.empty
    (TConst (c0, ts0), TConst (c1, ts1)) | c0 == c1 -> if length ts0 /= length ts1
        then ice "lengths of TConst params differ in unify"
        else unifiesMany ts0 ts1
    (TVar a, TVar b) | a == b -> pure Map.empty
    (TVar a, t) | occursIn a t -> throwError (UInfType a t)
    -- Do not allow "override" of explicit (user given) type variables.
    (a@(TVar (TVExplicit _)), b@(TVar (TVImplicit _))) -> unifies b a
    (a@(TVar (TVExplicit _)), b) -> throwError (UFailed a b)
    (TVar a, t) -> pure (Map.singleton a t)
    (t, TVar a) -> unifies (TVar a) t
    (TFun t1 t2, TFun u1 u2) -> unifiesMany [t1, t2] [u1, u2]
    (TBox t, TBox u) -> unifies t u
    (t1, t2) -> throwError (UFailed t1 t2)

unifiesMany :: [Type] -> [Type] -> Except UnifyErr Subst
unifiesMany ts us = foldM
    (\s (t, u) -> fmap (flip composeSubsts s) (unifies (subst s t) (subst s u)))
    Map.empty
    (zip ts us)

occursIn :: TVar -> Type -> Bool
occursIn a t = Set.member a (ftv t)

M src/Subst.hs => src/Subst.hs +0 -2
@@ 10,11 10,9 @@ import Data.Maybe
import SrcPos
import Inferred


-- | Map of substitutions from type-variables to more specific types
type Subst = Map TVar Type


substDef :: Subst -> Def -> Def
substDef s = \case
    VarDef d -> VarDef (second (mapPosd (second (substExpr s))) d)