~jojo/Carth

c5a394b01aa60cb95605d66f6118cfb231955aa7 — JoJo 2 years ago ce0a08a
Fix class constraint inference bug when explicit scheme given

See test/tests/bad/transmute-size-mismatch2.carth. This didn't fail
before. Users type sig was not respected, and (SameSize a b) was
inferred and added to the scheme implicitly.
3 files changed, 65 insertions(+), 38 deletions(-)

M src/Infer.hs
M src/Misc.hs
M test/tests/bad/transmute-size-mismatch2.carth
M src/Infer.hs => src/Infer.hs +58 -33
@@ 200,42 200,52 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
            params
    inferNonrecDef' (Parsed.VarDef dpos lhs mayscm body) = do
        t <- fresh
        (body', cs) <- listen $ inferDef t lhs mayscm (getPos body) (infer body)
        mayscm' <- checkScheme (idstr lhs) mayscm
        (body', cs) <- listen $ inferDef t mayscm' (getPos body) (infer body)
        (sub, ccs) <- solve cs
        env <- view envLocalDefs
        let scm = generalize (substEnv sub env) ccs (subst sub t)
        scm <- generalize (substEnv sub env)
                          (fmap _scmConstraints mayscm')
                          ccs
                          (subst sub t)
        let body'' = substExpr sub body'
        pure (idstr lhs, WithPos dpos (scm, body''))

    inferRecDefs' ds = do
        ts <- replicateM (length ds) fresh
        let (names, poss) = unzip $ flip map ds $ \case
                Parsed.FunDef p x _ _ _ -> (idstr x, p)
                Parsed.VarDef p x _ _ -> (idstr x, p)
        let dummyDefs = Map.fromList (zip names (map (Forall Set.empty Set.empty) ts))
        (fs, ucs) <- listen $ augment envLocalDefs dummyDefs $ zipWithM inferRecDef ts ds
        (names, poss, mayscms', ts) <- fmap unzip4 $ forM ds $ \d -> do
            let (name, pos, mayscm) = case d of
                    Parsed.FunDef p x s _ _ -> (idstr x, p, s)
                    Parsed.VarDef p x s _ -> (idstr x, p, s)
            t <- fresh
            mayscm' <- checkScheme name mayscm
            pure (name, pos, mayscm', t)
        let dummyDefs = Map.fromList $ zip names (map (Forall Set.empty Set.empty) ts)
        (fs, ucs) <- listen $ augment envLocalDefs dummyDefs $ mapM
            (uncurry3 inferRecDef)
            (zip3 mayscms' ts ds)
        (sub, cs) <- solve ucs
        env <- view envLocalDefs
        let scms = map (generalize (substEnv sub env) cs . subst sub) ts
        scms <- zipWithM
            (\s -> generalize (substEnv sub env) (fmap _scmConstraints s) cs . subst sub)
            mayscms'
            ts
        let fs' = map (mapPosd (substFunMatch sub)) fs
        pure (zip names (zipWith3 (curry . WithPos) poss scms fs'))

    inferRecDef :: Type -> Parsed.Def -> Infer (WithPos FunMatch)
    inferRecDef t = \case
        Parsed.FunDef fpos lhs mayscm params body ->
    inferRecDef :: Maybe Scheme -> Type -> Parsed.Def -> Infer (WithPos FunMatch)
    inferRecDef mayscm t = \case
        Parsed.FunDef fpos _ _ params body ->
            let (initps, lastp) = fromJust $ unsnoc params
            in  fmap (WithPos fpos) $ inferDef t lhs mayscm fpos $ inferFunMatch $ foldr
            in  fmap (WithPos fpos) $ inferDef t mayscm fpos $ inferFunMatch $ foldr
                    (\p cs -> [(p, WithPos fpos (Parsed.FunMatch cs))])
                    [(lastp, body)]
                    initps
        Parsed.VarDef fpos lhs mayscm (WithPos _ (Parsed.FunMatch cs)) ->
            fmap (WithPos fpos) $ inferDef t lhs mayscm fpos (inferFunMatch cs)
        Parsed.VarDef fpos _ _ (WithPos _ (Parsed.FunMatch cs)) ->
            fmap (WithPos fpos) $ inferDef t mayscm fpos (inferFunMatch cs)
        Parsed.VarDef _ (Id lhs) _ _ -> throwError (RecursiveVarDef lhs)

    inferDef t lhs mayscm bodyPos inferBody = do
        checkScheme (idstr lhs) mayscm >>= \case
            Just (Forall _ _ scmt) -> unify (Expected scmt) (Found bodyPos t)
            Nothing -> pure ()
    inferDef t mayscm bodyPos inferBody = do
        whenJust mayscm $ \(Forall _ _ scmt) -> unify (Expected scmt) (Found bodyPos t)
        (t', body') <- inferBody
        unify (Expected t) (Found bodyPos t')
        pure body'


@@ 254,7 264,7 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
            t' <- checkType pos t
            let s1 = Forall vs Set.empty t'
            env <- view envLocalDefs
            let s2 = generalize env Set.empty t'
            s2 <- generalize env (Just (_scmConstraints s1)) Map.empty t'
            if (s1 == s2)
                then pure (Just s1)
                else throwError (InvalidUserTypeSig pos s1 s2)


@@ 432,8 442,14 @@ instantiate pos (Forall params constraints t) = do
    forM_ constraints $ \c -> unifyClass pos (substClassConstraint s c)
    pure (subst s t)

generalize :: Map String Scheme -> Set ClassConstraint -> Type -> Scheme
generalize env allConstraints t = Forall vs constraints t
generalize
    :: (MonadError TypeErr m)
    => Map String Scheme
    -> Maybe (Set ClassConstraint)
    -> Map ClassConstraint SrcPos
    -> Type
    -> m Scheme
generalize env mayGivenCs allCs t = fmap (\cs -> Forall vs cs t) constraints
  where
    -- A constraint should be included in a signature if the type variables include at
    -- least one of the signature's forall-qualified tvars, and the rest of the tvars


@@ 443,11 459,20 @@ generalize env allConstraints t = Forall vs constraints t
    --
    -- TODO: Maybe we should handle the propagation of class constraints in a better way,
    --       so that ones belonging to inner definitions no longer exist at this point.
    constraints :: Set ClassConstraint
    constraints = flip Set.filter allConstraints $ \c ->
        let vcs = ftvClassConstraint c
        in  any (flip Set.member vs) vcs
                && all (\vc -> Set.member vc vs || Set.member vc ftvEnv) vcs
    constraints =
        fmap (Set.fromList . map fst) $ flip filterM (Map.toList allCs) $ \(c, pos) ->
            let vcs = ftvClassConstraint c
                belongs =
                    any (flip Set.member vs) vcs
                        && all (\vc -> Set.member vc vs || Set.member vc ftvEnv) vcs
            in  if belongs
                    then if matchesGiven c
                        then pure True
                        else throwError (NoClassInstance pos c)
                    else pure False
    matchesGiven = case mayGivenCs of
        Just gcs -> flip Set.member gcs
        Nothing -> const True
    vs = Set.difference (ftv t) ftvEnv
    ftvEnv = Set.unions (map ftvScheme (Map.elems env))
    ftvScheme (Forall tvs _ t) = Set.difference (ftv t) tvs


@@ 479,7 504,7 @@ unifyClass p c = tell ([], [(p, c)])

data UnifyErr = UInfType TVar Type | UFailed Type Type

solve :: Constraints -> Infer (Subst, Set ClassConstraint)
solve :: Constraints -> Infer (Subst, (Map ClassConstraint SrcPos))
solve (eqcs, ccs) = do
    sub <- lift $ lift $ lift $ solveUnis Map.empty eqcs
    ccs' <- solveClassCs (map (second (substClassConstraint sub)) ccs)


@@ 492,24 517,24 @@ solve (eqcs, ccs) = do
            sub2 <- withExcept (toTypeErr pos et ft) (unifies et ft)
            solveUnis (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)

    solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Set ClassConstraint)
    solveClassCs = fmap Set.unions . mapM
    solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Map ClassConstraint SrcPos)
    solveClassCs = fmap Map.unions . mapM
        (\(pos, c) -> case c of
            ("SameSize", [ta, tb]) -> sameSize pos ta tb
            ("SameSize", ts) -> ice $ "solveClassCs: invalid SameSize " ++ show ts
            _ -> ice $ "solveClassCs: unknown class constraint " ++ show c
        )

    sameSize :: SrcPos -> Type -> Type -> Infer (Set ClassConstraint)
    sameSize :: SrcPos -> Type -> Type -> Infer (Map ClassConstraint SrcPos)
    sameSize pos ta tb = do
        sizeof' <- fmap sizeof (view envTypeDefs)
        let c = ("SameSize", [ta, tb])
        case liftA2 (==) (sizeof' ta) (sizeof' tb) of
            Just True -> pure Set.empty
            Just True -> pure Map.empty
            Just False -> throwError (NoClassInstance pos c)
            -- One or both of the two types are of unknown size due to polymorphism, so
            -- propagate the constraint to the scheme of the definition.
            Nothing -> pure (Set.singleton c)
            Nothing -> pure (Map.singleton c pos)

    substConstraint sub (Expected t1, Found pos t2) =
        (Expected (subst sub t1), Found pos (subst sub t2))

M src/Misc.hs => src/Misc.hs +6 -0
@@ 111,3 111,9 @@ maximumOr :: Ord a => a -> [a] -> a
maximumOr b = \case
    [] -> b
    as -> maximum as

whenJust :: Monad m => Maybe a -> (a -> m ()) -> m ()
whenJust = flip (maybe (pure ()))

uncurry3 :: (a -> b -> c -> d) -> ((a, b, c) -> d)
uncurry3 f (a, b, c) = f a b c

M test/tests/bad/transmute-size-mismatch2.carth => test/tests/bad/transmute-size-mismatch2.carth +1 -5
@@ 1,9 1,5 @@
;; foo should fail here, not bar
;; NoClassInstance

(define: (foo a)
    (forall (a b) (Fun a b))
  (transmute a))

(define: (bar Unit)
    (Fun Unit Int8)
  (foo (: 123 Int)))