@@ 200,42 200,52 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
params
inferNonrecDef' (Parsed.VarDef dpos lhs mayscm body) = do
t <- fresh
- (body', cs) <- listen $ inferDef t lhs mayscm (getPos body) (infer body)
+ mayscm' <- checkScheme (idstr lhs) mayscm
+ (body', cs) <- listen $ inferDef t mayscm' (getPos body) (infer body)
(sub, ccs) <- solve cs
env <- view envLocalDefs
- let scm = generalize (substEnv sub env) ccs (subst sub t)
+ scm <- generalize (substEnv sub env)
+ (fmap _scmConstraints mayscm')
+ ccs
+ (subst sub t)
let body'' = substExpr sub body'
pure (idstr lhs, WithPos dpos (scm, body''))
inferRecDefs' ds = do
- ts <- replicateM (length ds) fresh
- let (names, poss) = unzip $ flip map ds $ \case
- Parsed.FunDef p x _ _ _ -> (idstr x, p)
- Parsed.VarDef p x _ _ -> (idstr x, p)
- let dummyDefs = Map.fromList (zip names (map (Forall Set.empty Set.empty) ts))
- (fs, ucs) <- listen $ augment envLocalDefs dummyDefs $ zipWithM inferRecDef ts ds
+ (names, poss, mayscms', ts) <- fmap unzip4 $ forM ds $ \d -> do
+ let (name, pos, mayscm) = case d of
+ Parsed.FunDef p x s _ _ -> (idstr x, p, s)
+ Parsed.VarDef p x s _ -> (idstr x, p, s)
+ t <- fresh
+ mayscm' <- checkScheme name mayscm
+ pure (name, pos, mayscm', t)
+ let dummyDefs = Map.fromList $ zip names (map (Forall Set.empty Set.empty) ts)
+ (fs, ucs) <- listen $ augment envLocalDefs dummyDefs $ mapM
+ (uncurry3 inferRecDef)
+ (zip3 mayscms' ts ds)
(sub, cs) <- solve ucs
env <- view envLocalDefs
- let scms = map (generalize (substEnv sub env) cs . subst sub) ts
+ scms <- zipWithM
+ (\s -> generalize (substEnv sub env) (fmap _scmConstraints s) cs . subst sub)
+ mayscms'
+ ts
let fs' = map (mapPosd (substFunMatch sub)) fs
pure (zip names (zipWith3 (curry . WithPos) poss scms fs'))
- inferRecDef :: Type -> Parsed.Def -> Infer (WithPos FunMatch)
- inferRecDef t = \case
- Parsed.FunDef fpos lhs mayscm params body ->
+ inferRecDef :: Maybe Scheme -> Type -> Parsed.Def -> Infer (WithPos FunMatch)
+ inferRecDef mayscm t = \case
+ Parsed.FunDef fpos _ _ params body ->
let (initps, lastp) = fromJust $ unsnoc params
- in fmap (WithPos fpos) $ inferDef t lhs mayscm fpos $ inferFunMatch $ foldr
+ in fmap (WithPos fpos) $ inferDef t mayscm fpos $ inferFunMatch $ foldr
(\p cs -> [(p, WithPos fpos (Parsed.FunMatch cs))])
[(lastp, body)]
initps
- Parsed.VarDef fpos lhs mayscm (WithPos _ (Parsed.FunMatch cs)) ->
- fmap (WithPos fpos) $ inferDef t lhs mayscm fpos (inferFunMatch cs)
+ Parsed.VarDef fpos _ _ (WithPos _ (Parsed.FunMatch cs)) ->
+ fmap (WithPos fpos) $ inferDef t mayscm fpos (inferFunMatch cs)
Parsed.VarDef _ (Id lhs) _ _ -> throwError (RecursiveVarDef lhs)
- inferDef t lhs mayscm bodyPos inferBody = do
- checkScheme (idstr lhs) mayscm >>= \case
- Just (Forall _ _ scmt) -> unify (Expected scmt) (Found bodyPos t)
- Nothing -> pure ()
+ inferDef t mayscm bodyPos inferBody = do
+ whenJust mayscm $ \(Forall _ _ scmt) -> unify (Expected scmt) (Found bodyPos t)
(t', body') <- inferBody
unify (Expected t) (Found bodyPos t')
pure body'
@@ 254,7 264,7 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
t' <- checkType pos t
let s1 = Forall vs Set.empty t'
env <- view envLocalDefs
- let s2 = generalize env Set.empty t'
+ s2 <- generalize env (Just (_scmConstraints s1)) Map.empty t'
if (s1 == s2)
then pure (Just s1)
else throwError (InvalidUserTypeSig pos s1 s2)
@@ 432,8 442,14 @@ instantiate pos (Forall params constraints t) = do
forM_ constraints $ \c -> unifyClass pos (substClassConstraint s c)
pure (subst s t)
-generalize :: Map String Scheme -> Set ClassConstraint -> Type -> Scheme
-generalize env allConstraints t = Forall vs constraints t
+generalize
+ :: (MonadError TypeErr m)
+ => Map String Scheme
+ -> Maybe (Set ClassConstraint)
+ -> Map ClassConstraint SrcPos
+ -> Type
+ -> m Scheme
+generalize env mayGivenCs allCs t = fmap (\cs -> Forall vs cs t) constraints
where
-- A constraint should be included in a signature if the type variables include at
-- least one of the signature's forall-qualified tvars, and the rest of the tvars
@@ 443,11 459,20 @@ generalize env allConstraints t = Forall vs constraints t
--
-- TODO: Maybe we should handle the propagation of class constraints in a better way,
-- so that ones belonging to inner definitions no longer exist at this point.
- constraints :: Set ClassConstraint
- constraints = flip Set.filter allConstraints $ \c ->
- let vcs = ftvClassConstraint c
- in any (flip Set.member vs) vcs
- && all (\vc -> Set.member vc vs || Set.member vc ftvEnv) vcs
+ constraints =
+ fmap (Set.fromList . map fst) $ flip filterM (Map.toList allCs) $ \(c, pos) ->
+ let vcs = ftvClassConstraint c
+ belongs =
+ any (flip Set.member vs) vcs
+ && all (\vc -> Set.member vc vs || Set.member vc ftvEnv) vcs
+ in if belongs
+ then if matchesGiven c
+ then pure True
+ else throwError (NoClassInstance pos c)
+ else pure False
+ matchesGiven = case mayGivenCs of
+ Just gcs -> flip Set.member gcs
+ Nothing -> const True
vs = Set.difference (ftv t) ftvEnv
ftvEnv = Set.unions (map ftvScheme (Map.elems env))
ftvScheme (Forall tvs _ t) = Set.difference (ftv t) tvs
@@ 479,7 504,7 @@ unifyClass p c = tell ([], [(p, c)])
data UnifyErr = UInfType TVar Type | UFailed Type Type
-solve :: Constraints -> Infer (Subst, Set ClassConstraint)
+solve :: Constraints -> Infer (Subst, (Map ClassConstraint SrcPos))
solve (eqcs, ccs) = do
sub <- lift $ lift $ lift $ solveUnis Map.empty eqcs
ccs' <- solveClassCs (map (second (substClassConstraint sub)) ccs)
@@ 492,24 517,24 @@ solve (eqcs, ccs) = do
sub2 <- withExcept (toTypeErr pos et ft) (unifies et ft)
solveUnis (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)
- solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Set ClassConstraint)
- solveClassCs = fmap Set.unions . mapM
+ solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Map ClassConstraint SrcPos)
+ solveClassCs = fmap Map.unions . mapM
(\(pos, c) -> case c of
("SameSize", [ta, tb]) -> sameSize pos ta tb
("SameSize", ts) -> ice $ "solveClassCs: invalid SameSize " ++ show ts
_ -> ice $ "solveClassCs: unknown class constraint " ++ show c
)
- sameSize :: SrcPos -> Type -> Type -> Infer (Set ClassConstraint)
+ sameSize :: SrcPos -> Type -> Type -> Infer (Map ClassConstraint SrcPos)
sameSize pos ta tb = do
sizeof' <- fmap sizeof (view envTypeDefs)
let c = ("SameSize", [ta, tb])
case liftA2 (==) (sizeof' ta) (sizeof' tb) of
- Just True -> pure Set.empty
+ Just True -> pure Map.empty
Just False -> throwError (NoClassInstance pos c)
-- One or both of the two types are of unknown size due to polymorphism, so
-- propagate the constraint to the scheme of the definition.
- Nothing -> pure (Set.singleton c)
+ Nothing -> pure (Map.singleton c pos)
substConstraint sub (Expected t1, Found pos t2) =
(Expected (subst sub t1), Found pos (subst sub t2))
@@ 111,3 111,9 @@ maximumOr :: Ord a => a -> [a] -> a
maximumOr b = \case
[] -> b
as -> maximum as
+
+whenJust :: Monad m => Maybe a -> (a -> m ()) -> m ()
+whenJust = flip (maybe (pure ()))
+
+uncurry3 :: (a -> b -> c -> d) -> ((a, b, c) -> d)
+uncurry3 f (a, b, c) = f a b c