~jojo/Carth

55fb4f948f1f3797078b584dc60b4f7dd68b37ed — JoJo a month ago ebd61e1
Check `cast` in Infer instead of Gen

Using the new virtual typeclass `Cast a b`, which as a predicate is
true then `a` can be cast to `b`, which is the case from all primitive
numeric types to eachother, and from one type to itself.
M src/Gen.hs => src/Gen.hs +56 -54
@@ 466,9 466,7 @@ genAppBuiltinVirtual (TypedVar g t) as = do
            Ast.TFun a b -> (a, genType b, \x -> genTransmute x a b)
            _ -> ice "genAppBuiltinVirtual: t not TFun for transmute"
        "cast" -> wrap1 $ case t of
            Ast.TFun a b -> case pos of
                Just p -> (a, genType b, \x -> genCast p x a b)
                Nothing -> ice "genAppBuiltinVirtual: cast: srcPos is Nothing"
            Ast.TFun a b -> (a, genType b, \x -> genCast x a b)
            _ -> ice "genAppBuiltinVirtual: t not TFun for cast"
        "deref" -> wrap1 $ case t of
            Ast.TFun a b -> (a, genType b, genDeref)


@@ 484,10 482,59 @@ genAppBuiltinVirtual (TypedVar g t) as = do
        a' <- genType a
        b' <- genType b
        transmute a' b' x
    genCast :: SrcPos -> Val -> Ast.Type -> Ast.Type -> Gen Val
    genCast pos x a b = do

        -- | Assumes that the from-type and to-type are of the same size.
    transmute :: Type -> Type -> Val -> Gen Val
    transmute t u x = case (t, u) of
        (FunctionType _ _ _, _) -> transmuteIce
        (_, FunctionType _ _ _) -> transmuteIce
        (MetadataType, _) -> transmuteIce
        (_, MetadataType) -> transmuteIce
        (LabelType, _) -> transmuteIce
        (_, LabelType) -> transmuteIce
        (TokenType, _) -> transmuteIce
        (_, TokenType) -> transmuteIce
        (VoidType, _) -> transmuteIce
        (_, VoidType) -> transmuteIce

        (IntegerType _, IntegerType _) -> bitcast'
        (IntegerType _, PointerType _ _) ->
            getLocal x >>= \x' -> emitAnonReg (inttoptr x' u) <&> VLocal
        (IntegerType _, FloatingPointType _) -> bitcast'
        (IntegerType _, VectorType _ _) -> bitcast'

        (PointerType pt _, PointerType pu _) | pt == pu -> pure x
                                             | otherwise -> bitcast'
        (PointerType _ _, IntegerType _) ->
            getLocal x >>= \x' -> emitAnonReg (ptrtoint x' u) <&> VLocal
        (PointerType _ _, _) -> stackCast
        (_, PointerType _ _) -> stackCast

        (FloatingPointType _, FloatingPointType _) -> pure x
        (FloatingPointType _, IntegerType _) -> bitcast'
        (FloatingPointType _, VectorType _ _) -> bitcast'

        (VectorType _ vt, VectorType _ vu) | vt == vu -> pure x
                                           | otherwise -> bitcast'
        (VectorType _ _, IntegerType _) -> bitcast'
        (VectorType _ _, FloatingPointType _) -> bitcast'

        (StructureType _ _, _) -> stackCast
        (_, StructureType _ _) -> stackCast
        (ArrayType _ _, _) -> stackCast
        (_, ArrayType _ _) -> stackCast
        (NamedTypeReference _, _) -> stackCast
        (_, NamedTypeReference _) -> stackCast
      where
        transmuteIce = ice $ "transmute " ++ show t ++ " to " ++ show u
        bitcast' = getLocal x >>= \x' -> emitAnonReg (bitcast x' u) <&> VLocal
        stackCast = getVar x >>= \x' -> emitAnonReg (bitcast x' (LLType.ptr u)) <&> VVar

    genCast :: Val -> Ast.Type -> Ast.Type -> Gen Val
    genCast x a b = do
        a' <- genType a
        b' <- genType b
        let err = ice $ "genCast: " ++ show a' ++ " to " ++ show b'
        let emit' instr = getLocal x >>= \x' -> emitAnonReg (instr x' b') <&> VLocal
        case (a', b') of
            _ | a' == b' -> pure x


@@ 500,18 547,20 @@ genAppBuiltinVirtual (TypedVar g t) as = do
                (_, FloatFP) -> emit' fptrunc
                (DoubleFP, _) -> emit' fpext
                (_, DoubleFP) -> emit' fptrunc
                _ -> throwError (CastErr pos a b)
                _ -> err
            (IntegerType _, FloatingPointType _) ->
                emit' $ if isInt a then sitofp else uitofp
            (FloatingPointType _, IntegerType _) ->
                emit' $ if isInt b then fptosi else fptoui
            _ -> throwError (CastErr pos a b)
            _ -> err

    genStore :: Val -> Val -> Gen Val
    genStore x p = do
        x' <- getLocal x
        p' <- getLocal p
        emitDo (store x' p')
        pure p

    isNat = \case
        TNat _ -> True
        TNatSize -> True


@@ 557,53 606,6 @@ genDeref = \case
    VVar x -> fmap VVar (emitAnonReg (load x))
    VLocal x -> pure (VVar x)

-- | Assumes that the from-type and to-type are of the same size.
transmute :: Type -> Type -> Val -> Gen Val
transmute t u x = case (t, u) of
    (FunctionType _ _ _, _) -> transmuteIce
    (_, FunctionType _ _ _) -> transmuteIce
    (MetadataType, _) -> transmuteIce
    (_, MetadataType) -> transmuteIce
    (LabelType, _) -> transmuteIce
    (_, LabelType) -> transmuteIce
    (TokenType, _) -> transmuteIce
    (_, TokenType) -> transmuteIce
    (VoidType, _) -> transmuteIce
    (_, VoidType) -> transmuteIce

    (IntegerType _, IntegerType _) -> bitcast'
    (IntegerType _, PointerType _ _) ->
        getLocal x >>= \x' -> emitAnonReg (inttoptr x' u) <&> VLocal
    (IntegerType _, FloatingPointType _) -> bitcast'
    (IntegerType _, VectorType _ _) -> bitcast'

    (PointerType pt _, PointerType pu _) | pt == pu -> pure x
                                         | otherwise -> bitcast'
    (PointerType _ _, IntegerType _) ->
        getLocal x >>= \x' -> emitAnonReg (ptrtoint x' u) <&> VLocal
    (PointerType _ _, _) -> stackCast
    (_, PointerType _ _) -> stackCast

    (FloatingPointType _, FloatingPointType _) -> pure x
    (FloatingPointType _, IntegerType _) -> bitcast'
    (FloatingPointType _, VectorType _ _) -> bitcast'

    (VectorType _ vt, VectorType _ vu) | vt == vu -> pure x
                                       | otherwise -> bitcast'
    (VectorType _ _, IntegerType _) -> bitcast'
    (VectorType _ _, FloatingPointType _) -> bitcast'

    (StructureType _ _, _) -> stackCast
    (_, StructureType _ _) -> stackCast
    (ArrayType _ _, _) -> stackCast
    (_, ArrayType _ _) -> stackCast
    (NamedTypeReference _, _) -> stackCast
    (_, NamedTypeReference _) -> stackCast
  where
    transmuteIce = ice $ "transmute " ++ show t ++ " to " ++ show u
    bitcast' = getLocal x >>= \x' -> emitAnonReg (bitcast x' u) <&> VLocal
    stackCast = getVar x >>= \x' -> emitAnonReg (bitcast x' (LLType.ptr u)) <&> VVar

callBuiltin :: String -> [Operand] -> Gen FunInstr
callBuiltin f as = do
    (_, rt) <- view (builtins . to (Map.lookup f)) <&> \case

M src/Infer.hs => src/Infer.hs +52 -18
@@ 120,7 120,11 @@ inferTopDefs tdefs ctors externs defs =
                             Set.empty
                             (TFun ta (TFun (TBox ta) (TBox ta)))
                    )
                  , ("cast", Forall (Set.fromList [tva, tvb]) Set.empty (TFun ta tb))
                  , ( "cast"
                    , Forall (Set.fromList [tva, tvb])
                             (Set.singleton ("Cast", [ta, tb]))
                             (TFun ta tb)
                    )
                  ]

checkType :: SrcPos -> Parsed.Type -> Infer Type


@@ 518,23 522,53 @@ solve (eqcs, ccs) = do
            solveUnis (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)

    solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Map ClassConstraint SrcPos)
    solveClassCs = fmap Map.unions . mapM
        (\(pos, c) -> case c of
            ("SameSize", [ta, tb]) -> sameSize pos ta tb
            ("SameSize", ts) -> ice $ "solveClassCs: invalid SameSize " ++ show ts
            _ -> ice $ "solveClassCs: unknown class constraint " ++ show c
        )

    sameSize :: SrcPos -> Type -> Type -> Infer (Map ClassConstraint SrcPos)
    sameSize pos ta tb = do
        sizeof' <- fmap sizeof (view envTypeDefs)
        let c = ("SameSize", [ta, tb])
        case liftA2 (==) (sizeof' ta) (sizeof' tb) of
            Just True -> pure Map.empty
            Just False -> throwError (NoClassInstance pos c)
            -- One or both of the two types are of unknown size due to polymorphism, so
            -- propagate the constraint to the scheme of the definition.
            Nothing -> pure (Map.singleton c pos)
    solveClassCs = fmap Map.unions . mapM solveClassConstraint

    solveClassConstraint
        :: (SrcPos, ClassConstraint) -> Infer (Map ClassConstraint SrcPos)
    solveClassConstraint (pos, c) = case c of
            -- Virtual classes
        ("SameSize", [ta, tb]) -> sameSize (ta, tb)
        ("Cast", [ta, tb]) -> cast (ta, tb)
        -- "Real classes"
        -- ... TODO
        _ -> ice $ "solveClassCs: invalid class constraint " ++ show c
      where
        ok = pure Map.empty
        propagate = pure (Map.singleton c pos)
        err = throwError (NoClassInstance pos c)

        -- TODO: Maybe we should move the check against user-provided explicit signature from
        --       `generalize` to here. Like, we could keep the explicit scheme (if there is
        --       one) in the `Env`.
        --
        -- | As the name indicates, a predicate that is true / class that is instanced when
        --   two types are of the same size. If the size for either cannot be determined yet
        --   due to polymorphism, the constraint is propagated.
        sameSize :: (Type, Type) -> Infer (Map ClassConstraint SrcPos)
        sameSize (ta, tb) = do
            sizeof' <- fmap sizeof (view envTypeDefs)
            case liftA2 (==) (sizeof' ta) (sizeof' tb) of
                _ | ta == tb -> ok
                Just True -> ok
                Just False -> err
                -- One or both of the two types are of unknown size due to polymorphism, so
                -- propagate the constraint to the scheme of the definition.
                Nothing -> propagate

        -- | This class is instanced when the first type can be `cast` to the other.
        cast :: (Type, Type) -> Infer (Map ClassConstraint SrcPos)
        cast = \case
            (ta, tb) | ta == tb -> ok
            (TPrim _, TPrim _) -> ok
            (TVar _, _) -> propagate
            (_, TVar _) -> propagate
            (TConst _, _) -> err
            (_, TConst _) -> err
            (TFun _ _, _) -> err
            (_, TFun _ _) -> err
            (TBox _, _) -> err
            (_, TBox _) -> err

    substConstraint sub (Expected t1, Found pos t2) =
        (Expected (subst sub t1), Found pos (subst sub t2))

A test/tests/bad/cast-mono.carth => test/tests/bad/cast-mono.carth +5 -0
@@ 0,0 1,5 @@
;; NoClassInstance

(define: (foo a)
    (Fun Int (Box Nat8))
  (cast a))

A test/tests/bad/cast-poly.carth => test/tests/bad/cast-poly.carth +5 -0
@@ 0,0 1,5 @@
;; NoClassInstance

(define: (foo a)
    (forall (a b) (Fun a b))
  (cast a))

A test/tests/bad/cast-poly2.carth => test/tests/bad/cast-poly2.carth +10 -0
@@ 0,0 1,10 @@
;; NoClassInstance

(import std)

(define (foo a)
  (cast a))

(define: (bar Unit)
    (Fun Unit Int)
  (foo (Some 5)))

A test/tests/good/cast-poly.carth => test/tests/good/cast-poly.carth +9 -0
@@ 0,0 1,9 @@
;; 420

(import std)

(define (to-int a) (: (cast a) Int))

(define main
  (display (show-int (+ (to-int (: 400.20     F64))
                        (to-int (: (cast 20) Nat))))))

A test/tests/good/cast.carth => test/tests/good/cast.carth +6 -0
@@ 0,0 1,6 @@
;; 1337

(import std)

(define main
  (display (show-int (cast (: (* 13.37 100.0) F64)))))

A test/tests/good/transmute-mono.carth => test/tests/good/transmute-mono.carth +6 -0
@@ 0,0 1,6 @@
;; -1337

(import std)

(define main
  (display (show-int ((: transmute (Fun Nat Int)) (cast 18446744073709550279)))))

A test/tests/good/transmute-poly.carth => test/tests/good/transmute-poly.carth +10 -0
@@ 0,0 1,10 @@
;; 123456

(import std)

(define tr transmute)

(define main
  (display (show-int ((: tr (Fun Nat Int))
                      ((: tr (Fun Int Nat))
                       123456)))))