~jojo/Carth

a07d1ce73f8cb27624ea7647d1a339c2af537ca5 — JoJo 2 years ago e6d8963
Infer: Minor refactor
2 files changed, 13 insertions(+), 24 deletions(-)

M src/Infer.hs
M src/Subst.hs
M src/Infer.hs => src/Infer.hs +12 -23
@@ 4,7 4,7 @@
module Infer (inferTopDefs, checkType', checkType'') where

import Prelude hiding (span)
import Lens.Micro.Platform (makeLenses, over, view, mapped, to)
import Lens.Micro.Platform (makeLenses, over, view, mapped, to, Lens')
import Control.Applicative hiding (Const(..))
import Control.Monad.Except
import Control.Monad.Reader


@@ 65,7 65,9 @@ inferTopDefs tdefs ctors externs defs =
            , _envLocalDefs = Map.empty
            , _envCtors = ctors
            }
    in  evalStateT (runReaderT (fmap fst (runWriterT (inferGlobalDefs defs))) initEnv) 0
    in  evalStateT
            (runReaderT (fmap fst (runWriterT (inferDefs envGlobDefs defs))) initEnv)
            0

checkType :: SrcPos -> Parsed.Type -> Infer Type
checkType pos t = view envTypeDefs >>= \tds -> checkType' tds pos t


@@ 93,29 95,14 @@ checkType'' tdefsParams pos = go
                else throwError (TypeInstArityMismatch pos x expectedN foundN)
        Nothing -> throwError (UndefType pos x)

inferGlobalDefs :: [Parsed.Def] -> Infer Defs
inferGlobalDefs defs = do
    checkNoDuplicateDefs defs
    let ordered = orderDefs defs
    foldr
        (\scc inferRest -> do
            (def, constraints) <- censor (const []) $ listen $ inferComponent scc
            sub <- lift $ lift $ lift $ solve constraints
            let def' = substDef sub def
            Topo rest <- augment envGlobDefs (Map.fromList (defSigs def)) inferRest
            pure (Topo (def' : rest))
        )
        (pure (Topo []))
        ordered

inferLocalDefs :: [Parsed.Def] -> Infer Defs
inferLocalDefs defs = do
inferDefs :: Lens' Env (Map String Scheme) -> [Parsed.Def] -> Infer Defs
inferDefs envDefs defs = do
    checkNoDuplicateDefs defs
    let ordered = orderDefs defs
    foldr
        (\scc inferRest -> do
            def <- inferComponent scc
            Topo rest <- augment envLocalDefs (Map.fromList (defSigs def)) inferRest
            Topo rest <- augment envDefs (Map.fromList (defSigs def)) inferRest
            pure (Topo (def : rest))
        )
        (pure (Topo []))


@@ 157,7 144,8 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
        sub <- lift $ lift $ lift $ solve cs
        env <- view envLocalDefs
        let scm = generalize (substEnv sub env) (subst sub t)
        pure (idstr lhs, WithPos defPos (scm, body'))
        let body'' = substExpr sub body'
        pure (idstr lhs, WithPos defPos (scm, body''))

    inferRecDefs' ds = do
        ts <- replicateM (length ds) fresh


@@ 167,7 155,8 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
        sub <- lift $ lift $ lift $ solve cs
        env <- view envLocalDefs
        let scms = map (generalize (substEnv sub env) . subst sub) ts
        pure (zip names (zipWith3 (curry . WithPos) poss scms fs))
        let fs' = map (mapPosd (substFunMatch sub)) fs
        pure (zip names (zipWith3 (curry . WithPos) poss scms fs'))

    inferRecDef :: Type -> Parsed.Def -> Infer (WithPos FunMatch)
    inferRecDef t = uncurry $ \(Id lhs) -> unpos >>> \case


@@ 222,7 211,7 @@ infer (WithPos pos e) = fmap (second (WithPos pos)) $ case e of
        (t, body') <- augment1 envLocalDefs (defSig def') (infer body)
        pure (t, Let (VarDef def') body')
    Parsed.LetRec defs b -> do
        Topo defs' <- inferLocalDefs defs
        Topo defs' <- inferDefs envLocalDefs defs
        let withDef def inferX = do
                (tx, x') <- withLocals (defSigs def) inferX
                pure (tx, WithPos pos (Let def x'))

M src/Subst.hs => src/Subst.hs +1 -1
@@ 1,6 1,6 @@
{-# LANGUAGE LambdaCase #-}

module Subst (Subst, subst, substDef, composeSubsts) where
module Subst (Subst, subst, substExpr, substFunMatch, composeSubsts) where

import qualified Data.Map as Map
import Data.Map (Map)