~jojo/Carth

24655e2783e0112a21fc013d68530e9fcc2d56f6 — JoJo 23 days ago ecde8c3 master
Explicit class constraints in type sigs

    (define: transmute'
        (forall (a b) (where (SameSize a b))
                (Fun a b))
      transmute)
M src/Front/Infer.hs => src/Front/Infer.hs +11 -6
@@ 263,15 263,20 @@ inferRecDefs :: [Parsed.Def] -> Infer RecDefs
    checkScheme :: String -> Maybe Parsed.Scheme -> Infer (Maybe Scheme)
    checkScheme = curry $ \case
        ("main", Nothing) -> pure (Just (Forall Set.empty Set.empty mainType))
        ("main", Just s@(Parsed.Forall pos vs t)) | Set.size vs /= 0 || t /= mainType ->
            throwError (WrongMainType pos s)
        ("main", Just s@(Parsed.Forall pos vs cs t))
            | Set.size vs /= 0 || Set.size cs /= 0 || t /= mainType -> throwError
                (WrongMainType pos s)
        (_, Nothing) -> pure Nothing
        (_, Just (Parsed.Forall pos vs t)) -> do
        (_, Just (Parsed.Forall pos vs cs t)) -> do
            t' <- checkType pos t
            let s1 = Forall vs Set.empty t'
            cs' <- mapM (secondM (mapM (uncurry checkType))) (Set.toList cs)
            let s1 = Forall vs (Set.fromList cs') t'
            env <- view envLocalDefs
            s2 <- generalize env (Just (_scmConstraints s1)) Map.empty t'
            if (s1 == s2)
            s2@(Forall vs2 _ t2) <- generalize env
                                               (Just (_scmConstraints s1))
                                               Map.empty
                                               t'
            if ((vs, t') == (vs2, t2))
                then pure (Just s1)
                else throwError (InvalidUserTypeSig pos s1 s2)


M src/Front/Lex.hs => src/Front/Lex.hs +1 -0
@@ 139,6 139,7 @@ keyword = andSkipSpaceAfter $ choice $ (++)
        , string "define:" $> KdefineColon
        , string "extern" $> Kextern
        , string "forall" $> Kforall
        , string "where" $> Kwhere
        , string "fmatch" $> Kfmatch
        , string "match" $> Kmatch
        , string "if" $> Kif

M src/Front/Lexd.hs => src/Front/Lexd.hs +1 -1
@@ 12,7 12,7 @@ data Const

data Keyword
    = Kcolon | Kdot
    | Kforall | KFun | KBox
    | Kforall | Kwhere | KFun | KBox
    | Kdefine | KdefineColon
    | Kimport | Kextern | Kdata
    | Kfmatch | Kmatch | Kcase

M src/Front/Parse.hs => src/Front/Parse.hs +11 -5
@@ 7,6 7,7 @@ import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.Except
import Control.Monad.Combinators
import Data.Bifunctor
import Data.Maybe
import qualified Data.Set as Set
import Data.List


@@ 167,9 168,12 @@ pat = choice [patInt, patStr, patCtor, patVar, patTuple, ppat]
scheme :: Parser Scheme
scheme = do
    pos <- getSrcPos
    let wrap = fmap (Forall pos Set.empty)
        universal = reserved Kforall *> liftA2 (Forall pos) tvars type_
        tvars = parens (fmap Set.fromList (many tvar))
    let wrap = fmap (Forall pos Set.empty Set.empty)
        universal =
            reserved Kforall
                *> liftA3 (Forall pos) tvars (option Set.empty (try constrs)) type_
        tvars = parens (fmap Set.fromList (some tvar))
        constrs = parens (reserved Kwhere *> fmap Set.fromList (some (parens tapp)))
    wrap nonptype <|> (parens (universal <|> wrap ptype))

type_ :: Parser Type


@@ 202,7 206,7 @@ tuple p unit f = brackets $ do
    pure $ foldr f r ls

ptype :: Parser Type
ptype = choice [tfun, tbox, tapp]
ptype = choice [tfun, tbox, fmap (TConst . second (map snd)) tapp]
  where
    tfun = do
        reserved KFun


@@ 210,7 214,9 @@ ptype = choice [tfun, tbox, tapp]
        ts <- some type_
        pure (foldr1 TFun (t : ts))
    tbox = reserved KBox *> fmap TBox type_
    tapp = liftA2 (TConst .* (,) . idstr) big (some type_)

tapp :: Parser (String, [(SrcPos, Type)])
tapp = liftA2 ((,) . idstr) big (some (liftA2 (,) getSrcPos type_))

tvar :: Parser TVar
tvar = fmap TVExplicit small

M src/Front/Parsed.hs => src/Front/Parsed.hs +3 -2
@@ 11,7 11,6 @@ import FreeVars
import Front.TypeAst
import Front.Lexd (Const (..))


data IdCase = Big | Small

newtype Id (case' :: IdCase) = Id (WithPos String)


@@ 31,7 30,9 @@ data Type
    | TBox Type
    deriving (Show, Eq, Ord)

data Scheme = Forall SrcPos (Set TVar) Type
type ClassConstraint = (String, [(SrcPos, Type)])

data Scheme = Forall SrcPos (Set TVar) (Set ClassConstraint) Type
    deriving (Show, Eq)

data Pat

M src/Pretty.hs => src/Pretty.hs +10 -6
@@ 35,6 35,7 @@ instance Pretty Lexd.Keyword where
        Lexd.Kcolon -> ":"
        Lexd.Kdot -> "."
        Lexd.Kforall -> "forall"
        Lexd.Kwhere -> "where"
        Lexd.KFun -> "Fun"
        Lexd.KBox -> "Box"
        Lexd.Kdefine -> "define"


@@ 57,7 58,8 @@ instance Pretty Lexd.Keyword where


instance Pretty Parsed.Scheme where
    pretty' _ (Parsed.Forall _ ps t) = prettyScheme ps t
    pretty' _ (Parsed.Forall _ ps cs t) =
        prettyScheme ps (map (second (map snd)) (Set.toList cs)) t
instance Pretty Parsed.Type where
    pretty' _ = prettyType
instance Pretty Parsed.TPrim where


@@ 67,9 69,12 @@ instance Pretty Parsed.TVar where
instance Pretty (Parsed.Id a) where
    pretty' _ = Parsed.idstr

prettyScheme :: (Pretty p, Pretty t) => Set p -> t -> String
prettyScheme ps t =
    concat ["(forall (" ++ spcPretty (Set.toList ps) ++ ") ", pretty t ++ ")"]
prettyScheme :: (Pretty p, Pretty t) => Set p -> [(String, [t])] -> t -> String
prettyScheme ps cs t = concat
    [ "(forall (" ++ spcPretty (Set.toList ps) ++ ") "
    , "(where " ++ unwords (map prettyTConst cs) ++ ") "
    , pretty t ++ ")"
    ]

prettyType :: Parsed.Type -> String
prettyType = \case


@@ 111,9 116,8 @@ prettyTVar = \case
    Parsed.TVExplicit v -> Parsed.idstr v
    Parsed.TVImplicit v -> "•" ++ v


instance Pretty Inferred.Scheme where
    pretty' _ (Inferred.Forall ps _ t) = prettyScheme ps t
    pretty' _ (Inferred.Forall ps cs t) = prettyScheme ps (Set.toList cs) t
instance Pretty Inferred.Type where
    pretty' _ = prettyAnType


M std/array.carth => std/array.carth +5 -5
@@ 31,11 31,11 @@

(extern memcpy (Fun (Box Nat8) (Box Nat8) Nat (Box Nat8)))

(define (memcpy' dest src count)
  (: (transmute (memcpy (transmute (: dest (Box a)))
                        (transmute (: src (Box a)))
                        (* count (sizeof a))))
     (Box a)))
(define: (memcpy' dest src count)
    (forall (a) (Fun (Box a) (Box a) Nat (Box a)))
  (transmute (memcpy (transmute dest)
                     (transmute src)
                     (* count (sizeof a)))))

(define: (array/append (Array px nx) (Array py ny))
    (forall (a) (Fun (Array a) (Array a) (Array a)))

M std/math.carth => std/math.carth +5 -1
@@ 7,7 7,11 @@
(extern cos (Fun F64 F64))
(extern tan (Fun F64 F64))

(define (inc n) (+ n (cast 1)))
(define: (inc n)
    (forall (a) (where (Num a) (Cast Int a))
            (Fun a a))
  (+ n (cast 1)))

(define (dec n) (- n (cast 1)))

(define (neg x) (- 0 x))

A test/tests/good/function-sig-class.carth => test/tests/good/function-sig-class.carth +12 -0
@@ 0,0 1,12 @@
;; 123456

(import std)

(define: tr (forall (a b) (where (SameSize a b))
                    (Fun a b))
  transmute)

(define main
  (display (show-int ((: tr (Fun Nat Int))
                      ((: tr (Fun Int Nat))
                       123456)))))