@@ 4,7 4,7 @@
module Infer (inferTopDefs, checkType', checkType'') where
import Prelude hiding (span)
-import Lens.Micro.Platform (makeLenses, over, view, mapped, to)
+import Lens.Micro.Platform (makeLenses, over, view, mapped, to, Lens')
import Control.Applicative hiding (Const(..))
import Control.Monad.Except
import Control.Monad.Reader
@@ 65,7 65,9 @@ inferTopDefs tdefs ctors externs defs =
, _envLocalDefs = Map.empty
, _envCtors = ctors
}
- in evalStateT (runReaderT (fmap fst (runWriterT (inferGlobalDefs defs))) initEnv) 0
+ in evalStateT
+ (runReaderT (fmap fst (runWriterT (inferDefs envGlobDefs defs))) initEnv)
+ 0
checkType :: SrcPos -> Parsed.Type -> Infer Type
checkType pos t = view envTypeDefs >>= \tds -> checkType' tds pos t
@@ 93,29 95,14 @@ checkType'' tdefsParams pos = go
else throwError (TypeInstArityMismatch pos x expectedN foundN)
Nothing -> throwError (UndefType pos x)
-inferGlobalDefs :: [Parsed.Def] -> Infer Defs
-inferGlobalDefs defs = do
- checkNoDuplicateDefs defs
- let ordered = orderDefs defs
- foldr
- (\scc inferRest -> do
- (def, constraints) <- censor (const []) $ listen $ inferComponent scc
- sub <- lift $ lift $ lift $ solve constraints
- let def' = substDef sub def
- Topo rest <- augment envGlobDefs (Map.fromList (defSigs def)) inferRest
- pure (Topo (def' : rest))
- )
- (pure (Topo []))
- ordered
-
-inferLocalDefs :: [Parsed.Def] -> Infer Defs
-inferLocalDefs defs = do
+inferDefs :: Lens' Env (Map String Scheme) -> [Parsed.Def] -> Infer Defs
+inferDefs envDefs defs = do
checkNoDuplicateDefs defs
let ordered = orderDefs defs
foldr
(\scc inferRest -> do
def <- inferComponent scc
- Topo rest <- augment envLocalDefs (Map.fromList (defSigs def)) inferRest
+ Topo rest <- augment envDefs (Map.fromList (defSigs def)) inferRest
pure (Topo (def : rest))
)
(pure (Topo []))
@@ 157,7 144,8 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
sub <- lift $ lift $ lift $ solve cs
env <- view envLocalDefs
let scm = generalize (substEnv sub env) (subst sub t)
- pure (idstr lhs, WithPos defPos (scm, body'))
+ let body'' = substExpr sub body'
+ pure (idstr lhs, WithPos defPos (scm, body''))
inferRecDefs' ds = do
ts <- replicateM (length ds) fresh
@@ 167,7 155,8 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
sub <- lift $ lift $ lift $ solve cs
env <- view envLocalDefs
let scms = map (generalize (substEnv sub env) . subst sub) ts
- pure (zip names (zipWith3 (curry . WithPos) poss scms fs))
+ let fs' = map (mapPosd (substFunMatch sub)) fs
+ pure (zip names (zipWith3 (curry . WithPos) poss scms fs'))
inferRecDef :: Type -> Parsed.Def -> Infer (WithPos FunMatch)
inferRecDef t = uncurry $ \(Id lhs) -> unpos >>> \case
@@ 222,7 211,7 @@ infer (WithPos pos e) = fmap (second (WithPos pos)) $ case e of
(t, body') <- augment1 envLocalDefs (defSig def') (infer body)
pure (t, Let (VarDef def') body')
Parsed.LetRec defs b -> do
- Topo defs' <- inferLocalDefs defs
+ Topo defs' <- inferDefs envLocalDefs defs
let withDef def inferX = do
(tx, x') <- withLocals (defSigs def) inferX
pure (tx, WithPos pos (Let def x'))
@@ 1,6 1,6 @@
{-# LANGUAGE LambdaCase #-}
-module Subst (Subst, subst, substDef, composeSubsts) where
+module Subst (Subst, subst, substExpr, substFunMatch, composeSubsts) where
import qualified Data.Map as Map
import Data.Map (Map)