~jojo/Carth

wip: on-demand parsing, common trees that grow, & module system

I'll include the contents of "0 About this branch.txt" here:

----------------

I just had too much stuff going on at the same time.
It was shaping up to become a way too big and messy commit.
I'm gonna attack some of these things one at a time.

- On-demand parsing (and maybe on-demand typechecking as well?)
  As a first step on the road to query-based compilation,
  I thought it proper to start with converting the lexer and parser.
  This had consequences.
  For an on-demand parser, we need a better module system.
  We can't have everything in the same Module / Program.
- Module system / namespacing
  This would not be a huge diff in itself, but we need qualified names.
- Qualified names
  This has a dominoe effect.
  All parts of the program assume that everything lives in the same global namespace,
  and just uses local, relative, unqualified identifiers.
  Mainly to consider, adding qualified identifiers (and module paths in imports) affects the type checker.
- Common Trees That Grow, Concrete and Abstract
  As part of making the qualified names change, I needed two new modules: Resolve{d}.
  But a resolved parse-tree would basically be identical to Parsed, except the names.
  So we should have a common Tree That Grows for both of them: Concrete.
  And we're gonna have to update both Inferred and Checked with the new qualified names.
  Why not merge these two into a single Abstract as well while we're at it?

And so you see how it all grew pretty wild.
Now, I'm thinking of putting all this mess aside for a little bit, and try to make the changes more progressively.
In a first (series of) commits, we'll focus on: namespacing & qualified names, which depends on adding Resolve{d} as well.
We really might as well split Parsed into Concrete while we're at it.
I deem that valuable enough to warrant the increase in size of the changes.

Then we'll look back at the stuff in this branch and decide on what to
extract and finish next!

----------------
A 0 About this branch.txt => 0 About this branch.txt +33 -0
@@ 0,0 1,33 @@
I just had too much stuff going on at the same time.
It was shaping up to become a way too big and messy commit.
I'm gonna attack some of these things one at a time.

- On-demand parsing (and maybe on-demand typechecking as well?)
  As a first step on the road to query-based compilation,
  I thought it proper to start with converting the lexer and parser.
  This had consequences.
  For an on-demand parser, we need a better module system.
  We can't have everything in the same Module / Program.
- Module system / namespacing
  This would not be a huge diff in itself, but we need qualified names.
- Qualified names
  This has a dominoe effect.
  All parts of the program assume that everything lives in the same global namespace,
  and just uses local, relative, unqualified identifiers.
  Mainly to consider, adding qualified identifiers (and module paths in imports) affects the type checker.
- Common Trees That Grow, Concrete and Abstract
  As part of making the qualified names change, I needed two new modules: Resolve{d}.
  But a resolved parse-tree would basically be identical to Parsed, except the names.
  So we should have a common Tree That Grows for both of them: Concrete.
  And we're gonna have to update both Inferred and Checked with the new qualified names.
  Why not merge these two into a single Abstract as well while we're at it?

And so you see how it all grew pretty wild.
Now, I'm thinking of putting all this mess aside for a little bit, and try to make the changes more progressively.
In a first (series of) commits, we'll focus on: namespacing & qualified names, which depends on adding Resolve{d} as well.
We really might as well split Parsed into Concrete while we're at it.
I deem that valuable enough to warrant the increase in size of the changes.

Then we'll look back at the stuff in this branch and decide on what to extract and finish next!
  
  

M app/Main.hs => app/Main.hs +9 -8
@@ 61,14 61,15 @@ runFile cfg = do
frontend :: Config cfg => cfg -> FilePath -> IO Ast.Program
frontend cfg f = do
    let d = getDebug cfg
    verbose cfg "   Lexing"
    !tts <- lex f
    when d $ writeFile ".dbg.lexd" (show tts)
    verbose cfg "   Expanding macros"
    !tts' <- expandMacros f tts
    when d $ writeFile ".dbg.expanded" (show tts')
    verbose cfg "   Parsing"
    !ast <- parse f tts'
    -- verbose cfg "   Lexing"
    -- !tts <- lex f
    -- when d $ writeFile ".dbg.lexd" (show tts)
    -- verbose cfg "   Expanding macros"
    -- !tts' <- expandMacros f tts
    -- when d $ writeFile ".dbg.expanded" (show tts')
    -- verbose cfg "   Parsing"
    -- !ast <- parse f tts'
    !ast <- queryParsedModule (QMName "main" ["main", "self"])
    verbose cfg "   Typechecking"
    !ann <- typecheck' f ast
    when d $ writeFile ".dbg.checked" (show ann)

M carth.cabal => carth.cabal +7 -1
@@ 26,8 26,12 @@ library
      Misc
      Pretty
      Sizeof
      Query
      Name
      SrcPos

      Front.SrcPos
      Front.Concrete
      Front.Abstract
      Front.Subst
      Front.Err
      Front.TypeAst


@@ 37,6 41,8 @@ library
      Front.Parse
      Front.Parser
      Front.Parsed
      Front.Resolve
      Front.Resolved
      Front.Check
      Front.Checked
      Front.Infer

M src/EnvVars.hs => src/EnvVars.hs +3 -3
@@ 1,9 1,9 @@
module EnvVars (modulePaths) where
module EnvVars (packagePaths) where

import System.Environment (lookupEnv)

import Misc


modulePaths :: IO [FilePath]
modulePaths = fmap (maybe [] (splitOn ":")) (lookupEnv "CARTH_MODULE_PATH")
packagePaths :: IO [FilePath]
packagePaths = fmap (maybe [] (splitOn ":")) (lookupEnv "CARTH_PACKAGE_PATH")

A src/Front/Abstract.hs => src/Front/Abstract.hs +139 -0
@@ 0,0 1,139 @@
{-# LANGUAGE TemplateHaskell #-}

-- TODO: Can this and Checked be merged to a single, parametrized AST?

-- | Type annotated AST as a result of typechecking
module Front.Abstract where

import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import Data.Bifunctor
import Lens.Micro.Platform (makeLenses)

import Misc
import SrcPos
import Name
import Front.Concrete (Const (..))
import Front.Resolved
-- import qualified Front.Parsed as Parsed
-- import Front.Concrete (Type, TConst, TVar(..), Const(..))
-- import Front.Parsed
import Front.TypeAst

type LhsName = Ident

type TConst = ResolvedTConst
type Type = ResolvedType

data TypeErr
    = MainNotDefined
    | InvalidUserTypeSig SrcPos Scheme Scheme
    | CtorArityMismatch SrcPos String Int Int
    | ConflictingPatVarDefs SrcPos String
    | UndefCtor SrcPos String
    | UndefVar SrcPos String
    | InfType SrcPos Type Type TVar Type
    | UnificationFailed SrcPos Type Type Type Type
    | ConflictingTypeDef SrcPos String
    | ConflictingCtorDef SrcPos String
    | RedundantCase SrcPos
    | InexhaustivePats SrcPos String
    | ExternNotMonomorphic Ident TVar
    | FoundHole SrcPos
    | RecTypeDef String SrcPos
    | UndefType SrcPos QualName
    | WrongMainType SrcPos ResolvedScheme
    | RecursiveVarDef (WithPos String)
    | TypeInstArityMismatch SrcPos QualName Int Int
    | ConflictingVarDef SrcPos String
    | NoClassInstance SrcPos ClassConstraint
    | FunCaseArityMismatch SrcPos Int Int
    | FunArityMismatch SrcPos Int Int
    | DeBruijnIndexOutOfRange SrcPos Word
    | FreeVarsInData SrcPos TVar
    | FreeVarsInAlias SrcPos TVar
    deriving (Show)

type ClassConstraint = (QualName, [Type])

data Scheme = Forall
    { _scmParams :: Set TVar
    , _scmConstraints :: Set ClassConstraint
    , _scmBody :: Type
    }
    deriving (Show, Eq)
makeLenses ''Scheme

type VariantIx = Integer

type Span = Integer

data Variant = VariantIx VariantIx | VariantStr String
    deriving (Show, Eq, Ord)

data Con = Con
    { variant :: Variant
    , span :: Span
    , argTs :: [Type]
    }
    deriving Show

data Pat'
    = PVar LhsName Type
    | PWild
    | PCon Con [Pat]
    | PBox Pat
    deriving Show

data Pat = Pat SrcPos Type Pat'
    deriving Show

type Fun match = ([(LhsName, Type)], (Expr match, Type))

data Expr match typ
    = Lit Const
    | EVar QualName typ
    | App (Expr match typ) [Expr match typ] typ
    | If (Expr match typ) (Expr match typ) (Expr match typ)
    | Let (Def match) (Expr match typ)
    | Fun (Fun match)
    | Match [Expr match typ] match
    | Ctor VariantIx Span TConst [typ]
    | Sizeof Type
    deriving Show

type Defs match = TopologicalOrder (Def match)
data Def match = VarDef (VarDef match) | RecDefs (RecDefs match) deriving Show
type VarDef match = (LhsName, (Scheme, Expr match))
type RecDefs match = [(LhsName, (Scheme, Fun match))]
data TypeDefRhs = Data [(WithPos QualName, [Type])] | Alias SrcPos Type
    deriving Show
type TypeDefs = Map QualName ([TVar], TypeDefRhs)
type TypeAliases = Map QualName ([TVar], Type)
type Ctors = Map LhsName (VariantIx, (String, [TVar]), [Type], Span)
type Externs = Map LhsName Type


instance Eq Con where
    (==) (Con c1 _ _) (Con c2 _ _) = c1 == c2

instance Ord Con where
    compare (Con c1 _ _) (Con c2 _ _) = compare c1 c2


ftv :: Type -> Set TVar
ftv = \case
    TVar tv -> Set.singleton tv
    TPrim _ -> Set.empty
    TFun pts rt -> Set.unions (ftv rt : map ftv pts)
    TBox t -> ftv t
    TConst (_, ts) -> Set.unions (map ftv ts)

defSigs :: Def match -> [(LhsName, Scheme)]
defSigs = \case
    VarDef d -> [defSig d]
    RecDefs ds -> map defSig ds

defSig :: (LhsName, (Scheme, a)) -> (LhsName, Scheme)
defSig = second fst

M src/Front/Check.hs => src/Front/Check.hs +251 -253
@@ 1,281 1,279 @@
{-# LANGUAGE DataKinds #-}
--{-# LANGUAGE DataKinds #-}

module Front.Check (typecheck) where
--module Front.Check (typecheck) where
module Front.Check where

import Prelude hiding (span)
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Bifunctor
import Data.Bitraversable
import Data.Foldable
import Data.Functor
import Control.Applicative
import qualified Data.Map as Map
import Data.Map (Map)
import qualified Data.Set as Set
import Data.Set (Set)
-- import Prelude hiding (span)
-- import Control.Monad.Except
-- import Control.Monad.Reader
-- import Control.Monad.State.Strict
-- import Data.Bifunctor
-- import Data.Bitraversable
-- import Data.Foldable
-- import Data.Functor
-- import Control.Applicative
-- import qualified Data.Map as Map
-- import Data.Map (Map)
-- import qualified Data.Set as Set
-- import Data.Set (Set)

import Misc
import Front.SrcPos
import Front.Subst
import qualified Front.Parsed as Parsed
import Front.Parsed (Id(..), TVar(..), idstr)
import Front.Err
import qualified Front.Inferred as Inferred
import Front.Match
import Front.Infer
import Front.TypeAst
import qualified Front.Checked as Checked
import Front.Checked (Virt(..))
-- import Misc
-- import SrcPos
-- import Front.Subst
-- import Front.Concrete (Id(..), TVar(..), idstr)
-- import Front.Err
-- import qualified Front.Inferred as Inferred
-- import Front.Match
-- import Front.Infer
-- import Front.TypeAst
-- import qualified Front.Checked as Checked
-- import Front.Checked (Virt(..))


typecheck :: Parsed.Program -> Either TypeErr Checked.Program
typecheck (Parsed.Program defs tdefs externs) = runExcept $ do
    (tdefs', ctors) <- checkTypeDefs tdefs
    externs' <- checkExterns tdefs' externs
    inferred <- inferTopDefs tdefs' ctors externs' defs
    let bound = unboundTypeVarsToUnit inferred
    -- Remove aliases. They are should be replaced in the AST and no longer needed at this point.
    let tdefs'' = Map.mapMaybe
            (secondM $ \case
                Inferred.Data rhs -> Just rhs
                Inferred.Alias _ _ -> Nothing
            )
            tdefs'
    let mTypeDefs = fmap (map (unpos . fst) . snd) tdefs''
    compiled <- compileDecisionTrees mTypeDefs bound
    checkMainDefined compiled
    let tdefs''' = fmap (second (map (first unpos))) tdefs''
    pure (Checked.Program compiled tdefs''' externs')
  where
    checkMainDefined ds =
        unless ("main" `elem` map fst (Checked.flattenDefs ds)) (throwError MainNotDefined)
-- typecheck :: Ast.Program -> Either TypeErr Checked.Program
-- typecheck (Ast.Program defs tdefs externs) = runExcept $ do
--     (tdefs', ctors) <- checkTypeDefs tdefs
--     externs' <- checkExterns tdefs' externs
--     inferred <- inferTopDefs tdefs' ctors externs' defs
--     let bound = unboundTypeVarsToUnit inferred
--     -- Remove aliases. They are should be replaced in the AST and no longer needed at this point.
--     let tdefs'' = Map.mapMaybe
--             (secondM $ \case
--                 Inferred.Data rhs -> Just rhs
--                 Inferred.Alias _ _ -> Nothing
--             )
--             tdefs'
--     let mTypeDefs = fmap (map (unpos . fst) . snd) tdefs''
--     compiled <- compileDecisionTrees mTypeDefs bound
--     checkMainDefined compiled
--     let tdefs''' = fmap (second (map (first unpos))) tdefs''
--     pure (Checked.Program compiled tdefs''' externs')
--   where
--     checkMainDefined ds =
--         unless ("main" `elem` map fst (Checked.flattenDefs ds)) (throwError MainNotDefined)

type CheckTypeDefs a
    = ReaderT
          (Map String (Either Int ([TVar], Parsed.Type)))
          (StateT (Inferred.TypeDefs, Inferred.Ctors) (Except TypeErr))
          a
-- type CheckTypeDefs a
--     = ReaderT
--           (Map String (Either Int ([TVar], Ast.Type)))
--           (StateT (Inferred.TypeDefs, Inferred.Ctors) (Except TypeErr))
--           a

checkTypeDefs :: [Parsed.TypeDef] -> Except TypeErr (Inferred.TypeDefs, Inferred.Ctors)
checkTypeDefs tdefs = do
    let tdefs' = Map.union (fmap (Left . length . fst) builtinTypeDefs) $ Map.fromList
            (map
                (\case
                    Parsed.TypeDef x ps _ -> (idstr x, Left (length ps))
                    Parsed.TypeAlias x ps t -> (idstr x, Right (map TVExplicit ps, t))
                )
                tdefs
            )
    (tdefs', ctors) <- execStateT (runReaderT (forM_ tdefs checkTypeDef) tdefs')
                                  (builtinTypeDefs, builtinConstructors)
    forM_ (Map.toList tdefs') (assertNoRec tdefs')
    pure (tdefs', ctors)
-- checkTypeDefs :: [Ast.TypeDef] -> Except TypeErr (Inferred.TypeDefs, Inferred.Ctors)
-- checkTypeDefs tdefs = do
--     let tdefs' = Map.union (fmap (Left . length . fst) builtinTypeDefs) $ Map.fromList
--             (map
--                 (\case
--                     Ast.TypeDef x ps _ -> (idstr x, Left (length ps))
--                     Ast.TypeAlias x ps t -> (idstr x, Right (map TVExplicit ps, t))
--                 )
--                 tdefs
--             )
--     (tdefs', ctors) <- execStateT (runReaderT (forM_ tdefs checkTypeDef) tdefs')
--                                   (builtinTypeDefs, builtinConstructors)
--     forM_ (Map.toList tdefs') (assertNoRec tdefs')
--     pure (tdefs', ctors)

checkTypeDef :: Parsed.TypeDef -> CheckTypeDefs ()
checkTypeDef = \case
    Parsed.TypeDef (Parsed.Id x) ps cs -> do
        checkNotAlreadyDefined x
        let ps' = map TVExplicit ps
        cs' <- checkCtors (unpos x, ps') cs
        forM_ (foldMap (foldMap Inferred.ftv . snd) cs')
            $ \tv -> unless (tv `elem` ps') (throwError (FreeVarsInData (getPos x) tv))
        modify (first (Map.insert (unpos x) (ps', Inferred.Data cs')))
    Parsed.TypeAlias (Parsed.Id x) ps t -> do
        checkNotAlreadyDefined x
        let ps' = map TVExplicit ps
        t' <- checkType' (getPos x) t
        forM_ (Inferred.ftv t')
            $ \tv -> unless (tv `elem` ps') (throwError (FreeVarsInAlias (getPos x) tv))
        modify (first (Map.insert (unpos x) (ps', Inferred.Alias (getPos x) t')))
  where
    checkNotAlreadyDefined (WithPos xpos x) = do
        alreadyDefined <- gets (Map.member x . fst)
        when alreadyDefined (throwError (ConflictingTypeDef xpos x))
-- checkTypeDef :: Ast.TypeDef -> CheckTypeDefs ()
-- checkTypeDef = \case
--     Ast.TypeDef (Ast.Id x) ps cs -> do
--         checkNotAlreadyDefined x
--         let ps' = map TVExplicit ps
--         cs' <- checkCtors (unpos x, ps') cs
--         forM_ (foldMap (foldMap Inferred.ftv . snd) cs')
--             $ \tv -> unless (tv `elem` ps') (throwError (FreeVarsInData (getPos x) tv))
--         modify (first (Map.insert (unpos x) (ps', Inferred.Data cs')))
--     Ast.TypeAlias (Ast.Id x) ps t -> do
--         checkNotAlreadyDefined x
--         let ps' = map TVExplicit ps
--         t' <- checkType' (getPos x) t
--         forM_ (Inferred.ftv t')
--             $ \tv -> unless (tv `elem` ps') (throwError (FreeVarsInAlias (getPos x) tv))
--         modify (first (Map.insert (unpos x) (ps', Inferred.Alias (getPos x) t')))
--   where
--     checkNotAlreadyDefined (WithPos xpos x) = do
--         alreadyDefined <- gets (Map.member x . fst)
--         when alreadyDefined (throwError (ConflictingTypeDef xpos x))

checkCtors
    :: (String, [TVar])
    -> Parsed.ConstructorDefs
    -> CheckTypeDefs [(WithPos String, [Inferred.Type])]
checkCtors parent (Parsed.ConstructorDefs cs) =
    let cspan = fromIntegral (length cs) in mapM (checkCtor cspan) (zip [0 ..] cs)
  where
    checkCtor cspan (i, (Id c'@(WithPos pos c), ts)) = do
        cAlreadyDefined <- gets (Map.member c . snd)
        when cAlreadyDefined (throwError (ConflictingCtorDef pos c))
        ts' <- mapM (checkType' pos) ts
        modify (second (Map.insert c (i, parent, ts', cspan)))
        pure (c', ts')
-- checkCtors
--     :: (String, [TVar]) -> Ast.ConstructorDefs -> CheckTypeDefs [(WithPos String, [Inferred.Type])]
-- checkCtors parent (Ast.ConstructorDefs cs) =
--     let cspan = fromIntegral (length cs) in mapM (checkCtor cspan) (zip [0 ..] cs)
--   where
--     checkCtor cspan (i, (Id c'@(WithPos pos c), ts)) = do
--         cAlreadyDefined <- gets (Map.member c . snd)
--         when cAlreadyDefined (throwError (ConflictingCtorDef pos c))
--         ts' <- mapM (checkType' pos) ts
--         modify (second (Map.insert c (i, parent, ts', cspan)))
--         pure (c', ts')

checkType' :: SrcPos -> Parsed.Type -> CheckTypeDefs Inferred.Type
checkType' pos t = do
    tdefs <- ask
    let checkTConst (x, args) = case Map.lookup x tdefs of
            Nothing -> throwError (UndefType pos x)
            Just (Left expectedN) ->
                let foundN = length args
                in  if expectedN == foundN
                        then do
                            args' <- mapM go args
                            pure (TConst (x, args'))
                        else throwError (TypeInstArityMismatch pos x expectedN foundN)
            Just (Right (params, u)) -> subst (Map.fromList (zip params args)) <$> go u
        go = checkType checkTConst
    go t
-- checkType' :: SrcPos -> Ast.Type -> CheckTypeDefs Inferred.Type
-- checkType' pos t = do
--     tdefs <- ask
--     let checkTConst (x, args) = case Map.lookup x tdefs of
--             Nothing -> throwError (UndefType pos x)
--             Just (Left expectedN) ->
--                 let foundN = length args
--                 in  if expectedN == foundN
--                         then do
--                             args' <- mapM go args
--                             pure (TConst (x, args'))
--                         else throwError (TypeInstArityMismatch pos x expectedN foundN)
--             Just (Right (params, u)) -> subst (Map.fromList (zip params args)) <$> go u
--         go = checkType checkTConst
--     go t

builtinTypeDefs :: Inferred.TypeDefs
builtinTypeDefs = Map.fromList $ map
    (\(x, ps, cs) ->
        (x, (ps, Inferred.Data $ map (first (WithPos (SrcPos "<builtin>" 0 0 Nothing))) cs))
    )
    builtinDataTypes'
-- builtinTypeDefs :: Inferred.TypeDefs
-- builtinTypeDefs = Map.fromList $ map
--     (\(x, ps, cs) ->
--         (x, (ps, Inferred.Data $ map (first (WithPos (SrcPos "<builtin>" 0 0 Nothing))) cs))
--     )
--     builtinDataTypes'

builtinConstructors :: Inferred.Ctors
builtinConstructors = Map.unions (map builtinConstructors' builtinDataTypes')
  where
    builtinConstructors' (x, ps, cs) =
        let cSpan = fromIntegral (length cs)
        in  foldl' (\csAcc (i, (cx, cps)) -> Map.insert cx (i, (x, ps), cps, cSpan) csAcc)
                   Map.empty
                   (zip [0 ..] cs)
-- builtinConstructors :: Inferred.Ctors
-- builtinConstructors = Map.unions (map builtinConstructors' builtinDataTypes')
--   where
--     builtinConstructors' (x, ps, cs) =
--         let cSpan = fromIntegral (length cs)
--         in  foldl' (\csAcc (i, (cx, cps)) -> Map.insert cx (i, (x, ps), cps, cSpan) csAcc)
--                    Map.empty
--                    (zip [0 ..] cs)

builtinDataTypes' :: [(String, [TVar], [(String, [Inferred.Type])])]
builtinDataTypes' =
    [ ( "Array"
      , [TVImplicit "a"]
      , [("Array", [Inferred.TBox (Inferred.TVar (TVImplicit "a")), Inferred.TPrim TNatSize])]
      )
    , ("Str", [], [("Str", [tArray (Inferred.TPrim (TNat 8))])])
    , ( "Cons"
      , [TVImplicit "a", TVImplicit "b"]
      , [("Cons", [Inferred.TVar (TVImplicit "a"), Inferred.TVar (TVImplicit "b")])]
      )
    , ("Unit", [], [unit'])
    , ("RealWorld", [], [("UnsafeRealWorld", [])])
    , ("Bool", [], [("False", []), ("True", [])])
    ]
    where unit' = ("Unit", [])
-- builtinDataTypes' :: [(String, [TVar], [(String, [Inferred.Type])])]
-- builtinDataTypes' =
--     [ ( "Array"
--       , [TVImplicit "a"]
--       , [("Array", [Inferred.TBox (Inferred.TVar (TVImplicit "a")), Inferred.TPrim TNatSize])]
--       )
--     , ("Str", [], [("Str", [tArray (Inferred.TPrim (TNat 8))])])
--     , ( "Cons"
--       , [TVImplicit "a", TVImplicit "b"]
--       , [("Cons", [Inferred.TVar (TVImplicit "a"), Inferred.TVar (TVImplicit "b")])]
--       )
--     , ("Unit", [], [unit'])
--     , ("RealWorld", [], [("UnsafeRealWorld", [])])
--     , ("Bool", [], [("False", []), ("True", [])])
--     ]
--     where unit' = ("Unit", [])

assertNoRec :: Inferred.TypeDefs -> (String, ([TVar], Inferred.TypeDefRhs)) -> Except TypeErr ()
assertNoRec tdefs' (x, (xinst, rhs)) = assertNoRec' (Set.singleton (x, map TVar xinst))
                                                    rhs
                                                    Map.empty
  where
    assertNoRec' seen (Inferred.Data cs) s =
        forM_ cs $ \(WithPos cpos _, cts) -> forM_ cts (assertNoRecType seen cpos . subst s)
    assertNoRec' seen (Inferred.Alias pos t) s = assertNoRecType seen pos (subst s t)
    assertNoRecType seen cpos = \case
        Inferred.TConst (y, yinst) -> do
            when (Set.member (y, yinst) seen) $ throwError (RecTypeDef x cpos)
            let (tvs, cs) = Map.findWithDefault
                    (ice $ "assertNoRec: type id " ++ show y ++ " not in tdefs")
                    y
                    tdefs'
            let substs = Map.fromList (zip tvs yinst)
            assertNoRec' (Set.insert (y, yinst) seen) cs substs
        _ -> pure ()
-- assertNoRec :: Inferred.TypeDefs -> (String, ([TVar], Inferred.TypeDefRhs)) -> Except TypeErr ()
-- assertNoRec tdefs' (x, (xinst, rhs)) = assertNoRec' (Set.singleton (x, map TVar xinst))
--                                                     rhs
--                                                     Map.empty
--   where
--     assertNoRec' seen (Inferred.Data cs) s =
--         forM_ cs $ \(WithPos cpos _, cts) -> forM_ cts (assertNoRecType seen cpos . subst s)
--     assertNoRec' seen (Inferred.Alias pos t) s = assertNoRecType seen pos (subst s t)
--     assertNoRecType seen cpos = \case
--         Inferred.TConst (y, yinst) -> do
--             when (Set.member (y, yinst) seen) $ throwError (RecTypeDef x cpos)
--             let (tvs, cs) = Map.findWithDefault
--                     (ice $ "assertNoRec: type id " ++ show y ++ " not in tdefs")
--                     y
--                     tdefs'
--             let substs = Map.fromList (zip tvs yinst)
--             assertNoRec' (Set.insert (y, yinst) seen) cs substs
--         _ -> pure ()

checkExterns :: Inferred.TypeDefs -> [Parsed.Extern] -> Except TypeErr Inferred.Externs
checkExterns tdefs = fmap (Map.union Checked.builtinExterns . Map.fromList) . mapM checkExtern
  where
    checkExtern (Parsed.Extern name t) = do
        t' <- checkType (checkTConst tdefs (getPos name)) t
        case Set.lookupMin (Inferred.ftv t') of
            Just tv -> throwError (ExternNotMonomorphic name tv)
            Nothing -> pure (idstr name, t')
-- checkExterns :: Inferred.TypeDefs -> [Ast.Extern] -> Except TypeErr Inferred.Externs
-- checkExterns tdefs = fmap (Map.union Checked.builtinExterns . Map.fromList) . mapM checkExtern
--   where
--     checkExtern (Ast.Extern name t) = do
--         t' <- checkType (checkTConst tdefs (getPos name)) t
--         case Set.lookupMin (Inferred.ftv t') of
--             Just tv -> throwError (ExternNotMonomorphic name tv)
--             Nothing -> pure (idstr name, t')

-- Any free / unbound type variables left in the AST after Infer are replacable with any
-- type, unless there's a bug in the compiler. Therefore, replace them all with Unit now.
unboundTypeVarsToUnit :: Inferred.Defs -> Inferred.Defs
unboundTypeVarsToUnit (Topo defs) = Topo $ runReader (mapM goDef defs) Set.empty
  where
    goDef :: Inferred.Def -> Reader (Set TVar) Inferred.Def
    goDef = \case
        Inferred.VarDef d -> Inferred.VarDef <$> secondM (goDefRhs goExpr) d
        Inferred.RecDefs ds -> Inferred.RecDefs <$> mapM (secondM (goDefRhs goFun)) ds
-- -- Any free / unbound type variables left in the AST after Infer are replacable with any
-- -- type, unless there's a bug in the compiler. Therefore, replace them all with Unit now.
-- unboundTypeVarsToUnit :: Inferred.Defs -> Inferred.Defs
-- unboundTypeVarsToUnit (Topo defs) = Topo $ runReader (mapM goDef defs) Set.empty
--   where
--     goDef :: Inferred.Def -> Reader (Set TVar) Inferred.Def
--     goDef = \case
--         Inferred.VarDef d -> Inferred.VarDef <$> secondM (goDefRhs goExpr) d
--         Inferred.RecDefs ds -> Inferred.RecDefs <$> mapM (secondM (goDefRhs goFun)) ds

    goDefRhs f (scm, x) = (scm, ) <$> local (Set.union (Inferred._scmParams scm)) (f x)
--     goDefRhs f (scm, x) = (scm, ) <$> local (Set.union (Inferred._scmParams scm)) (f x)

    goMatch :: Inferred.Match -> Reader (Set TVar) Inferred.Match
    goMatch (WithPos pos (ms, cs, tps, tb)) = do
        ms' <- mapM goExpr ms
        cs' <- mapM (bimapM (mapPosdM (mapM goPat)) goExpr) cs
        tps' <- mapM subst tps
        tb' <- subst tb
        pure (WithPos pos (ms', cs', tps', tb'))
--     goMatch :: Inferred.Match -> Reader (Set TVar) Inferred.Match
--     goMatch (WithPos pos (ms, cs, tps, tb)) = do
--         ms' <- mapM goExpr ms
--         cs' <- mapM (bimapM (mapPosdM (mapM goPat)) goExpr) cs
--         tps' <- mapM subst tps
--         tb' <- subst tb
--         pure (WithPos pos (ms', cs', tps', tb'))

    goFun :: Inferred.Fun -> Reader (Set TVar) Inferred.Fun
    goFun (params, body) = liftA2 (,) (mapM (secondM subst) params) (bimapM goExpr subst body)
--     goFun :: Inferred.Fun -> Reader (Set TVar) Inferred.Fun
--     goFun (params, body) = liftA2 (,) (mapM (secondM subst) params) (bimapM goExpr subst body)

    goExpr :: Inferred.Expr -> Reader (Set TVar) Inferred.Expr
    goExpr = \case
        Inferred.Lit c -> pure (Inferred.Lit c)
        Inferred.Var v -> Inferred.Var <$> secondM goTypedVar v
        Inferred.App f as tr -> liftA3 Inferred.App (goExpr f) (mapM goExpr as) (subst tr)
        Inferred.If p c a -> liftA3 Inferred.If (goExpr p) (goExpr c) (goExpr a)
        Inferred.Let ld b -> liftA2 Inferred.Let (goDef ld) (goExpr b)
        Inferred.Fun f -> fmap Inferred.Fun (goFun f)
        Inferred.Match m -> fmap Inferred.Match (goMatch m)
        Inferred.Ctor v sp inst ts ->
            liftA2 (Inferred.Ctor v sp) (secondM (mapM subst) inst) (mapM subst ts)
        Inferred.Sizeof t -> fmap Inferred.Sizeof (subst t)
--     goExpr :: Inferred.Expr -> Reader (Set TVar) Inferred.Expr
--     goExpr = \case
--         Inferred.Lit c -> pure (Inferred.Lit c)
--         Inferred.Var v -> Inferred.Var <$> secondM goTypedVar v
--         Inferred.App f as tr -> liftA3 Inferred.App (goExpr f) (mapM goExpr as) (subst tr)
--         Inferred.If p c a -> liftA3 Inferred.If (goExpr p) (goExpr c) (goExpr a)
--         Inferred.Let ld b -> liftA2 Inferred.Let (goDef ld) (goExpr b)
--         Inferred.Fun f -> fmap Inferred.Fun (goFun f)
--         Inferred.Match m -> fmap Inferred.Match (goMatch m)
--         Inferred.Ctor v sp inst ts ->
--             liftA2 (Inferred.Ctor v sp) (secondM (mapM subst) inst) (mapM subst ts)
--         Inferred.Sizeof t -> fmap Inferred.Sizeof (subst t)

    goPat :: Inferred.Pat -> Reader (Set TVar) Inferred.Pat
    goPat (Inferred.Pat pos t pat) = liftA2 (Inferred.Pat pos) (subst t) $ case pat of
        Inferred.PVar v -> fmap Inferred.PVar (goTypedVar v)
        Inferred.PWild -> pure Inferred.PWild
        Inferred.PCon con ps -> liftA2
            Inferred.PCon
            (fmap (\ts -> con { argTs = ts }) (mapM subst (argTs con)))
            (mapM goPat ps)
        Inferred.PBox p -> fmap Inferred.PBox (goPat p)
--     goPat :: Inferred.Pat -> Reader (Set TVar) Inferred.Pat
--     goPat (Inferred.Pat pos t pat) = liftA2 (Inferred.Pat pos) (subst t) $ case pat of
--         Inferred.PVar v -> fmap Inferred.PVar (goTypedVar v)
--         Inferred.PWild -> pure Inferred.PWild
--         Inferred.PCon con ps -> liftA2
--             Inferred.PCon
--             (fmap (\ts -> con { argTs = ts }) (mapM subst (argTs con)))
--             (mapM goPat ps)
--         Inferred.PBox p -> fmap Inferred.PBox (goPat p)

    goTypedVar (Inferred.TypedVar x t) = fmap (Inferred.TypedVar x) (subst t)
--     goTypedVar (Inferred.TypedVar x t) = fmap (Inferred.TypedVar x) (subst t)

    subst :: Inferred.Type -> Reader (Set TVar) Inferred.Type
    subst t =
        ask <&> \bound -> subst' (\tv -> if Set.member tv bound then Nothing else Just tUnit) t
--     subst :: Inferred.Type -> Reader (Set TVar) Inferred.Type
--     subst t =
--         ask <&> \bound -> subst' (\tv -> if Set.member tv bound then Nothing else Just tUnit) t

compileDecisionTrees :: MTypeDefs -> Inferred.Defs -> Except TypeErr Checked.Defs
compileDecisionTrees tdefs = compDefs
  where
    compDefs (Topo defs) = Topo <$> mapM compDef defs
-- compileDecisionTrees :: MTypeDefs -> Inferred.Defs -> Except TypeErr Checked.Defs
-- compileDecisionTrees tdefs = compDefs
--   where
--     compDefs (Topo defs) = Topo <$> mapM compDef defs

    compDef :: Inferred.Def -> Except TypeErr Checked.Def
    compDef = \case
        Inferred.VarDef (lhs, rhs) -> fmap (Checked.VarDef . (lhs, )) (secondM compExpr rhs)
        Inferred.RecDefs ds -> fmap Checked.RecDefs $ forM ds $ secondM (secondM compFun)
--     compDef :: Inferred.Def -> Except TypeErr Checked.Def
--     compDef = \case
--         Inferred.VarDef (lhs, rhs) -> fmap (Checked.VarDef . (lhs, )) (secondM compExpr rhs)
--         Inferred.RecDefs ds -> fmap Checked.RecDefs $ forM ds $ secondM (secondM compFun)

    compFun (params, (body, tbody)) = do
        body' <- compExpr body
        pure (params, (body', tbody))
--     compFun (params, (body, tbody)) = do
--         body' <- compExpr body
--         pure (params, (body', tbody))

    compMatch :: Inferred.Match -> Except TypeErr Checked.Expr
    compMatch (WithPos pos (ms, cs, tps, tb)) = do
        ms' <- mapM compExpr ms
        cs' <- mapM (secondM compExpr) cs
        case runExceptT (toDecisionTree tdefs pos tps cs') of
            Nothing -> pure (Checked.Absurd tb)
            Just e -> do
                dt <- liftEither e
                pure (Checked.Match ms' dt)
--     compMatch :: Inferred.Match -> Except TypeErr Checked.Expr
--     compMatch (WithPos pos (ms, cs, tps, tb)) = do
--         ms' <- mapM compExpr ms
--         cs' <- mapM (secondM compExpr) cs
--         case runExceptT (toDecisionTree tdefs pos tps cs') of
--             Nothing -> pure (Checked.Absurd tb)
--             Just e -> do
--                 dt <- liftEither e
--                 pure (Checked.Match ms' dt)

    compExpr :: Inferred.Expr -> Except TypeErr Checked.Expr
    compExpr ex = case ex of
        Inferred.Lit c -> pure (Checked.Lit c)
        Inferred.Var (virt, Inferred.TypedVar x t) ->
            pure (Checked.Var (virt, Checked.TypedVar x t))
        Inferred.App f as _ -> liftA2 Checked.App (compExpr f) (mapM compExpr as)
        Inferred.If p c a -> liftA3 Checked.If (compExpr p) (compExpr c) (compExpr a)
        Inferred.Let ld b -> liftA2 Checked.Let (compDef ld) (compExpr b)
        Inferred.Fun f -> fmap Checked.Fun (compFun f)
        Inferred.Match m -> compMatch m
        Inferred.Ctor v span' inst ts ->
            let xs = map (\n -> "x" ++ show n) (take (length ts) [0 ..] :: [Word])
                params = zip xs ts
                args = map (Checked.Var . (NonVirt, ) . uncurry Checked.TypedVar) params
                ret = Checked.Ction v span' inst args
                tret = Inferred.TConst inst
            in  pure $ if null params then ret else Checked.Fun (params, (ret, tret))
        Inferred.Sizeof t -> pure (Checked.Sizeof t)
--     compExpr :: Inferred.Expr -> Except TypeErr Checked.Expr
--     compExpr ex = case ex of
--         Inferred.Lit c -> pure (Checked.Lit c)
--         Inferred.Var (virt, Inferred.TypedVar x t) ->
--             pure (Checked.Var (virt, Checked.TypedVar x t))
--         Inferred.App f as _ -> liftA2 Checked.App (compExpr f) (mapM compExpr as)
--         Inferred.If p c a -> liftA3 Checked.If (compExpr p) (compExpr c) (compExpr a)
--         Inferred.Let ld b -> liftA2 Checked.Let (compDef ld) (compExpr b)
--         Inferred.Fun f -> fmap Checked.Fun (compFun f)
--         Inferred.Match m -> compMatch m
--         Inferred.Ctor v span' inst ts ->
--             let xs = map (\n -> "x" ++ show n) (take (length ts) [0 ..] :: [Word])
--                 params = zip xs ts
--                 args = map (Checked.Var . (NonVirt, ) . uncurry Checked.TypedVar) params
--                 ret = Checked.Ction v span' inst args
--                 tret = Inferred.TConst inst
--             in  pure $ if null params then ret else Checked.Fun (params, (ret, tret))
--         Inferred.Sizeof t -> pure (Checked.Sizeof t)

M src/Front/Checked.hs => src/Front/Checked.hs +121 -96
@@ 1,102 1,127 @@
module Front.Checked
    ( module Front.Checked
    , TVar(..)
    , TPrim(..)
    , Type
    , TConst
    , TConst'
    , Type'(..)
    , Scheme(..)
    , VariantIx
    , Span
    , Con(..)
    , mainType
    , Virt (..)
    )
where

import Data.Map (Map)
import Data.Word
import Data.Bifunctor
import qualified Data.Map as Map

import Misc
import Front.Inferred
    ( TVar(..)
    , TConst
    , Type
    , Scheme(..)
    , Const(..)
    , VariantIx
    , Span
    , Con(..)
    , Virt(..)
    )
import Front.TypeAst

data TypedVar = TypedVar String Type
    deriving (Show, Eq, Ord)
module Front.Checked where

import Front.Abstract

data Access
    = TopSel Word32 -- type of selectee
    = TopSel Word -- type of selectee
    | As Access Span VariantIx
    | Sel Access Word32 Span
    | Sel Access Word Span
    | ADeref Access
    deriving (Show, Eq, Ord)

type VarBindings = Map TypedVar Access

data DecisionTree
    = DLeaf (VarBindings, Expr)
    | DSwitch Span Access (Map VariantIx DecisionTree) DecisionTree
    | DSwitchStr Access (Map String DecisionTree) DecisionTree
    deriving Show

type Fun = ([(String, Type)], (Expr, Type))

type Var = (Virt, TypedVar)

data Expr
    = Lit Const
    | Var Var
    | App Expr [Expr]
    | If Expr Expr Expr
    | Fun Fun
    | Let Def Expr
    | Match [Expr] DecisionTree
    | Ction VariantIx Span TConst [Expr]
    | Sizeof Type
    | Absurd Type
    deriving (Show)

builtinExterns :: Map String Type
builtinExterns = Map.fromList
    [ ("GC_add_roots", TFun [TBox tUnit, TBox tUnit] tUnit)
    , ("GC_malloc", TFun [TPrim TNatSize] (TBox tUnit))
    , ("malloc", TFun [TPrim TNatSize] (TBox tUnit))
    , ("str_eq", TFun [tStr, tStr] tBool)
    , ("carth_backtrace_push", TFun [tStr] tUnit)
    , ("carth_backtrace_pop", TFun [] tUnit)
    ]

type Defs = TopologicalOrder Def
data Def = VarDef VarDef | RecDefs RecDefs deriving Show
type VarDef = (String, (Scheme, Expr))
type RecDefs = [(String, (Scheme, Fun))]
type TypeDefs = Map String ([TVar], [(String, [Type])])
type Externs = Map String Type

data Program = Program Defs TypeDefs Externs
    deriving Show


flattenDefs :: Defs -> [(String, (Scheme, Expr))]
flattenDefs (Topo defs) = defToVarDefs =<< defs

defToVarDefs :: Def -> [(String, (Scheme, Expr))]
defToVarDefs = \case
    VarDef d -> [d]
    RecDefs ds -> map funDefToVarDef ds

funDefToVarDef :: (String, (Scheme, Fun)) -> VarDef
funDefToVarDef = second (second Fun)
-- type VarBindings = Map TypedVar Access

-- data DecisionTree
--     = DLeaf (VarBindings, Expr)
--     | DSwitch Span Access (Map VariantIx DecisionTree) DecisionTree
--     | DSwitchStr Access (Map String DecisionTree) DecisionTree
--     deriving Show

-- type CheckedExpr = Expr DecisionTree

-- module Front.Checked
--     ( module Front.Checked
--     , TVar(..)
--     , TPrim(..)
--     , Type
--     , TConst
--     , TConst'
--     , Type'(..)
--     , Scheme(..)
--     , VariantIx
--     , Span
--     , Con(..)
-- --    , mainType
--     , Virt (..)
--     )
-- where

-- import Data.Map (Map)
-- import Data.Word
-- import Data.Bifunctor
-- import qualified Data.Map as Map

-- import Misc
-- import Name
-- import Front.Inferred
--     ( TVar(..)
--     , TConst
--     , Type
--     , Scheme(..)
--     , Const(..)
--     , VariantIx
--     , Span
--     , Con(..)
--     )
-- import Front.TypeAst

-- data TypedVar = TypedVar String Type
--     deriving (Show, Eq, Ord)

-- data Access
--     = TopSel Word32 -- type of selectee
--     | As Access Span VariantIx
--     | Sel Access Word32 Span
--     | ADeref Access
--     deriving (Show, Eq, Ord)

-- type VarBindings = Map TypedVar Access

-- data DecisionTree
--     = DLeaf (VarBindings, Expr)
--     | DSwitch Span Access (Map VariantIx DecisionTree) DecisionTree
--     | DSwitchStr Access (Map String DecisionTree) DecisionTree
--     deriving Show

-- type Fun = ([(String, Type)], (Expr, Type))

-- type Var = (Virt, TypedVar)

-- data Expr
--     = Lit Const
--     | Var Var
--     | App Expr [Expr]
--     | If Expr Expr Expr
--     | Fun Fun
--     | Let Def Expr
--     | Match [Expr] DecisionTree
--     | Ction VariantIx Span TConst [Expr]
--     | Sizeof Type
--     | Absurd Type
--     deriving (Show)

-- builtinExterns :: Map String Type
-- builtinExterns = Map.fromList
--     [ ("GC_add_roots", TFun [TBox tUnit, TBox tUnit] tUnit)
--     , ("GC_malloc", TFun [TPrim TNatSize] (TBox tUnit))
--     , ("malloc", TFun [TPrim TNatSize] (TBox tUnit))
--     , ("str_eq", TFun [tStr, tStr] tBool)
--     , ("carth_backtrace_push", TFun [tStr] tUnit)
--     , ("carth_backtrace_pop", TFun [] tUnit)
--     ]
--   where
--     tUnit = TConst (QName (Ident "Unit") QMBuiltin, [])
--     tStr = TConst (QName (Ident "Str") QMBuiltin, [])
--     tBool = TConst (QName (Ident "Bool") QMBuiltin, [])

-- type Defs = TopologicalOrder Def
-- data Def = VarDef VarDef | RecDefs RecDefs deriving Show
-- type VarDef = (String, (Scheme, Expr))
-- type RecDefs = [(String, (Scheme, Fun))]
-- type TypeDefs = Map String ([TVar], [(String, [Type])])
-- type Externs = Map String Type

-- data Program = Program Defs TypeDefs Externs
--     deriving Show


-- flattenDefs :: Defs -> [(String, (Scheme, Expr))]
-- flattenDefs (Topo defs) = defToVarDefs =<< defs

-- defToVarDefs :: Def -> [(String, (Scheme, Expr))]
-- defToVarDefs = \case
--     VarDef d -> [d]
--     RecDefs ds -> map funDefToVarDef ds

-- funDefToVarDef :: (String, (Scheme, Fun)) -> VarDef
-- funDefToVarDef = second (second Fun)

A src/Front/Concrete.hs => src/Front/Concrete.hs +175 -0
@@ 0,0 1,175 @@
-- {-# LANGUAGE DataKinds #-}

-- | Concrete syntax tree / parse tree
--
--   Not actually 100% concrete, because of macros, but it's close.
module Front.Concrete (module Front.Concrete, Const (..), TPrim (..), Ident (..)) where

-- import qualified Data.Set as Set
import Data.Set (Set)
-- import Control.Arrow ((>>>))
-- import Data.Bifunctor

import Name
import SrcPos

-- import FreeVars
import Front.TypeAst
import Front.Lexd (Const (..))

type LhsName = WithPos Ident

type ClassConstraint t n = (n, [(SrcPos, t)])

data Scheme t n = Forall SrcPos (Set TVar) (Set (ClassConstraint t n)) t
    deriving (Show, Eq)

data Pat n
    = PConstruction SrcPos n [Pat n]
    | PInt SrcPos Int
    | PStr SrcPos String
    | PVar LhsName
    | PBox SrcPos (Pat n)
    | PTuple SrcPos (Tuple (Pat n))
    -- TODO: Add special pattern for Lazy
    deriving (Show)

type FunPats n = WithPos [Pat n]

data Tuple a = Tuple [a] (Maybe a)
    deriving (Show, Eq, Ord)

data Expr' t n
    = Lit Const
    | Var n
    | App (Expr t n) [Expr t n]
    | If (Expr t n) (Expr t n) (Expr t n)
    | Let1 (DefLike t n) (Expr t n)
    | Let [DefLike t n] (Expr t n)
    | LetRec [Def t n] (Expr t n)
    | TypeAscr (Expr t n) t
    | Match (Expr t n) [(Pat n, Expr t n)]
    | Fun (FunPats n) (Expr t n)
    | DeBruijnFun Word (Expr t n)
    | DeBruijnIndex Word
    | FunMatch [(FunPats n, Expr t n)]
    | Ctor n
    | Sizeof t
    | ETuple (Tuple (Expr t n))
    deriving (Show)

type Expr t n = WithPos (Expr' t n)

data Def t n
    = VarDef SrcPos LhsName (Maybe (Scheme t n)) (Expr t n)
    | FunDef SrcPos LhsName (Maybe (Scheme t n)) (FunPats n) (Expr t n)
    | FunMatchDef SrcPos LhsName (Maybe (Scheme t n)) [(FunPats n, Expr t n)]
    deriving (Show)

data DefLike t n = Def (Def t n) | Deconstr (Pat n) (Expr t n)
    deriving (Show)

newtype ConstructorDefs t = ConstructorDefs [(LhsName, [t])]
    deriving (Show)

data TypeDef t
    = TypeDef LhsName [LhsName] (ConstructorDefs t)
    | TypeAlias LhsName [LhsName] t
    deriving (Show)

data Extern t = Extern
    { externName :: LhsName
    , externType :: t
    }
    deriving Show

data Module t n = Module
    { moduleImports :: [n]
    , moduleDefs :: [Def t n]
    , moduleTypes :: [TypeDef t]
    , moduleExterns :: [Extern t]
    }
    deriving Show


-- instance Eq Pat where
--     (==) = curry $ \case
--         (PConstruction _ x ps, PConstruction _ x' ps') -> x == x' && ps == ps'
--         (PVar x, PVar x') -> x == x'
--         _ -> False

-- instance FreeVars Def Ident where
--     freeVars = \case
--         VarDef _ _ _ rhs -> freeVars rhs
--         FunDef _ _ _ pats rhs ->
--             Set.difference (freeVars rhs) (Set.unions (map bvPat (unpos pats)))
--         FunMatchDef _ _ _ cs -> fvCases (map (first unpos) cs)

-- instance FreeVars DefLike Ident where
--     freeVars = \case
--         Def d -> freeVars d
--         Deconstr _ matchee -> freeVars matchee

-- instance FreeVars Expr Ident where
--     freeVars = fvExpr

-- instance HasPos (Id a) where
--     getPos (Id x) = getPos x

-- instance HasPos Pat where
--     getPos = \case
--         PConstruction p _ _ -> p
--         PInt p _ -> p
--         PStr p _ -> p
--         PVar v -> getPos v
--         PBox p _ -> p


-- fvExpr :: Expr -> Set Ident
-- fvExpr = unpos >>> fvExpr'
--   where
--     fvExpr' = \case
--         Lit _ -> Set.empty
--         Var x -> Set.singleton x
--         App f as -> fvApp f as
--         If p c a -> fvIf p c a
--         Let1 b e -> Set.union (freeVars b) (Set.difference (freeVars e) (bvDefLike b))
--         Let bs e -> foldr
--             (\b fvs -> Set.union (freeVars b) (Set.difference fvs (bvDefLike b)))
--             (freeVars e)
--             bs
--         LetRec ds e -> fvLet (unzip (map (\d -> (defLhs d, d)) ds)) e
--         TypeAscr e _t -> freeVars e
--         Match e cs -> fvMatch e cs
--         Fun (WithPos _ pats) e -> Set.difference (freeVars e) (Set.unions (map bvPat pats))
--         DeBruijnFun _ body -> freeVars body
--         DeBruijnIndex _ -> Set.empty
--         FunMatch cs -> fvCases (map (first unpos) cs)
--         Ctor _ -> Set.empty
--         Sizeof _t -> Set.empty
--     bvDefLike = \case
--         Def d -> Set.singleton (defLhs d)
--         Deconstr pat _ -> bvPat pat

-- defLhs :: Def -> Ident
-- defLhs = \case
--     VarDef _ lhs _ _ -> lhs
--     FunDef _ lhs _ _ _ -> lhs
--     FunMatchDef _ lhs _ _ -> lhs

-- fvMatch :: Expr -> [(Pat, Expr)] -> Set Ident
-- fvMatch e cs = Set.union (freeVars e) (fvCases (map (first pure) cs))

-- fvCases :: [([Pat], Expr)] -> Set Ident
-- fvCases = Set.unions . map (\(ps, e) -> Set.difference (freeVars e) (Set.unions (map bvPat ps)))

-- bvPat :: Pat -> Set Ident
-- bvPat = \case
--     PConstruction _ _ ps -> Set.unions (map bvPat ps)
--     PInt _ _ -> Set.empty
--     PStr _ _ -> Set.empty
--     PVar x -> Set.singleton x
--     PBox _ p -> bvPat p

-- idstr :: Id a -> String
-- idstr (Id (WithPos _ x)) = x

M src/Front/Err.hs => src/Front/Err.hs +7 -6
@@ 5,10 5,11 @@ module Front.Err (module Front.Err, TypeErr(..)) where
import Text.Megaparsec (match)

import Misc
import Front.SrcPos
import SrcPos
import Name
import Front.TypeAst
import qualified Front.Parsed as Parsed
import Front.Inferred
import Front.Abstract
--import Front.Inferred
import Pretty
import Front.Lex



@@ 50,7 51,7 @@ printTypeErr = \case
    RedundantCase p -> posd p "Redundant case in pattern match."
    InexhaustivePats p patStr -> posd p $ "Inexhaustive patterns: " ++ patStr ++ " not covered."
    ExternNotMonomorphic name tv -> case tv of
        TVExplicit (Parsed.Id (WithPos p tv')) ->
        TVExplicit (WithPos p (Ident tv')) ->
            posd p
                $ ("Extern " ++ pretty name ++ " is not monomorphic. ")
                ++ ("Type variable " ++ tv' ++ " encountered in type signature")


@@ 61,7 62,7 @@ printTypeErr = \case
            $ ("Type `" ++ x ++ "` ")
            ++ "has infinite size due to recursion without indirection.\n"
            ++ "Insert a pointer at some point to make it representable."
    UndefType p x -> posd p $ "Undefined type `" ++ x ++ "`."
    UndefType p x -> posd p $ "Undefined type `" ++ pretty x ++ "`."
    WrongMainType p s ->
        posd p
            $ "Incorrect type of `main`.\n"


@@ 71,7 72,7 @@ printTypeErr = \case
        posd p ("Non-function variable definition `" ++ x ++ "` is recursive.")
    TypeInstArityMismatch p t expected found ->
        posd p
            $ ("Arity mismatch for instantiation of type `" ++ t)
            $ ("Arity mismatch for instantiation of type `" ++ pretty t)
            ++ ("`.\nExpected " ++ show expected)
            ++ (", found " ++ show found)
    ConflictingVarDef p x -> posd p $ "Conflicting definitions for variable `" ++ x ++ "`."

M src/Front/Infer.hs => src/Front/Infer.hs +710 -700
@@ 1,704 1,714 @@
{-# LANGUAGE TemplateHaskell, DataKinds, RankNTypes #-}

module Front.Infer (inferTopDefs, checkType, checkTConst) where
module Front.Infer () where

import Prelude hiding (span)
import Lens.Micro.Platform (makeLenses, over, view, mapped, to, Lens')
import Control.Applicative hiding (Const(..))
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State.Strict
import Control.Monad.Writer
import Data.Bifunctor
import Data.Functor
import Data.Graph (SCC(..), stronglyConnComp)
import Data.List hiding (span)
import qualified Data.Map as Map
import Data.Map (Map)
import Data.Maybe
import qualified Data.Set as Set
import Data.Set (Set)
import Control.Arrow ((>>>))

import Misc
import Sizeof
import Front.SrcPos
import FreeVars
import Front.Subst
import qualified Front.Parsed as Parsed
import Front.Parsed (Id(..), IdCase(..), idstr, defLhs)
import Front.Err
import Front.Inferred
import Front.TypeAst hiding (TConst)


newtype ExpectedType = Expected Type
data FoundType = Found SrcPos Type

unFound :: FoundType -> Type
unFound (Found _ t) = t

type EqConstraint = (ExpectedType, FoundType)
type Constraints = ([EqConstraint], [(SrcPos, ClassConstraint)])

data Env = Env
    { _envTypeDefs :: TypeDefs
    -- Separarate global (and virtual) defs and local defs, because `generalize` only has to look
    -- at local defs.
    , _envVirtuals :: Map String Scheme
    , _envGlobDefs :: Map String Scheme
    , _envLocalDefs :: Map String Scheme
    -- | Maps a constructor to its variant index in the type definition it constructs, the
    --   signature/left-hand-side of the type definition, the types of its parameters, and the span
    --   (number of constructors) of the datatype
    , _envCtors :: Map String (VariantIx, (String, [TVar]), [Type], Span)
    , _freshParams :: [String]
    , _envDeBruijn :: [TypedVar]
    }
makeLenses ''Env

type FreshTVs = [String]

type Infer a = WriterT Constraints (ReaderT Env (StateT FreshTVs (Except TypeErr))) a

inferTopDefs :: TypeDefs -> Ctors -> Externs -> [Parsed.Def] -> Except TypeErr Defs
inferTopDefs tdefs ctors externs defs =
    let initEnv = Env { _envTypeDefs = tdefs
                      , _envVirtuals = builtinVirtuals
                      , _envGlobDefs = fmap (Forall Set.empty Set.empty) externs
                      , _envLocalDefs = Map.empty
                      , _envCtors = ctors
                      , _freshParams = freshParams
                      , _envDeBruijn = []
                      }
        freshTvs =
            let ls = "abcdehjkpqrstuvxyz"
                ns = map show [1 :: Word .. 99]
                vs = [ l : n | l <- ls, n <- ns ] ++ [ l : v | l <- ls, v <- vs ]
            in  vs
        freshParams = map (("generated/param" ++) . show) [0 :: Word ..]
    in  evalStateT (runReaderT (fmap fst (runWriterT (inferDefs envGlobDefs defs))) initEnv)
                   freshTvs
  where
    builtinVirtuals :: Map String Scheme
    builtinVirtuals =
        let
            tv a = TVExplicit (Parsed.Id (WithPos (SrcPos "<builtin>" 0 0 Nothing) a))
            tva = tv "a"
            ta = TVar tva
            tvb = tv "b"
            tb = TVar tvb
            arithScm =
                Forall (Set.fromList [tva]) (Set.singleton ("Num", [ta])) (TFun [ta, ta] ta)
            bitwiseScm =
                Forall (Set.fromList [tva]) (Set.singleton ("Bitwise", [ta])) (TFun [ta, ta] ta)
            relScm =
                Forall (Set.fromList [tva]) (Set.singleton ("Ord", [ta])) (TFun [ta, ta] tBool)
        in
            Map.fromList
                [ ("+", arithScm)
                , ("-", arithScm)
                , ("*", arithScm)
                , ("/", arithScm)
                , ("rem", arithScm)
                , ("shift-l", bitwiseScm)
                , ("lshift-r", bitwiseScm)
                , ("ashift-r", bitwiseScm)
                , ("bit-and", bitwiseScm)
                , ("bit-or", bitwiseScm)
                , ("bit-xor", bitwiseScm)
                , ("=", relScm)
                , ("/=", relScm)
                , (">", relScm)
                , (">=", relScm)
                , ("<", relScm)
                , ("<=", relScm)
                , ( "transmute"
                  , Forall (Set.fromList [tva, tvb])
                           (Set.singleton ("SameSize", [ta, tb]))
                           (TFun [ta] tb)
                  )
                , ("deref", Forall (Set.fromList [tva]) Set.empty (TFun [TBox ta] ta))
                , ("store", Forall (Set.fromList [tva]) Set.empty (TFun [ta, TBox ta] (TBox ta)))
                , ( "cast"
                  , Forall (Set.fromList [tva, tvb])
                           (Set.singleton ("Cast", [ta, tb]))
                           (TFun [ta] tb)
                  )
                ]

checkType :: MonadError TypeErr m => (Parsed.TConst -> m Type) -> Parsed.Type -> m Type
checkType checkTConst = go
  where
    go = \case
        Parsed.TVar v -> pure (TVar v)
        Parsed.TPrim p -> pure (TPrim p)
        Parsed.TConst tc -> checkTConst tc
        Parsed.TFun ps r -> liftA2 TFun (mapM go ps) (go r)
        Parsed.TBox t -> fmap TBox (go t)

-- TODO: Include SrcPos in Parsed.Type. The `pos` we're given here likely doesn't quite make sense.
checkType' :: SrcPos -> Parsed.Type -> Infer Type
checkType' pos t = do
    tdefs <- view envTypeDefs
    checkType (checkTConst tdefs pos) t

checkTConst :: MonadError TypeErr m => TypeDefs -> SrcPos -> Parsed.TConst -> m Type
checkTConst tdefs pos (x, args) = case Map.lookup x tdefs of
    Nothing -> throwError (UndefType pos x)
    Just (params, Data _) ->
        let expectedN = length params
            foundN = length args
        in  if expectedN == foundN
                then do
                    args' <- mapM go args
                    pure (TConst (x, args'))
                else throwError (TypeInstArityMismatch pos x expectedN foundN)
    Just (params, Alias _ u) -> subst (Map.fromList (zip params args)) <$> go u
    where go = checkType (checkTConst tdefs pos)

inferDefs :: Lens' Env (Map String Scheme) -> [Parsed.Def] -> Infer Defs
inferDefs envDefs defs = do
    checkNoDuplicateDefs Set.empty defs
    let ordered = orderDefs defs
    foldr
        (\scc inferRest -> do
            def <- inferComponent scc
            Topo rest <- augment envDefs (Map.fromList (defSigs def)) inferRest
            pure (Topo (def : rest))
        )
        (pure (Topo []))
        ordered
  where
    checkNoDuplicateDefs :: Set String -> [Parsed.Def] -> Infer ()
    checkNoDuplicateDefs already = uncons >>> fmap (first defLhs) >>> \case
        Just (Id (WithPos p x), ds) -> if Set.member x already
            then throwError (ConflictingVarDef p x)
            else checkNoDuplicateDefs (Set.insert x already) ds
        Nothing -> pure ()

-- For unification to work properly with mutually recursive functions, we need to create a
-- dependency graph of non-recursive / directly-recursive functions and groups of mutual
-- functions. We do this by creating a directed acyclic graph (DAG) of strongly connected
-- components (SCC), where a node is a definition and an edge is a reference to another
-- definition. For each SCC, we infer types for all the definitions / the single definition before
-- generalizing.
orderDefs :: [Parsed.Def] -> [SCC Parsed.Def]
orderDefs = stronglyConnComp . graph
    where graph = map (\d -> (d, defLhs d, Set.toList (freeVars d)))

inferComponent :: SCC Parsed.Def -> Infer Def
inferComponent = \case
    AcyclicSCC vert -> fmap VarDef (inferNonrecDef vert)
    CyclicSCC verts -> fmap RecDefs (inferRecDefs verts)

inferNonrecDef :: Parsed.Def -> Infer VarDef
inferNonrecDef = \case
    Parsed.FunDef dpos lhs mayscm params body -> do
        t <- fresh
        mayscm' <- checkScheme (idstr lhs) mayscm
        (fun, cs) <- listen $ inferDef t mayscm' dpos (inferFun dpos params body)
        (sub, ccs) <- solve cs
        env <- view envLocalDefs
        scm <- generalize (substEnv sub env) (fmap _scmConstraints mayscm') ccs (subst sub t)
        let fun' = substFun sub fun
        pure (idstr lhs, (scm, Fun fun'))
    Parsed.FunMatchDef dpos lhs mayscm cases -> do
        t <- fresh
        mayscm' <- checkScheme (idstr lhs) mayscm
        (fun, cs) <- listen $ inferDef t mayscm' dpos (inferFunMatch dpos cases)
        (sub, ccs) <- solve cs
        env <- view envLocalDefs
        scm <- generalize (substEnv sub env) (fmap _scmConstraints mayscm') ccs (subst sub t)
        let fun' = substFun sub fun
        pure (idstr lhs, (scm, Fun fun'))
    Parsed.VarDef dpos lhs mayscm body -> do
        t <- fresh
        mayscm' <- checkScheme (idstr lhs) mayscm
        (body', cs) <- listen $ inferDef t mayscm' dpos (infer body)
        -- TODO: Can't we get rid of this somehow? It makes our solution more complex and expensive
        --       if we have to do nested solves. Also re-solves many constraints in vain.
        --
        --       I think we should switch to bidirectional type checking. This will be fixed then.
        (sub, ccs) <- solve cs
        env <- view envLocalDefs
        scm <- generalize (substEnv sub env) (fmap _scmConstraints mayscm') ccs (subst sub t)
        let body'' = substExpr sub body'
        pure (idstr lhs, (scm, body''))

inferRecDefs :: [Parsed.Def] -> Infer RecDefs
inferRecDefs ds = do
    (names, mayscms', ts) <- fmap unzip3 $ forM ds $ \d -> do
        let (name, mayscm) = first idstr $ case d of
                Parsed.FunDef _ x s _ _ -> (x, s)
                Parsed.FunMatchDef _ x s _ -> (x, s)
                Parsed.VarDef _ x s _ -> (x, s)
        t <- fresh
        mayscm' <- checkScheme name mayscm
        pure (name, mayscm', t)
    let dummyDefs = Map.fromList $ zip names (map (Forall Set.empty Set.empty) ts)
    (fs, ucs) <- listen $ augment envLocalDefs dummyDefs $ mapM (uncurry3 inferRecDef)
                                                                (zip3 mayscms' ts ds)
    (sub, cs) <- solve ucs
    env <- view envLocalDefs
    scms <- zipWithM
        (\s -> generalize (substEnv sub env) (fmap _scmConstraints s) cs . subst sub)
        mayscms'
        ts
    let fs' = map (substFun sub) fs
    pure (zip names (zip scms fs'))
  where
    inferRecDef :: Maybe Scheme -> Type -> Parsed.Def -> Infer Fun
    inferRecDef mayscm t = \case
        Parsed.FunDef fpos _ _ params body -> inferDef t mayscm fpos $ inferFun fpos params body
        Parsed.FunMatchDef fpos _ _ cases -> inferDef t mayscm fpos $ inferFunMatch fpos cases
        Parsed.VarDef fpos _ _ (WithPos pos (Parsed.Fun params body)) ->
            inferDef t mayscm fpos (inferFun pos params body)
        Parsed.VarDef fpos _ _ (WithPos pos (Parsed.FunMatch cs)) ->
            inferDef t mayscm fpos (inferFunMatch pos cs)
        Parsed.VarDef _ (Id lhs) _ _ -> throwError (RecursiveVarDef lhs)

inferDef :: Type -> Maybe Scheme -> SrcPos -> Infer (Type, body) -> Infer body
inferDef t mayscm bodyPos inferBody = do
    whenJust mayscm $ \(Forall _ _ scmt) -> unify (Expected scmt) (Found bodyPos t)
    (t', body') <- inferBody
    unify (Expected t) (Found bodyPos t')
    pure body'

-- | Verify that user-provided type signature schemes are valid
checkScheme :: String -> Maybe Parsed.Scheme -> Infer (Maybe Scheme)
checkScheme = curry $ \case
    ("main", Nothing) -> pure (Just (Forall Set.empty Set.empty mainType))
    ("main", Just s@(Parsed.Forall pos vs cs t))
        | Set.size vs /= 0 || Set.size cs /= 0 || t /= mainType -> throwError (WrongMainType pos s)
    (_, Nothing) -> pure Nothing
    (_, Just (Parsed.Forall pos vs cs t)) -> do
        t' <- checkType' pos t
        cs' <- mapM (secondM (mapM (uncurry checkType'))) (Set.toList cs)
        let s1 = Forall vs (Set.fromList cs') t'
        env <- view envLocalDefs
        s2@(Forall vs2 _ t2) <- generalize env (Just (_scmConstraints s1)) Map.empty t'
        if (vs, t') == (vs2, t2) then pure (Just s1) else throwError (InvalidUserTypeSig pos s1 s2)

infer :: Parsed.Expr -> Infer (Type, Expr)
infer (WithPos pos e) = case e of
    Parsed.Lit l -> pure (litType l, Lit l)
    Parsed.Var (Id (WithPos p "_")) -> throwError (FoundHole p)
    Parsed.Var x -> fmap (second Var) (lookupVar x)
    Parsed.App f as -> do
        tas <- mapM (const fresh) as
        tr <- fresh
        (tf', f') <- infer f
        case tf' of
            TFun tps _ -> unless (length tps == length tas)
                $ throwError (FunArityMismatch pos (length tps) (length tas))
            _ -> pure () -- If it's not k
        (tas', as') <- unzip <$> mapM infer as
        unify (Expected (TFun tas tr)) (Found (getPos f) tf')
        forM_ (zip3 as tas tas') $ \(a, ta, ta') -> unify (Expected ta) (Found (getPos a) ta')
        pure (tr, App f' as' tr)
    Parsed.If p c a -> do
        (tp, p') <- infer p
        (tc, c') <- infer c
        (ta, a') <- infer a
        unify (Expected tBool) (Found (getPos p) tp)
        unify (Expected tc) (Found (getPos a) ta)
        pure (tc, If p' c' a')
    Parsed.Let1 def body -> inferLet1 pos def body
    Parsed.Let defs body ->
        -- FIXME: positions
        let (def, defs') = fromJust $ uncons defs
        in  inferLet1 pos def $ foldr (\d b -> WithPos pos (Parsed.Let1 d b)) body defs'
    Parsed.LetRec defs b -> do
        Topo defs' <- inferDefs envLocalDefs defs
        let withDef def inferX = do
                (tx, x') <- withLocals (defSigs def) inferX
                pure (tx, Let def x')
        foldr withDef (infer b) defs'
    Parsed.TypeAscr x t -> do
        (tx, x') <- infer x
        t' <- checkType' pos t
        unify (Expected t') (Found (getPos x) tx)
        pure (t', x')
    Parsed.Fun param body -> fmap (second Fun) (inferFun pos param body)
    Parsed.DeBruijnFun nparams body -> fmap (second Fun) (inferDeBruijnFun nparams body)
    Parsed.DeBruijnIndex ix -> do
        args <- view envDeBruijn
        if fromIntegral ix < length args
            then let tv@(TypedVar _ t) = args !! fromIntegral ix in pure (t, Var (NonVirt, tv))
            else throwError (DeBruijnIndexOutOfRange pos ix)
    Parsed.FunMatch cases -> fmap (second Fun) (inferFunMatch pos cases)
    Parsed.Match matchee cases -> inferMatch pos matchee cases
    Parsed.Ctor c -> do
        (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
        (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
        let tCtion = TConst tdefInst
        let t = if null cParams' then tCtion else TFun cParams' tCtion
        pure (t, Ctor variantIx cSpan tdefInst cParams')
    Parsed.Sizeof t -> fmap ((TPrim TNatSize, ) . Sizeof) (checkType' pos t)

inferLet1 :: SrcPos -> Parsed.DefLike -> Parsed.Expr -> Infer (Type, Expr)
inferLet1 pos defl body = case defl of
    Parsed.Def def -> do
        def' <- inferNonrecDef def
        (t, body') <- augment1 envLocalDefs (defSig def') (infer body)
        pure (t, Let (VarDef def') body')
    Parsed.Deconstr pat matchee -> inferMatch pos matchee [(pat, body)]

inferMatch :: SrcPos -> Parsed.Expr -> [(Parsed.Pat, Parsed.Expr)] -> Infer (Type, Expr)
inferMatch pos matchee cases = do
    (tmatchee, matchee') <- infer matchee
    (tbody, cases') <- inferCases [tmatchee]
                                  (map (first (\pat -> WithPos (getPos pat) [pat])) cases)
    pure (tbody, Match (WithPos pos ([matchee'], cases', [tmatchee], tbody)))

inferFun :: SrcPos -> Parsed.FunPats -> Parsed.Expr -> Infer (Type, Fun)
inferFun pos pats body = do
    (tpats, tbody, case') <- inferCase pats body
    let tpats' = map unFound tpats
    funMatchToFun pos [case'] tpats' (unFound tbody)

inferDeBruijnFun :: Word -> Parsed.Expr -> Infer (Type, Fun)
inferDeBruijnFun nparams body = genParams nparams $ \paramNames -> do
    tparams <- replicateM (fromIntegral nparams) fresh
    let params = zip paramNames tparams
        paramSigs = map (second (Forall Set.empty Set.empty)) params
        args = map (uncurry TypedVar) params
    (tbody, body') <- locallySet envDeBruijn args $ withLocals paramSigs (infer body)
    pure (TFun tparams tbody, (params, (body', tbody)))

inferFunMatch :: SrcPos -> [(Parsed.FunPats, Parsed.Expr)] -> Infer (Type, Fun)
inferFunMatch pos cases = do
    arity <- checkCasePatternsArity
    tpats <- replicateM arity fresh
    (tbody, cases') <- inferCases tpats cases
    funMatchToFun pos cases' tpats tbody
  where
    checkCasePatternsArity = case cases of
        [] -> ice "inferFunMatch: checkCasePatternsArity: fun* has no cases, arity 0"
        (pats0, _) : rest -> do
            let arity = length (unpos pats0)
            forM_ rest $ \(WithPos pos pats, _) -> unless
                (length pats == arity)
                (throwError (FunCaseArityMismatch pos arity (length pats)))
            pure arity

funMatchToFun :: SrcPos -> Cases -> [Type] -> Type -> Infer (Type, Fun)
funMatchToFun pos cases' tpats tbody = genParams (length tpats) $ \paramNames -> do
    let paramNames' = zipWith fromMaybe paramNames $ case cases' of
            [(WithPos _ ps, _)] -> flip map ps $ \(Pat _ _ p) -> case p of
                PVar (TypedVar x _) -> Just x
                _ -> Nothing
            _ -> repeat Nothing
        params = zip paramNames' tpats
        args = map (Var . (NonVirt, ) . uncurry TypedVar) params
    pure (TFun tpats tbody, (params, (Match (WithPos pos (args, cases', tpats, tbody)), tbody)))

-- | All the patterns must be of the same types, and all the bodies must be of the same type.
inferCases
    :: [Type] -- Type of matchee(s). Expected type(s) of pattern(s).
    -> [(WithPos [Parsed.Pat], Parsed.Expr)]
    -> Infer (Type, Cases)
inferCases tmatchees cases = do
    (tpatss, tbodies, cases') <- fmap unzip3 (mapM (uncurry inferCase) cases)
    forM_ tpatss $ zipWithM (unify . Expected) tmatchees
    tbody <- fresh
    forM_ tbodies (unify (Expected tbody))
    pure (tbody, cases')

inferCase
    :: WithPos [Parsed.Pat] -> Parsed.Expr -> Infer ([FoundType], FoundType, (WithPos [Pat], Expr))
inferCase (WithPos pos ps) b = do
    (tps, ps', pvss) <- fmap unzip3 (mapM inferPat ps)
    let pvs' = map (bimap Parsed.idstr (Forall Set.empty Set.empty . TVar))
                   (Map.toList (Map.unions pvss))
    (tb, b') <- withLocals pvs' (infer b)
    let tps' = zipWith Found (map getPos ps) tps
    pure (tps', Found (getPos b) tb, (WithPos pos ps', b'))

-- | Returns the type of the pattern; the pattern in the Pat format that the Match module wants,
--   and a Map from the variables bound in the pattern to fresh schemes.
inferPat :: Parsed.Pat -> Infer (Type, Pat, Map (Id 'Small) TVar)
inferPat pat = fmap (\(t, p, ss) -> (t, Pat (getPos pat) t p, ss)) (inferPat' pat)
  where
    inferPat' = \case
        Parsed.PConstruction pos c ps -> inferPatConstruction pos c ps
        Parsed.PInt _ n -> pure (TPrim TIntSize, intToPCon n 64, Map.empty)
        Parsed.PStr _ s ->
            let span' = ice "span of Con with VariantStr"
                p = PCon (Con (VariantStr s) span' []) []
            in  pure (tStr, p, Map.empty)
        Parsed.PVar (Id (WithPos _ "_")) -> do
            tv <- fresh
            pure (tv, PWild, Map.empty)
        Parsed.PVar x@(Id (WithPos _ x')) -> do
            tv <- fresh'
            pure (TVar tv, PVar (TypedVar x' (TVar tv)), Map.singleton x tv)
        Parsed.PBox _ p -> do
            (tp', p', vs) <- inferPat p
            pure (TBox tp', PBox p', vs)

    intToPCon n w = PCon
        (Con { variant = VariantIx (fromIntegral n), span = 2 ^ (w :: Integer), argTs = [] })
        []

    inferPatConstruction
        :: SrcPos -> Id 'Big -> [Parsed.Pat] -> Infer (Type, Pat', Map (Id 'Small) TVar)
    inferPatConstruction pos c cArgs = do
        (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
        let arity = length cParams
        let nArgs = length cArgs
        unless (arity == nArgs) (throwError (CtorArityMismatch pos (idstr c) arity nArgs))
        (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
        let t = TConst tdefInst
        (cArgTs, cArgs', cArgsVars) <- fmap unzip3 (mapM inferPat cArgs)
        cArgsVars' <- nonconflictingPatVarDefs cArgsVars
        forM_ (zip3 cParams' cArgTs cArgs)
            $ \(cParamT, cArgT, cArg) -> unify (Expected cParamT) (Found (getPos cArg) cArgT)
        let con = Con { variant = VariantIx variantIx, span = cSpan, argTs = cArgTs }
        pure (t, PCon con cArgs', cArgsVars')

    nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->
        case listToMaybe (Map.keys (Map.intersection acc ks)) of
            Just (Id (WithPos pos v)) -> throwError (ConflictingPatVarDefs pos v)
            Nothing -> pure (Map.union acc ks)

instantiateConstructorOfTypeDef :: (String, [TVar]) -> [Type] -> Infer (TConst, [Type])
instantiateConstructorOfTypeDef (tName, tParams) cParams = do
    tVars <- mapM (const fresh) tParams
    let cParams' = map (subst (Map.fromList (zip tParams tVars))) cParams
    pure ((tName, tVars), cParams')

lookupEnvConstructor :: Id 'Big -> Infer (VariantIx, (String, [TVar]), [Type], Span)
lookupEnvConstructor (Id (WithPos pos cx)) =
    view (envCtors . to (Map.lookup cx)) >>= maybe (throwError (UndefCtor pos cx)) pure

litType :: Const -> Type
litType = \case
    Int _ -> TPrim TIntSize
    F64 _ -> TPrim TF64
    Str _ -> tStr

lookupVar :: Id 'Small -> Infer (Type, Var)
lookupVar (Id (WithPos pos x)) = do
    virt <- fmap (Map.lookup x) (view envVirtuals)
    glob <- fmap (Map.lookup x) (view envGlobDefs)
    local <- fmap (Map.lookup x) (view envLocalDefs)
    case fmap (NonVirt, ) (local <|> glob) <|> fmap (Virt, ) virt of
        Just (virt, scm) -> instantiate pos scm <&> \t -> (t, (virt, TypedVar x t))
        Nothing -> throwError (UndefVar pos x)

genParams :: Integral n => n -> ([String] -> Infer a) -> Infer a
genParams n f = do
    ps <- view (freshParams . to (take (fromIntegral n)))
    locally freshParams (drop (fromIntegral n)) (f ps)

withLocals :: [(String, Scheme)] -> Infer a -> Infer a
withLocals = augment envLocalDefs . Map.fromList

instantiate :: SrcPos -> Scheme -> Infer Type
instantiate pos (Forall params constraints t) = do
    s <- Map.fromList <$> zipWithM (fmap . (,)) (Set.toList params) (repeat fresh)
    forM_ constraints $ \c -> unifyClass pos (substClassConstraint s c)
    pure (subst s t)

generalize
    :: (MonadError TypeErr m)
    => Map String Scheme
    -> Maybe (Set ClassConstraint)
    -> Map ClassConstraint SrcPos
    -> Type
    -> m Scheme
generalize env mayGivenCs allCs t = fmap (\cs -> Forall vs cs t) constraints
  where
    -- A constraint should be included in a signature if the type variables include at least one of
    -- the signature's forall-qualified tvars, and the rest of the tvars exist in the surrounding
    -- environment. If a tvar is not from the signature or the environment, it comes from an inner
    -- definition, and should already have been included in that signature.
    --
    -- TODO: Maybe we should handle the propagation of class constraints in a better way, so that
    --       ones belonging to inner definitions no longer exist at this point.
    constraints = fmap (Set.fromList . map fst) $ flip filterM (Map.toList allCs) $ \(c, pos) ->
        let vcs = ftvClassConstraint c
            belongs =
                any (flip Set.member vs) vcs
                    && all (\vc -> Set.member vc vs || Set.member vc ftvEnv) vcs
        in  if belongs
                then if matchesGiven c then pure True else throwError (NoClassInstance pos c)
                else pure False
    matchesGiven = case mayGivenCs of
        Just gcs -> flip Set.member gcs
        Nothing -> const True
    vs = Set.difference (ftv t) ftvEnv
    ftvEnv = Set.unions (map ftvScheme (Map.elems env))
    ftvScheme (Forall tvs _ t) = Set.difference (ftv t) tvs

substEnv :: Subst' -> Map String Scheme -> Map String Scheme
substEnv s = over (mapped . scmBody) (subst s)

ftvClassConstraint :: ClassConstraint -> Set TVar
ftvClassConstraint = mconcat . map ftv . snd

substClassConstraint :: Subst' -> ClassConstraint -> ClassConstraint
substClassConstraint sub = second (map (subst sub))

fresh :: Infer Type
fresh = fmap TVar fresh'

fresh' :: Infer TVar
fresh' = fmap TVImplicit (gets head <* modify tail)

unify :: ExpectedType -> FoundType -> Infer ()
unify e f = tell ([(e, f)], [])

unifyClass :: SrcPos -> ClassConstraint -> Infer ()
unifyClass p c = tell ([], [(p, c)])

data UnifyErr = UInfType TVar Type | UFailed Type Type

-- TODO: I actually don't really like this approach of keeping the unification solver separate from
--       the inferrer. The approach of doing it "inline" is, at least in some ways, more flexible,
--       and probably more performant. Consider this further -- maybe there's a big con I haven't
--       considered or have forgotten. Will updating the substitution map work well? How would it
--       work for nested inferDefs, compared to now?
solve :: Constraints -> Infer (Subst', Map ClassConstraint SrcPos)
solve (eqcs, ccs) = do
    sub <- lift $ lift $ lift $ solveUnis Map.empty eqcs
    ccs' <- solveClassCs (map (second (substClassConstraint sub)) ccs)
    pure (sub, ccs')
  where
    solveUnis :: Subst' -> [EqConstraint] -> Except TypeErr Subst'
    solveUnis sub1 = \case
        [] -> pure sub1
        (Expected et, Found pos ft) : cs -> do
            sub2 <- withExcept (toTypeErr pos et ft) (unifies et ft)
            solveUnis (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)

    solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Map ClassConstraint SrcPos)
    solveClassCs = fmap Map.unions . mapM solveClassConstraint

    solveClassConstraint :: (SrcPos, ClassConstraint) -> Infer (Map ClassConstraint SrcPos)
    solveClassConstraint (pos, c) = case c of
        -- Virtual classes
        ("SameSize", [ta, tb]) -> sameSize (ta, tb)
        ("Cast", [ta, tb]) -> cast (ta, tb)
        ("Num", [ta]) -> case ta of
            TPrim _ -> ok
            TVar _ -> propagate
            TConst _ -> err
            TFun _ _ -> err
            TBox _ -> err
        ("Bitwise", [ta]) -> case ta of
            TPrim p | isIntegral p -> ok
            TPrim _ -> err
            TVar _ -> propagate
            TConst _ -> err
            TFun _ _ -> err
            TBox _ -> err
        ("Ord", [ta]) -> case ta of
            TPrim _ -> ok
            TVar _ -> propagate
            TConst _ -> err
            TFun _ _ -> err
            TBox _ -> err
        -- "Real classes"
        -- ... TODO
        _ -> ice $ "solveClassCs: invalid class constraint " ++ show c
      where
        ok = pure Map.empty
        propagate = pure (Map.singleton c pos)
        err = throwError (NoClassInstance pos c)
        isIntegral = \case
            TInt _ -> True
            TIntSize -> True
            TNat _ -> True
            TNatSize -> True
            _ -> False

        -- TODO: Maybe we should move the check against user-provided explicit signature from
        --       `generalize` to here. Like, we could keep the explicit scheme (if there is one) in
        --       the `Env`.
        --
        -- | As the name indicates, a predicate that is true / class that is instanced when two
        --   types are of the same size. If the size for either cannot be determined yet due to
        --   polymorphism, the constraint is propagated.
        sameSize :: (Type, Type) -> Infer (Map ClassConstraint SrcPos)
        sameSize (ta, tb) = do
            sizeof'' <- sizeof . sizeofTypeDef <$> view envTypeDefs
            case liftA2 (==) (sizeof'' ta) (sizeof'' tb) of
                _ | ta == tb -> ok
                Right True -> ok
                Right False -> err
                -- One or both of the two types are of unknown size due to polymorphism, so
                -- propagate the constraint to the scheme of the definition.
                Left _ -> propagate

        sizeofTypeDef tdefs (x, args) = case Map.lookup x tdefs of
            Just (params, Data variants) ->
                let sub = Map.fromList (zip params args)
                    datas = map (map (subst sub) . snd) variants
                in  sizeofData (sizeofTypeDef tdefs) (alignofTypeDef tdefs) datas
            Just (params, Alias _ t) ->
                let sub = Map.fromList (zip params args)
                in  sizeof (sizeofTypeDef tdefs) (subst sub t)
            Nothing -> ice $ "Infer.sizeofTypeDef: undefined type " ++ show x
        alignofTypeDef tdefs (x, args) = case Map.lookup x tdefs of
            Just (params, Data variants) ->
                let sub = Map.fromList (zip params args)
                    datas = map (map (subst sub) . snd) variants
                in  alignmentofData (alignofTypeDef tdefs) datas
            Just (params, Alias _ t) ->
                let sub = Map.fromList (zip params args)
                in  alignmentof (alignofTypeDef tdefs) (subst sub t)
            Nothing -> ice $ "Infer.sizeofTypeDef: undefined type " ++ show x

        -- | This class is instanced when the first type can be `cast` to the other.
        cast :: (Type, Type) -> Infer (Map ClassConstraint SrcPos)
        cast = \case
            (ta, tb) | ta == tb -> ok
            (TPrim _, TPrim _) -> ok
            (TVar _, _) -> propagate
            (_, TVar _) -> propagate
            (TConst _, _) -> err
            (_, TConst _) -> err
            (TFun _ _, _) -> err
            (_, TFun _ _) -> err
            (TBox _, _) -> err
            (_, TBox _) -> err

    substConstraint sub (Expected t1, Found pos t2) =
        (Expected (subst sub t1), Found pos (subst sub t2))

    toTypeErr :: SrcPos -> Type -> Type -> UnifyErr -> TypeErr
    toTypeErr pos t1 t2 = \case
        UInfType a t -> InfType pos t1 t2 a t
        UFailed t'1 t'2 -> UnificationFailed pos t1 t2 t'1 t'2

-- FIXME: Keep track of whether we've flipped the arguments. Alternatively, keep right stuff to the
--        right and vice versa. If we don't, we get confusing type errors.
unifies :: Type -> Type -> Except UnifyErr Subst'
unifies = curry $ \case
    (TPrim a, TPrim b) | a == b -> pure Map.empty
    (TConst (c0, ts0), TConst (c1, ts1)) | c0 == c1 -> if length ts0 /= length ts1
        then ice "lengths of TConst params differ in unify"
        else unifiesMany (zip ts0 ts1)
    (TVar a, TVar b) | a == b -> pure Map.empty
    (TVar a, t) | occursIn a t -> throwError (UInfType a t)
    -- Do not allow "override" of explicit (user given) type variables.
    (a@(TVar (TVExplicit _)), b@(TVar (TVImplicit _))) -> unifies b a
    (a@(TVar (TVExplicit _)), b) -> throwError (UFailed a b)
    (TVar a, t) -> pure (Map.singleton a t)
    (t, TVar a) -> unifies (TVar a) t
    (t@(TFun ts1 t2), u@(TFun us1 u2)) -> if length ts1 /= length us1
        then throwError (UFailed t u)
        else unifiesMany (zip (ts1 ++ [t2]) (us1 ++ [u2]))
    (TBox t, TBox u) -> unifies t u
    (t1, t2) -> throwError (UFailed t1 t2)
  where
    unifiesMany :: [(Type, Type)] -> Except UnifyErr Subst'
    unifiesMany = foldM
        (\s (t, u) -> fmap (flip composeSubsts s) (unifies (subst s t) (subst s u)))
        Map.empty

    occursIn :: TVar -> Type -> Bool
    occursIn a t = Set.member a (ftv t)
-- import Lens.Micro.Platform (makeLenses, over, view, mapped, to, Lens')
-- import Control.Applicative hiding (Const(..))
-- import Control.Monad.Except
-- import Control.Monad.Reader
-- import Control.Monad.State.Strict
-- import Control.Monad.Writer
-- import Data.Bifunctor
-- import Data.Functor
-- import Data.Graph (SCC(..), stronglyConnComp)
-- import Data.List hiding (span)
-- import qualified Data.Map as Map
-- import Data.Map (Map)
-- import Data.Maybe
-- import qualified Data.Set as Set
-- import Data.Set (Set)
-- import Control.Arrow ((>>>))

-- import Misc
-- import Sizeof
-- import SrcPos
-- import Name
-- import FreeVars
-- import Front.Subst
-- import Front.Err
-- import Front.Resolved
-- import qualified Front.Concrete as Cst
-- import Front.Abstract
-- import Front.Inferred
-- --import qualified Front.Inferred as Inferred
-- import Front.TypeAst


-- newtype ExpectedType = Expected Type
-- data FoundType = Found SrcPos Type

-- unFound :: FoundType -> Type
-- unFound (Found _ t) = t

-- type EqConstraint = (ExpectedType, FoundType)
-- type Constraints = ([EqConstraint], [(SrcPos, ClassConstraint)])

-- data Env = Env
--     { _envTypeDefs :: TypeDefs
--     -- Separarate global (and virtual) defs and local defs, because `generalize` only has to look
--     -- at local defs.
--     , _envVirtuals :: Map String Scheme
--     , _envGlobDefs :: Map QualName Scheme
--     , _envLocalDefs :: Map LhsName Scheme
--     -- | Maps a constructor to its variant index in the type definition it constructs, the
--     --   signature/left-hand-side of the type definition, the types of its parameters, and the span
--     --   (number of constructors) of the datatype
--     , _envCtors :: Map String (VariantIx, (String, [TVar]), [Type], Span)
--     , _freshParams :: [String]
--     , _envDeBruijn :: [(QualName, Type)]
--     }
-- makeLenses ''Env

-- type FreshTVs = [String]

-- type Infer a = WriterT Constraints (ReaderT Env (StateT FreshTVs (Except TypeErr))) a

-- -- TODO: With on-demand parsing, we shouldn't be given all these things I think. Instead, we should
-- --       take only a file name, and get all externs, globals, etc as we query them.
-- inferTopDefs :: TypeDefs -> Ctors -> Externs -> [ResolvedDef] -> Except TypeErr InferredDefs
-- inferTopDefs tdefs ctors externs defs =
--     let initEnv = Env { _envTypeDefs = tdefs
--                       , _envVirtuals = builtinVirtuals
--                       , _envGlobDefs = fmap (Forall Set.empty Set.empty) externs
--                       , _envLocalDefs = Map.empty
--                       , _envCtors = ctors
--                       , _freshParams = freshParams
--                       , _envDeBruijn = []
--                       }
--         freshTvs =
--             let ls = "abcdehjkpqrstuvxyz"
--                 ns = map show [1 :: Word .. 99]
--                 vs = [ l : n | l <- ls, n <- ns ] ++ [ l : v | l <- ls, v <- vs ]
--             in  vs
--         freshParams = map (("generated/param" ++) . show) [0 :: Word ..]
--     in  evalStateT (runReaderT (fmap fst (runWriterT (inferDefs envGlobDefs defs))) initEnv)
--                    freshTvs
--   where
--     builtinVirtuals :: Map String Scheme
--     builtinVirtuals =
--         let
--             tv0 = TVImplicit 0
--             tv1 = TVImplicit 1
--             t0 = TVar tv0
--             t1 = TVar tv1
--             arithScm =
--                 Forall (Set.fromList [tv0]) (Set.singleton ("Num", [t0])) (TFun [t0, t0] t0)
--             bitwiseScm =
--                 Forall (Set.fromList [tv0]) (Set.singleton ("Bitwise", [t0])) (TFun [t0, t0] t0)
--             relScm =
--                 Forall (Set.fromList [tv0]) (Set.singleton ("Ord", [t0])) (TFun [t0, t0] tBool)
--         in
--             Map.fromList
--                 [ ("+", arithScm)
--                 , ("-", arithScm)
--                 , ("*", arithScm)
--                 , ("div", arithScm)
--                 , ("rem", arithScm)
--                 , ("shift-l", bitwiseScm)
--                 , ("lshift-r", bitwiseScm)
--                 , ("ashift-r", bitwiseScm)
--                 , ("bit-and", bitwiseScm)
--                 , ("bit-or", bitwiseScm)
--                 , ("bit-xor", bitwiseScm)
--                 , ("=", relScm)
--                 , ("not=", relScm)
--                 , (">", relScm)
--                 , (">=", relScm)
--                 , ("<", relScm)
--                 , ("<=", relScm)
--                 , ( "transmute"
--                   , Forall (Set.fromList [tv0, tv1])
--                            (Set.singleton ("SameSize", [t0, t1]))
--                            (TFun [t0] t1)
--                   )
--                 , ("deref", Forall (Set.fromList [tv0]) Set.empty (TFun [TBox t0] t0))
--                 , ("store", Forall (Set.fromList [tv0]) Set.empty (TFun [t0, TBox t0] (TBox t0)))
--                 , ( "cast"
--                   , Forall (Set.fromList [tv0, tv1])
--                            (Set.singleton ("Cast", [t0, t1]))
--                            (TFun [t0] t1)
--                   )
--                 ]

-- checkType :: MonadError TypeErr m => (ResolvedTConst -> m Type) -> ResolvedType -> m Type
-- checkType checkTConst = go
--   where
--     go = \case
--         TVar v -> pure (TVar v)
--         TPrim p -> pure (TPrim p)
--         TConst tc -> checkTConst tc
--         TFun ps r -> liftA2 TFun (mapM go ps) (go r)
--         TBox t -> fmap TBox (go t)

-- -- TODO: Include SrcPos in Cst.Type. The `pos` we're given here likely doesn't quite make sense.
-- checkType' :: SrcPos -> ResolvedType -> Infer Type
-- checkType' pos t = do
--     tdefs <- view envTypeDefs
--     checkType (checkTConst tdefs pos) t

-- checkTConst :: MonadError TypeErr m => TypeDefs -> SrcPos -> ResolvedTConst -> m Type
-- checkTConst tdefs pos (x, args) = case Map.lookup x tdefs of
--     Nothing -> throwError (UndefType pos x)
--     Just (params, Data _) ->
--         let expectedN = length params
--             foundN = length args
--         in  if expectedN == foundN
--                 then do
--                     args' <- mapM go args
--                     pure (TConst (x, args'))
--                 else throwError (TypeInstArityMismatch pos x expectedN foundN)
--     Just (params, Alias _ u) -> subst (Map.fromList (zip params args)) <$> go u
--     where go = checkType (checkTConst tdefs pos)

-- inferDefs :: Lens' Env (Map LhsName Scheme) -> [ResolvedDef] -> Infer InferredDefs
-- inferDefs envDefs defs = do
--     checkNoDuplicateDefs Set.empty defs
--     let ordered = orderDefs defs
--     foldr
--         (\scc inferRest -> do
--             def <- inferComponent scc
--             Topo rest <- augment envDefs (Map.fromList (defSigs def)) inferRest
--             pure (Topo (def : rest))
--         )
--         (pure (Topo []))
--         ordered
--   where
--     checkNoDuplicateDefs :: Set String -> [ResolvedDef] -> Infer ()
--     checkNoDuplicateDefs already = uncons >>> fmap (first defLhs) >>> \case
--         Just (WithPos p (Ident x), ds) -> if Set.member x already
--             then throwError (ConflictingVarDef p x)
--             else checkNoDuplicateDefs (Set.insert x already) ds
--         Nothing -> pure ()

-- -- For unification to work properly with mutually recursive functions, we need to create a
-- -- dependency graph of non-recursive / directly-recursive functions and groups of mutual
-- -- functions. We do this by creating a directed acyclic graph (DAG) of strongly connected
-- -- components (SCC), where a node is a definition and an edge is a reference to another
-- -- definition. For each SCC, we infer types for all the definitions / the single definition before
-- -- generalizing.
-- orderDefs :: [ResolvedDef] -> [SCC ResolvedDef]
-- orderDefs = stronglyConnComp . graph
--     where graph = map (\d -> (d, defLhs d, Set.toList (freeVars d)))

-- inferComponent :: SCC ResolvedDef -> Infer InferredDef
-- inferComponent = \case
--     AcyclicSCC vert -> fmap VarDef (inferNonrecDef vert)
--     CyclicSCC verts -> fmap RecDefs (inferRecDefs verts)

-- inferNonrecDef :: ResolvedDef -> Infer InferredVarDef
-- inferNonrecDef = \case
--     Cst.FunDef dpos lhs mayscm params body -> do
--         t <- fresh
--         mayscm' <- checkScheme (idstr lhs) mayscm
--         (fun, cs) <- listen $ inferDef t mayscm' dpos (inferFun dpos params body)
--         (sub, ccs) <- solve cs
--         env <- view envLocalDefs
--         scm <- generalize (substEnv sub env) (fmap _scmConstraints mayscm') ccs (subst sub t)
--         let fun' = substFun sub fun
--         pure (idstr lhs, (scm, Fun fun'))
--     Cst.FunMatchDef dpos lhs mayscm cases -> do
--         t <- fresh
--         mayscm' <- checkScheme (idstr lhs) mayscm
--         (fun, cs) <- listen $ inferDef t mayscm' dpos (inferFunMatch dpos cases)
--         (sub, ccs) <- solve cs
--         env <- view envLocalDefs
--         scm <- generalize (substEnv sub env) (fmap _scmConstraints mayscm') ccs (subst sub t)
--         let fun' = substFun sub fun
--         pure (idstr lhs, (scm, Fun fun'))
--     Cst.VarDef dpos lhs mayscm body -> do
--         t <- fresh
--         mayscm' <- checkScheme (idstr lhs) mayscm
--         (body', cs) <- listen $ inferDef t mayscm' dpos (infer body)
--         -- TODO: Can't we get rid of this somehow? It makes our solution more complex and expensive
--         --       if we have to do nested solves. Also re-solves many constraints in vain.
--         --
--         --       I think we should switch to bidirectional type checking. This will be fixed then.
--         (sub, ccs) <- solve cs
--         env <- view envLocalDefs
--         scm <- generalize (substEnv sub env) (fmap _scmConstraints mayscm') ccs (subst sub t)
--         let body'' = substExpr sub body'
--         pure (idstr lhs, (scm, body''))

-- inferRecDefs :: [ResolvedDef] -> Infer InferredRecDefs
-- inferRecDefs ds = do
--     (names, mayscms', ts) <- fmap unzip3 $ forM ds $ \d -> do
--         let (name, mayscm) = first idstr $ case d of
--                 Cst.FunDef _ x s _ _ -> (x, s)
--                 Cst.FunMatchDef _ x s _ -> (x, s)
--                 Cst.VarDef _ x s _ -> (x, s)
--         t <- fresh
--         mayscm' <- checkScheme name mayscm
--         pure (name, mayscm', t)
--     let dummyDefs = Map.fromList $ zip names (map (Forall Set.empty Set.empty) ts)
--     (fs, ucs) <- listen $ augment envLocalDefs dummyDefs $ mapM (uncurry3 inferRecDef)
--                                                                 (zip3 mayscms' ts ds)
--     (sub, cs) <- solve ucs
--     env <- view envLocalDefs
--     scms <- zipWithM
--         (\s -> generalize (substEnv sub env) (fmap _scmConstraints s) cs . subst sub)
--         mayscms'
--         ts
--     let fs' = map (substFun sub) fs
--     pure (zip names (zip scms fs'))
--   where
--     inferRecDef :: Maybe Scheme -> Type -> ResolvedDef -> Infer InferredFun
--     inferRecDef mayscm t = \case
--         Cst.FunDef fpos _ _ params body -> inferDef t mayscm fpos $ inferFun fpos params body
--         Cst.FunMatchDef fpos _ _ cases -> inferDef t mayscm fpos $ inferFunMatch fpos cases
--         Cst.VarDef fpos _ _ (WithPos pos (Cst.Fun params body)) ->
--             inferDef t mayscm fpos (inferFun pos params body)
--         Cst.VarDef fpos _ _ (WithPos pos (Cst.FunMatch cs)) ->
--             inferDef t mayscm fpos (inferFunMatch pos cs)
--         Cst.VarDef _ (WithPos _ (Ident lhs)) _ _ -> throwError (RecursiveVarDef lhs)

-- inferDef :: Type -> Maybe Scheme -> SrcPos -> Infer (Type, body) -> Infer body
-- inferDef t mayscm bodyPos inferBody = do
--     whenJust mayscm $ \(Forall _ _ scmt) -> unify (Expected scmt) (Found bodyPos t)
--     (t', body') <- inferBody
--     unify (Expected t) (Found bodyPos t')
--     pure body'

-- -- | Verify that user-provided type signature schemes are valid
-- checkScheme :: String -> Maybe ResolvedScheme -> Infer (Maybe Scheme)
-- checkScheme = curry $ \case
--     ("main", Nothing) -> pure (Just (Forall Set.empty Set.empty mainType))
--     ("main", Just s@(Cst.Forall pos vs cs t))
--         | Set.size vs /= 0 || Set.size cs /= 0 || t /= mainType -> throwError (WrongMainType pos s)
--     (_, Nothing) -> pure Nothing
--     (_, Just (Cst.Forall pos vs cs t)) -> do
--         t' <- checkType' pos t
--         cs' <- mapM (secondM (mapM (uncurry checkType'))) (Set.toList cs)
--         let s1 = Forall vs (Set.fromList cs') t'
--         env <- view envLocalDefs
--         s2@(Forall vs2 _ t2) <- generalize env (Just (_scmConstraints s1)) Map.empty t'
--         if (vs, t') == (vs2, t2) then pure (Just s1) else throwError (InvalidUserTypeSig pos s1 s2)

-- infer :: ResolvedExpr -> Infer (Type, InferredExpr)
-- infer (WithPos pos e) = case e of
--     Cst.Lit l -> pure (litType l, Lit l)
--     Cst.Var (WithPos p (Ident "_")) -> throwError (FoundHole p)
--     Cst.Var x -> fmap (second Var) (lookupVar x)
--     Cst.App f as -> do
--         tas <- mapM (const fresh) as
--         tr <- fresh
--         (tf', f') <- infer f
--         case tf' of
--             TFun tps _ -> unless (length tps == length tas)
--                 $ throwError (FunArityMismatch pos (length tps) (length tas))
--             _ -> pure () -- If it's not k
--         (tas', as') <- unzip <$> mapM infer as
--         unify (Expected (TFun tas tr)) (Found (getPos f) tf')
--         forM_ (zip3 as tas tas') $ \(a, ta, ta') -> unify (Expected ta) (Found (getPos a) ta')
--         pure (tr, App f' as' tr)
--     Cst.If p c a -> do
--         (tp, p') <- infer p
--         (tc, c') <- infer c
--         (ta, a') <- infer a
--         unify (Expected tBool) (Found (getPos p) tp)
--         unify (Expected tc) (Found (getPos a) ta)
--         pure (tc, If p' c' a')
--     Cst.Let1 def body -> inferLet1 pos def body
--     Cst.Let defs body ->
--         -- FIXME: positions
--         let (def, defs') = fromJust $ uncons defs
--         in  inferLet1 pos def $ foldr (\d b -> WithPos pos (Cst.Let1 d b)) body defs'
--     Cst.LetRec defs b -> do
--         Topo defs' <- inferDefs envLocalDefs defs
--         let withDef def inferX = do
--                 (tx, x') <- withLocals (defSigs def) inferX
--                 pure (tx, Let def x')
--         foldr withDef (infer b) defs'
--     Cst.TypeAscr x t -> do
--         (tx, x') <- infer x
--         t' <- checkType' pos t
--         unify (Expected t') (Found (getPos x) tx)
--         pure (t', x')
--     Cst.Fun param body -> fmap (second Fun) (inferFun pos param body)
--     Cst.DeBruijnFun nparams body -> fmap (second Fun) (inferDeBruijnFun nparams body)
--     Cst.DeBruijnIndex ix -> do
--         args <- view envDeBruijn
--         if fromIntegral ix < length args
--             then let (x, t) = args !! fromIntegral ix in pure (t, Var x t)
--             else throwError (DeBruijnIndexOutOfRange pos ix)
--     Cst.FunMatch cases -> fmap (second Fun) (inferFunMatch pos cases)
--     Cst.Match matchee cases -> inferMatch pos matchee cases
--     Cst.Ctor c -> do
--         (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
--         (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
--         let tCtion = TConst tdefInst
--         let t = if null cParams' then tCtion else TFun cParams' tCtion
--         pure (t, Ctor variantIx cSpan tdefInst cParams')
--     Cst.Sizeof t -> fmap ((TPrim TNatSize, ) . Sizeof) (checkType' pos t)

-- inferLet1
--     :: SrcPos
--     -> Cst.DefLike ResolvedType ResolvedName
--     -> ResolvedExpr
--     -> Infer (Type, InferredExpr)
-- inferLet1 pos defl body = case defl of
--     Cst.Def def -> do
--         def' <- inferNonrecDef def
--         (t, body') <- augment1 envLocalDefs (defSig def') (infer body)
--         pure (t, Let (VarDef def') body')
--     Cst.Deconstr pat matchee -> inferMatch pos matchee [(pat, body)]

-- inferMatch
--     :: SrcPos -> ResolvedExpr -> [(ResolvedPat, ResolvedExpr)] -> Infer (Type, InferredExpr)
-- inferMatch pos matchee cases = do
--     (tmatchee, matchee') <- infer matchee
--     (tbody, cases') <- inferCases [tmatchee]
--                                   (map (first (\pat -> WithPos (getPos pat) [pat])) cases)
--     pure (tbody, Match (WithPos pos ([matchee'], cases', [tmatchee], tbody)))

-- inferFun :: SrcPos -> Cst.FunPats ResolvedName -> ResolvedExpr -> Infer (Type, InferredFun)
-- inferFun pos pats body = do
--     (tpats, tbody, case') <- inferCase pats body
--     let tpats' = map unFound tpats
--     funMatchToFun pos [case'] tpats' (unFound tbody)

-- inferDeBruijnFun :: Word -> ResolvedExpr -> Infer (Type, InferredFun)
-- inferDeBruijnFun nparams body = genParams nparams $ \paramNames -> do
--     tparams <- replicateM (fromIntegral nparams) fresh
--     let params = zip paramNames tparams
--         paramSigs = map (second (Forall Set.empty Set.empty)) params
--         args = map (uncurry TypedVar) params
--     (tbody, body') <- locallySet envDeBruijn args $ withLocals paramSigs (infer body)
--     pure (TFun tparams tbody, (params, (body', tbody)))

-- inferFunMatch :: SrcPos -> [(Cst.FunPats ResolvedName, ResolvedExpr)] -> Infer (Type, InferredFun)
-- inferFunMatch pos cases = do
--     arity <- checkCasePatternsArity
--     tpats <- replicateM arity fresh
--     (tbody, cases') <- inferCases tpats cases
--     funMatchToFun pos cases' tpats tbody
--   where
--     checkCasePatternsArity = case cases of
--         [] -> ice "inferFunMatch: checkCasePatternsArity: fun* has no cases, arity 0"
--         (pats0, _) : rest -> do
--             let arity = length (unpos pats0)
--             forM_ rest $ \(WithPos pos pats, _) -> unless
--                 (length pats == arity)
--                 (throwError (FunCaseArityMismatch pos arity (length pats)))
--             pure arity

-- funMatchToFun :: SrcPos -> Cases -> [Type] -> Type -> Infer (Type, InferredFun)
-- funMatchToFun pos cases' tpats tbody = genParams (length tpats) $ \paramNames -> do
--     let paramNames' = zipWith fromMaybe paramNames $ case cases' of
--             [(WithPos _ ps, _)] -> flip map ps $ \(Pat _ _ p) -> case p of
--                 PVar x _ -> Just x
--                 _ -> Nothing
--             _ -> repeat Nothing
--         params = zip paramNames' tpats
--         args = map (Var . (NonVirt, ) . uncurry TypedVar) params
--     pure (TFun tpats tbody, (params, (Match (WithPos pos (args, cases', tpats, tbody)), tbody)))

-- -- | All the patterns must be of the same types, and all the bodies must be of the same type.
-- inferCases
--     :: [Type] -- Type of matchee(s). Expected type(s) of pattern(s).
--     -> [(WithPos [ResolvedPat], ResolvedExpr)]
--     -> Infer (Type, Cases)
-- inferCases tmatchees cases = do
--     (tpatss, tbodies, cases') <- fmap unzip3 (mapM (uncurry inferCase) cases)
--     forM_ tpatss $ zipWithM (unify . Expected) tmatchees
--     tbody <- fresh
--     forM_ tbodies (unify (Expected tbody))
--     pure (tbody, cases')

-- inferCase
--     :: WithPos [ResolvedPat]
--     -> ResolvedExpr
--     -> Infer ([FoundType], FoundType, (WithPos [Pat], InferredExpr))
-- inferCase (WithPos pos ps) b = do
--     (tps, ps', pvss) <- fmap unzip3 (mapM inferPat ps)
--     let pvs' = map (bimap Cst.idstr (Forall Set.empty Set.empty . TVar))
--                    (Map.toList (Map.unions pvss))
--     (tb, b') <- withLocals pvs' (infer b)
--     let tps' = zipWith Found (map getPos ps) tps
--     pure (tps', Found (getPos b) tb, (WithPos pos ps', b'))

-- -- | Returns the type of the pattern; the pattern in the Pat format that the Match module wants,
-- --   and a Map from the variables bound in the pattern to fresh schemes.
-- inferPat :: ResolvedPat -> Infer (Type, Pat, Map Ident TVar)
-- inferPat pat = fmap (\(t, p, ss) -> (t, Pat (getPos pat) t p, ss)) (inferPat' pat)
--   where
--     inferPat' = \case
--         Cst.PConstruction pos c ps -> inferPatConstruction pos c ps
--         Cst.PInt _ n -> pure (TPrim TIntSize, intToPCon n 64, Map.empty)
--         Cst.PStr _ s ->
--             let span' = ice "span of Con with VariantStr"
--                 p = PCon (Con (VariantStr s) span' []) []
--             in  pure (tStr, p, Map.empty)
--         Cst.PVar (WithPos _ "_") -> do
--             tv <- fresh
--             pure (tv, PWild, Map.empty)
--         Cst.PVar x@(WithPos _ x') -> do
--             tv <- fresh'
--             pure (TVar tv, PVar (TypedVar x' (TVar tv)), Map.singleton x tv)
--         Cst.PBox _ p -> do
--             (tp', p', vs) <- inferPat p
--             pure (TBox tp', PBox p', vs)

--     intToPCon n w = PCon
--         (Con { variant = VariantIx (fromIntegral n), span = 2 ^ (w :: Integer), argTs = [] })
--         []

--     inferPatConstruction :: SrcPos -> Ident -> [ResolvedPat] -> Infer (Type, Pat', Map Ident TVar)
--     inferPatConstruction pos c cArgs = do
--         (variantIx, tdefLhs, cParams, cSpan) <- lookupEnvConstructor c
--         let arity = length cParams
--         let nArgs = length cArgs
--         unless (arity == nArgs) (throwError (CtorArityMismatch pos (idstr c) arity nArgs))
--         (tdefInst, cParams') <- instantiateConstructorOfTypeDef tdefLhs cParams
--         let t = TConst tdefInst
--         (cArgTs, cArgs', cArgsVars) <- fmap unzip3 (mapM inferPat cArgs)
--         cArgsVars' <- nonconflictingPatVarDefs cArgsVars
--         forM_ (zip3 cParams' cArgTs cArgs)
--             $ \(cParamT, cArgT, cArg) -> unify (Expected cParamT) (Found (getPos cArg) cArgT)
--         let con = Con { variant = VariantIx variantIx, span = cSpan, argTs = cArgTs }
--         pure (t, PCon con cArgs', cArgsVars')

--     nonconflictingPatVarDefs = flip foldM Map.empty $ \acc ks ->
--         case listToMaybe (Map.keys (Map.intersection acc ks)) of
--             Just (WithPos pos v) -> throwError (ConflictingPatVarDefs pos v)
--             Nothing -> pure (Map.union acc ks)

-- instantiateConstructorOfTypeDef :: (String, [TVar]) -> [Type] -> Infer (TConst, [Type])
-- instantiateConstructorOfTypeDef (tName, tParams) cParams = do
--     tVars <- mapM (const fresh) tParams
--     let cParams' = map (subst (Map.fromList (zip tParams tVars))) cParams
--     pure ((tName, tVars), cParams')

-- lookupEnvConstructor :: Ident -> Infer (VariantIx, (String, [TVar]), [Type], Span)
-- lookupEnvConstructor (WithPos pos cx) =
--     view (envCtors . to (Map.lookup cx)) >>= maybe (throwError (UndefCtor pos cx)) pure

-- litType :: Cst.Const -> Type
-- litType = \case
--     Cst.Int _ -> TPrim TIntSize
--     Cst.F64 _ -> TPrim TF64
--     Cst.Str _ -> tStr

-- lookupVar :: Ident -> Infer (QualName, Type)
-- lookupVar (WithPos pos x) = do
--     virt <- fmap (Map.lookup x) (view envVirtuals)
--     glob <- fmap (Map.lookup x) (view envGlobDefs)
--     local <- fmap (Map.lookup x) (view envLocalDefs)
--     case fmap (NonVirt, ) (local <|> glob) <|> fmap (Virt, ) virt of
--         Just (virt, scm) -> instantiate pos scm <&> \t -> (t, (virt, TypedVar x t))
--         Nothing -> throwError (UndefVar pos x)

-- genParams :: Integral n => n -> ([String] -> Infer a) -> Infer a
-- genParams n f = do
--     ps <- view (freshParams . to (take (fromIntegral n)))
--     locally freshParams (drop (fromIntegral n)) (f ps)

-- withLocals :: [(String, Scheme)] -> Infer a -> Infer a
-- withLocals = augment envLocalDefs . Map.fromList

-- instantiate :: SrcPos -> Scheme -> Infer Type
-- instantiate pos (Forall params constraints t) = do
--     s <- Map.fromList <$> zipWithM (fmap . (,)) (Set.toList params) (repeat fresh)
--     forM_ constraints $ \c -> unifyClass pos (substClassConstraint s c)
--     pure (subst s t)

-- generalize
--     :: (MonadError TypeErr m)
--     => Map String Scheme
--     -> Maybe (Set ClassConstraint)
--     -> Map ClassConstraint SrcPos
--     -> Type
--     -> m Scheme
-- generalize env mayGivenCs allCs t = fmap (\cs -> Forall vs cs t) constraints
--   where
--     -- A constraint should be included in a signature if the type variables include at least one of
--     -- the signature's forall-qualified tvars, and the rest of the tvars exist in the surrounding
--     -- environment. If a tvar is not from the signature or the environment, it comes from an inner
--     -- definition, and should already have been included in that signature.
--     --
--     -- TODO: Maybe we should handle the propagation of class constraints in a better way, so that
--     --       ones belonging to inner definitions no longer exist at this point.
--     constraints = fmap (Set.fromList . map fst) $ flip filterM (Map.toList allCs) $ \(c, pos) ->
--         let vcs = ftvClassConstraint c
--             belongs =
--                 any (flip Set.member vs) vcs
--                     && all (\vc -> Set.member vc vs || Set.member vc ftvEnv) vcs
--         in  if belongs
--                 then if matchesGiven c then pure True else throwError (NoClassInstance pos c)
--                 else pure False
--     matchesGiven = case mayGivenCs of
--         Just gcs -> flip Set.member gcs
--         Nothing -> const True
--     vs = Set.difference (ftv t) ftvEnv
--     ftvEnv = Set.unions (map ftvScheme (Map.elems env))
--     ftvScheme (Forall tvs _ t) = Set.difference (ftv t) tvs

-- substEnv :: Subst' -> Map String Scheme -> Map String Scheme
-- substEnv s = over (mapped . scmBody) (subst s)

-- ftvClassConstraint :: ClassConstraint -> Set TVar
-- ftvClassConstraint = mconcat . map ftv . snd

-- substClassConstraint :: Subst' -> ClassConstraint -> ClassConstraint
-- substClassConstraint sub = second (map (subst sub))

-- fresh :: Infer Type
-- fresh = fmap TVar fresh'

-- fresh' :: Infer TVar
-- fresh' = fmap TVImplicit (gets head <* modify tail)

-- unify :: ExpectedType -> FoundType -> Infer ()
-- unify e f = tell ([(e, f)], [])

-- unifyClass :: SrcPos -> ClassConstraint -> Infer ()
-- unifyClass p c = tell ([], [(p, c)])

-- data UnifyErr = UInfType TVar Type | UFailed Type Type

-- -- TODO: I actually don't really like this approach of keeping the unification solver separate from
-- --       the inferrer. The approach of doing it "inline" is, at least in some ways, more flexible,
-- --       and probably more performant. Consider this further -- maybe there's a big con I haven't
-- --       considered or have forgotten. Will updating the substitution map work well? How would it
-- --       work for nested inferDefs, compared to now?
-- solve :: Constraints -> Infer (Subst', Map ClassConstraint SrcPos)
-- solve (eqcs, ccs) = do
--     sub <- lift $ lift $ lift $ solveUnis Map.empty eqcs
--     ccs' <- solveClassCs (map (second (substClassConstraint sub)) ccs)
--     pure (sub, ccs')
--   where
--     solveUnis :: Subst' -> [EqConstraint] -> Except TypeErr Subst'
--     solveUnis sub1 = \case
--         [] -> pure sub1
--         (Expected et, Found pos ft) : cs -> do
--             sub2 <- withExcept (toTypeErr pos et ft) (unifies et ft)
--             solveUnis (composeSubsts sub2 sub1) (map (substConstraint sub2) cs)

--     solveClassCs :: [(SrcPos, ClassConstraint)] -> Infer (Map ClassConstraint SrcPos)
--     solveClassCs = fmap Map.unions . mapM solveClassConstraint

--     solveClassConstraint :: (SrcPos, ClassConstraint) -> Infer (Map ClassConstraint SrcPos)
--     solveClassConstraint (pos, c) = case c of
--         -- Virtual classes
--         ("SameSize", [ta, tb]) -> sameSize (ta, tb)
--         ("Cast", [ta, tb]) -> cast (ta, tb)
--         ("Num", [ta]) -> case ta of
--             TPrim _ -> ok
--             TVar _ -> propagate
--             TConst _ -> err
--             TFun _ _ -> err
--             TBox _ -> err
--         ("Bitwise", [ta]) -> case ta of
--             TPrim p | isIntegral p -> ok
--             TPrim _ -> err
--             TVar _ -> propagate
--             TConst _ -> err
--             TFun _ _ -> err
--             TBox _ -> err
--         ("Ord", [ta]) -> case ta of
--             TPrim _ -> ok
--             TVar _ -> propagate
--             TConst _ -> err
--             TFun _ _ -> err
--             TBox _ -> err
--         -- "Real classes"
--         -- ... TODO
--         _ -> ice $ "solveClassCs: invalid class constraint " ++ show c
--       where
--         ok = pure Map.empty
--         propagate = pure (Map.singleton c pos)
--         err = throwError (NoClassInstance pos c)
--         isIntegral = \case
--             TInt _ -> True
--             TIntSize -> True
--             TNat _ -> True
--             TNatSize -> True
--             _ -> False

--         -- TODO: Maybe we should move the check against user-provided explicit signature from
--         --       `generalize` to here. Like, we could keep the explicit scheme (if there is one) in
--         --       the `Env`.
--         --
--         -- | As the name indicates, a predicate that is true / class that is instanced when two
--         --   types are of the same size. If the size for either cannot be determined yet due to
--         --   polymorphism, the constraint is propagated.
--         sameSize :: (Type, Type) -> Infer (Map ClassConstraint SrcPos)
--         sameSize (ta, tb) = do
--             sizeof'' <- sizeof . sizeofTypeDef <$> view envTypeDefs
--             case liftA2 (==) (sizeof'' ta) (sizeof'' tb) of
--                 _ | ta == tb -> ok
--                 Right True -> ok
--                 Right False -> err
--                 -- One or both of the two types are of unknown size due to polymorphism, so
--                 -- propagate the constraint to the scheme of the definition.
--                 Left _ -> propagate

--         sizeofTypeDef tdefs (x, args) = case Map.lookup x tdefs of
--             Just (params, Data variants) ->
--                 let sub = Map.fromList (zip params args)
--                     datas = map (map (subst sub) . snd) variants
--                 in  sizeofData (sizeofTypeDef tdefs) (alignofTypeDef tdefs) datas
--             Just (params, Alias _ t) ->
--                 let sub = Map.fromList (zip params args)
--                 in  sizeof (sizeofTypeDef tdefs) (subst sub t)
--             Nothing -> ice $ "Infer.sizeofTypeDef: undefined type " ++ show x
--         alignofTypeDef tdefs (x, args) = case Map.lookup x tdefs of
--             Just (params, Data variants) ->
--                 let sub = Map.fromList (zip params args)
--                     datas = map (map (subst sub) . snd) variants
--                 in  alignmentofData (alignofTypeDef tdefs) datas
--             Just (params, Alias _ t) ->
--                 let sub = Map.fromList (zip params args)
--                 in  alignmentof (alignofTypeDef tdefs) (subst sub t)
--             Nothing -> ice $ "Infer.sizeofTypeDef: undefined type " ++ show x

--         -- | This class is instanced when the first type can be `cast` to the other.
--         cast :: (Type, Type) -> Infer (Map ClassConstraint SrcPos)
--         cast = \case
--             (ta, tb) | ta == tb -> ok
--             (TPrim _, TPrim _) -> ok
--             (TVar _, _) -> propagate
--             (_, TVar _) -> propagate
--             (TConst _, _) -> err
--             (_, TConst _) -> err
--             (TFun _ _, _) -> err
--             (_, TFun _ _) -> err
--             (TBox _, _) -> err
--             (_, TBox _) -> err

--     substConstraint sub (Expected t1, Found pos t2) =
--         (Expected (subst sub t1), Found pos (subst sub t2))

--     toTypeErr :: SrcPos -> Type -> Type -> UnifyErr -> TypeErr
--     toTypeErr pos t1 t2 = \case
--         UInfType a t -> InfType pos t1 t2 a t
--         UFailed t'1 t'2 -> UnificationFailed pos t1 t2 t'1 t'2

-- -- FIXME: Keep track of whether we've flipped the arguments. Alternatively, keep right stuff to the
-- --        right and vice versa. If we don't, we get confusing type errors.
-- unifies :: Type -> Type -> Except UnifyErr Subst'
-- unifies = curry $ \case
--     (TPrim a, TPrim b) | a == b -> pure Map.empty
--     (TConst (c0, ts0), TConst (c1, ts1)) | c0 == c1 -> if length ts0 /= length ts1
--         then ice "lengths of TConst params differ in unify"
--         else unifiesMany (zip ts0 ts1)
--     (TVar a, TVar b) | a == b -> pure Map.empty
--     (TVar a, t) | occursIn a t -> throwError (UInfType a t)
--     -- Do not allow "override" of explicit (user given) type variables.
--     (a@(TVar (TVExplicit _)), b@(TVar (TVImplicit _))) -> unifies b a
--     (a@(TVar (TVExplicit _)), b) -> throwError (UFailed a b)
--     (TVar a, t) -> pure (Map.singleton a t)
--     (t, TVar a) -> unifies (TVar a) t
--     (t@(TFun ts1 t2), u@(TFun us1 u2)) -> if length ts1 /= length us1
--         then throwError (UFailed t u)
--         else unifiesMany (zip (ts1 ++ [t2]) (us1 ++ [u2]))
--     (TBox t, TBox u) -> unifies t u
--     (t1, t2) -> throwError (UFailed t1 t2)
--   where
--     unifiesMany :: [(Type, Type)] -> Except UnifyErr Subst'
--     unifiesMany = foldM
--         (\s (t, u) -> fmap (flip composeSubsts s) (unifies (subst s t) (subst s u)))
--         Map.empty

--     occursIn :: TVar -> Type -> Bool
--     occursIn a t = Set.member a (ftv t)

M src/Front/Inferred.hs => src/Front/Inferred.hs +160 -143
@@ 1,143 1,160 @@
{-# LANGUAGE TemplateHaskell, DataKinds #-}

-- TODO: Can this and Checked be merged to a single, parametrized AST?

-- | Type annotated AST as a result of typechecking
module Front.Inferred (module Front.Inferred, Type, TConst, WithPos(..), TVar(..), TPrim(..), Const(..), Type' (..), TConst') where

import Data.Set (Set)
import qualified Data.Set as Set
import Data.Map (Map)
import Data.Bifunctor
import Lens.Micro.Platform (makeLenses)

import Misc
import qualified Front.Parsed as Parsed
import Front.Parsed (Type, TConst, TVar(..), Const(..))
import Front.SrcPos
import Front.TypeAst


data TypeErr
    = MainNotDefined
    | InvalidUserTypeSig SrcPos Scheme Scheme
    | CtorArityMismatch SrcPos String Int Int
    | ConflictingPatVarDefs SrcPos String
    | UndefCtor SrcPos String
    | UndefVar SrcPos String
    | InfType SrcPos Type Type TVar Type
    | UnificationFailed SrcPos Type Type Type Type
    | ConflictingTypeDef SrcPos String
    | ConflictingCtorDef SrcPos String
    | RedundantCase SrcPos
    | InexhaustivePats SrcPos String
    | ExternNotMonomorphic (Parsed.Id 'Parsed.Small) TVar
    | FoundHole SrcPos
    | RecTypeDef String SrcPos
    | UndefType SrcPos String
    | WrongMainType SrcPos Parsed.Scheme
    | RecursiveVarDef (WithPos String)
    | TypeInstArityMismatch SrcPos String Int Int
    | ConflictingVarDef SrcPos String
    | NoClassInstance SrcPos ClassConstraint
    | FunCaseArityMismatch SrcPos Int Int
    | FunArityMismatch SrcPos Int Int
    | DeBruijnIndexOutOfRange SrcPos Word
    | FreeVarsInData SrcPos TVar
    | FreeVarsInAlias SrcPos TVar
    deriving (Show)

type ClassConstraint = (String, [Type])

data Scheme = Forall
    { _scmParams :: Set TVar
    , _scmConstraints :: Set ClassConstraint
    , _scmBody :: Type
    }
    deriving (Show, Eq)
makeLenses ''Scheme

data TypedVar = TypedVar String Type
    deriving (Show, Eq, Ord)

type VariantIx = Integer

type Span = Integer

data Variant = VariantIx VariantIx | VariantStr String
    deriving (Show, Eq, Ord)

data Con = Con
    { variant :: Variant
    , span :: Span
    , argTs :: [Type]
    }
    deriving Show

data Pat'
    = PVar TypedVar
    | PWild
    | PCon Con [Pat]
    | PBox Pat
    deriving Show

data Pat = Pat SrcPos Type Pat'
    deriving Show

type Fun = ([(String, Type)], (Expr, Type))

type Cases = [(WithPos [Pat], Expr)]
type Match = WithPos ([Expr], Cases, [Type], Type)

-- | Whether a Var refers to a builtin virtual, or a global/local definition. So we don't
--   have to keep as much state about environment definitions in later passes.
data Virt = Virt | NonVirt deriving (Show, Eq)

type Var = (Virt, TypedVar)

data Expr
    = Lit Const
    | Var Var
    | App Expr [Expr] Type
    | If Expr Expr Expr
    | Let Def Expr
    | Fun Fun
    | Match Match
    | Ctor VariantIx Span TConst [Type]
    | Sizeof Type
    deriving Show

type Defs = TopologicalOrder Def
data Def = VarDef VarDef | RecDefs RecDefs deriving Show
type VarDef = (String, (Scheme, Expr))
type RecDefs = [(String, (Scheme, Fun))]
data TypeDefRhs = Data [(WithPos String, [Type])] | Alias SrcPos Type
    deriving Show
type TypeDefs = Map String ([TVar], TypeDefRhs)
type TypeAliases = Map String ([TVar], Type)
type Ctors = Map String (VariantIx, (String, [TVar]), [Type], Span)
type Externs = Map String Type


instance Eq Con where
    (==) (Con c1 _ _) (Con c2 _ _) = c1 == c2

instance Ord Con where
    compare (Con c1 _ _) (Con c2 _ _) = compare c1 c2


ftv :: Type -> Set TVar
ftv = \case
    TVar tv -> Set.singleton tv
    TPrim _ -> Set.empty
    TFun pts rt -> Set.unions (ftv rt : map ftv pts)
    TBox t -> ftv t
    TConst (_, ts) -> Set.unions (map ftv ts)

defSigs :: Def -> [(String, Scheme)]
defSigs = \case
    VarDef d -> [defSig d]
    RecDefs ds -> map defSig ds

defSig :: (String, (Scheme, a)) -> (String, Scheme)
defSig = second fst
module Front.Inferred where

import SrcPos
--import Name
--import Front.TypeAst
--import Front.Resolved (ResolvedScheme)
import Front.Abstract

type Cases = [(WithPos [Pat], InferredExpr)]
data MatchCases = MatchCases Cases [Type] Type

type InferredFun = Fun MatchCases
type InferredExpr = Expr MatchCases
type InferredVarDef = VarDef MatchCases
type InferredRecDefs = RecDefs MatchCases
type InferredDef = Def MatchCases
type InferredDefs = Defs MatchCases

-- {-# LANGUAGE TemplateHaskell #-}

-- -- TODO: Can this and Checked be merged to a single, parametrized AST?

-- -- | Type annotated AST as a result of typechecking
-- module Front.Inferred (module Front.Inferred, WithPos(..), TVar(..), TPrim(..), Const(..), Type' (..), TConst') where

-- import Data.Set (Set)
-- import qualified Data.Set as Set
-- import Data.Map (Map)
-- import Data.Bifunctor
-- import Lens.Micro.Platform (makeLenses)

-- import Misc
-- import SrcPos
-- import Name
-- import Front.Concrete (TVar (..), Const (..))
-- import Front.Resolved
-- -- import qualified Front.Parsed as Parsed
-- -- import Front.Concrete (Type, TConst, TVar(..), Const(..))
-- -- import Front.Parsed
-- import Front.TypeAst

-- type LhsName = Ident

-- type TConst = ResolvedTConst
-- type Type = ResolvedType

-- data TypeErr
--     = MainNotDefined
--     | InvalidUserTypeSig SrcPos Scheme Scheme
--     | CtorArityMismatch SrcPos String Int Int
--     | ConflictingPatVarDefs SrcPos String
--     | UndefCtor SrcPos String
--     | UndefVar SrcPos String
--     | InfType SrcPos Type Type TVar Type
--     | UnificationFailed SrcPos Type Type Type Type
--     | ConflictingTypeDef SrcPos String
--     | ConflictingCtorDef SrcPos String
--     | RedundantCase SrcPos
--     | InexhaustivePats SrcPos String
--     | ExternNotMonomorphic Ident TVar
--     | FoundHole SrcPos
--     | RecTypeDef String SrcPos
--     | UndefType SrcPos String
--     | WrongMainType SrcPos ResolvedScheme
--     | RecursiveVarDef (WithPos String)
--     | TypeInstArityMismatch SrcPos String Int Int
--     | ConflictingVarDef SrcPos String
--     | NoClassInstance SrcPos ClassConstraint
--     | FunCaseArityMismatch SrcPos Int Int
--     | FunArityMismatch SrcPos Int Int
--     | DeBruijnIndexOutOfRange SrcPos Word
--     | FreeVarsInData SrcPos TVar
--     | FreeVarsInAlias SrcPos TVar
--     deriving (Show)

-- type ClassConstraint = (QualName, [Type])

-- data Scheme = Forall
--     { _scmParams :: Set TVar
--     , _scmConstraints :: Set ClassConstraint
--     , _scmBody :: Type
--     }
--     deriving (Show, Eq)
-- makeLenses ''Scheme

-- type VariantIx = Integer

-- type Span = Integer

-- data Variant = VariantIx VariantIx | VariantStr String
--     deriving (Show, Eq, Ord)

-- data Con = Con
--     { variant :: Variant
--     , span :: Span
--     , argTs :: [Type]
--     }
--     deriving Show

-- data Pat'
--     = PVar LhsName Type
--     | PWild
--     | PCon Con [Pat]
--     | PBox Pat
--     deriving Show

-- data Pat = Pat SrcPos Type Pat'
--     deriving Show

-- type Fun = ([(LhsName, Type)], (Expr, Type))

-- type Cases = [(WithPos [Pat], Expr)]
-- type Match = WithPos ([Expr], Cases, [Type], Type)

-- data Expr
--     = Lit Const
--     | EVar QualName Type
--     | App Expr [Expr] Type
--     | If Expr Expr Expr
--     | Let Def Expr
--     | Fun Fun
--     | Match Match
--     | Ctor VariantIx Span TConst [Type]
--     | Sizeof Type
--     deriving Show

-- type Defs = TopologicalOrder Def
-- data Def = VarDef VarDef | RecDefs RecDefs deriving Show
-- type VarDef = (LhsName, (Scheme, Expr))
-- type RecDefs = [(LhsName, (Scheme, Fun))]
-- data TypeDefRhs = Data [(WithPos QualName, [Type])] | Alias SrcPos Type
--     deriving Show
-- type TypeDefs = Map QualName ([TVar], TypeDefRhs)
-- type TypeAliases = Map QualName ([TVar], Type)
-- type Ctors = Map String (VariantIx, (String, [TVar]), [Type], Span)
-- type Externs = Map String Type


-- instance Eq Con where
--     (==) (Con c1 _ _) (Con c2 _ _) = c1 == c2

-- instance Ord Con where
--     compare (Con c1 _ _) (Con c2 _ _) = compare c1 c2


-- ftv :: Type -> Set TVar
-- ftv = \case
--     TVar tv -> Set.singleton tv
--     TPrim _ -> Set.empty
--     TFun pts rt -> Set.unions (ftv rt : map ftv pts)
--     TBox t -> ftv t
--     TConst (_, ts) -> Set.unions (map ftv ts)

-- defSigs :: Def -> [(LhsName, Scheme)]
-- defSigs = \case
--     VarDef d -> [defSig d]
--     RecDefs ds -> map defSig ds

-- defSig :: (LhsName, (Scheme, a)) -> (LhsName, Scheme)
-- defSig = second fst

M src/Front/Lex.hs => src/Front/Lex.hs +52 -83
@@ 6,11 6,10 @@
--       If a parser has a variant with a "ns_" prefix, that variant does not consume succeding
--       space, while the unprefixed variant does.

module Front.Lex (lex, toplevel, tokentree) where
module Front.Lex (lex, tokentree) where

import Control.Monad
import Control.Monad.Except
import Control.Monad.State
import Data.Char (isMark, isPunctuation, isSymbol)
import Data.Functor
import Data.Maybe


@@ 19,82 18,30 @@ import Text.Megaparsec hiding (parse, match, token, Token)
import Text.Megaparsec.Char hiding (space, space1)
import qualified Text.Megaparsec.Char as Char
import qualified Text.Megaparsec.Char.Lexer as Lexer
import Data.Set (Set)
import qualified Data.Set as Set
import System.FilePath
import System.Directory
import Data.Void
import Prelude hiding (lex)

import Misc
import Front.SrcPos
import SrcPos
import Name
import Front.Lexd
import Front.Literate
import EnvVars


type Lexer = Parsec Void String

type Import = String

data TopLevel = TImport Import -- | TMacro Macro
    | TTokenTree TokenTree


lex :: FilePath -> ExceptT String IO [TokenTree]
lex :: FilePath -> ExceptT String IO Module
lex filepath = do
    modPaths <- lift modulePaths
    filepath' <- lift $ makeAbsolute filepath
    evalStateT (lexModule modPaths filepath') Set.empty

-- NOTE: For the current implementation of macros where order of definition matters, it's important
--       that we visit imports and concatenate all token trees in the correct order, which is DFS.
lexModule :: [FilePath] -> FilePath -> StateT (Set FilePath) (ExceptT String IO) [TokenTree]
lexModule modPaths f = get >>= \visiteds -> if Set.member f visiteds
    then pure []
    else do
        modify (Set.insert f)
        s <- liftIO $ readFile f <&> \s' ->
            if takeExtension f == ".org" then untangleOrg s' else s'
        (imps, tts) <- liftEither $ parse' toplevels f s
        let ps = takeDirectory f : modPaths
        let resolve m = do
                let gs = [ p </> addExtension m e | p <- ps, e <- [".carth", ".org"] ]
                gs' <- filterM (liftIO . doesFileExist) gs
                case listToMaybe gs' of
                    Nothing ->
                        throwError
                            $ ("Error: No file for module " ++ m ++ " exists.\n")
                            ++ ("Searched paths: " ++ show ps)
                    Just g' -> liftIO $ makeAbsolute g'
        impFs <- mapM resolve imps
        ttsImp <- concat <$> mapM (lexModule modPaths) impFs
        pure (ttsImp ++ tts)

toplevels :: Lexer ([Import], [TokenTree])
toplevels = do
    space
    tops <- many toplevel
    eof
    pure $ foldr
        (\top (is, tts) -> case top of
            TImport i -> (i : is, tts)
            TTokenTree tt -> (is, tt : tts)
        )
        ([], [])
        tops

toplevel :: Lexer TopLevel
toplevel = getSrcPos >>= \p ->
    parens (fmap TImport import' <|> fmap (TTokenTree . WithPos p . Parens) (many tokentree))
    where import' = andSkipSpaceAfter (string "import") *> small
    f <- lift $ makeAbsolute filepath
    s <- liftIO $ readFile f <&> \s' -> if takeExtension f == ".org" then untangleOrg s' else s'
    tts <- liftEither $ parse' (many tokentree) f s
    pure (Module tts)

tokentree :: Lexer TokenTree
tokentree = do
    p <- getSrcPos
    tt <- tokentree'
    tt' <- option tt (ellipsis $> Ellipsis (WithPos p tt))
    pure (WithPos p tt')
tokentree = withPos tokentree'
  where
    tokentree' = choice
        [ fmap Parens (parens (many tokentree))


@@ 103,14 50,12 @@ tokentree = do
        , fmap Backslashed (string "\\" *> tokentree)
        , string "#" *> ((Octothorped <$> tokentree) <|> (Octothorpe <$ space))
        , fmap Reserved reserved
        , fmap Keyword (string ":" *> small)
        , fmap Small smallSpecial
        , fmap Big bigSpecial
        , fmap Small smallNormal
        , fmap Big bigNormal
        , fmap Keyword (string ":" *> ident)
        , lexdName <&> \case
            x `LIn` LRelative -> LIdent x
            x -> LName x
        , fmap Lit lit
        ]
    ellipsis = try (string "..." *> notFollowedBy identLetter *> space)
    lit = try num <|> fmap Str strlit
    num = andSkipSpaceAfter ns_num
    ns_num = do


@@ 152,18 97,39 @@ reserved = andSkipSpaceAfter . choice $ map
    , string "type" $> Rtype
    ]

small, smallSpecial, smallNormal :: Lexer String
small = smallSpecial <|> smallNormal
smallSpecial = string "id@" *> strlit
smallNormal = andSkipSpaceAfter $ liftA2 (:) smallStart identRest
  where
    smallStart =
        lowerChar <|> otherChar <|> try (oneOf ("-+" :: String) <* notFollowedBy digitChar)
-- small, smallSpecial, smallNormal :: Lexer String
-- small = smallSpecial <|> smallNormal
-- smallSpecial = string "id@" *> strlit
-- smallNormal = andSkipSpaceAfter $ liftA2 (:) smallStart identRest
--   where
--     smallStart =
--         lowerChar <|> otherChar <|> try (oneOf ("-+" :: String) <* notFollowedBy digitChar)

-- bigSpecial, bigNormal :: Lexer String
-- bigSpecial = string "id@" *> strlit
-- bigNormal = andSkipSpaceAfter $ liftA2 (:) bigStart identRest
--     where bigStart = upperChar <|> char ':'

lexdName :: Lexer LexdName
lexdName = do
    root <- option LRelative (char '/' $> LAbsolute)
    xs <- sepBy1 ns_ident (char '/')
    pure $ foldr LIn root xs

bigSpecial, bigNormal :: Lexer String
bigSpecial = string "id@" *> strlit
bigNormal = andSkipSpaceAfter $ liftA2 (:) bigStart identRest
    where bigStart = upperChar <|> char ':'
ident :: Lexer Ident
ident = fmap Ident ident'

ns_ident :: Lexer Ident
ns_ident = fmap Ident ns_ident'

ident' :: Lexer String
ident' = andSkipSpaceAfter ns_ident'

ns_ident' :: Lexer String
ns_ident' = liftA2 (:) identStart identRest
  where
    identStart =
        letterChar <|> otherChar <|> try (oneOf ("-+" :: String) <* notFollowedBy digitChar)

identRest :: Lexer String
identRest = many identLetter


@@ 176,21 142,21 @@ otherChar = satisfy
    (\c -> and
        [ any ($ c) [isMark, isPunctuation, isSymbol]
        , c `notElem` ("()[]{}" :: String)
        , c `notElem` ("\"-+:•" :: String)
        , c `notElem` ("\"-+:•/" :: String)
        ]
    )

parens, ns_parens :: Lexer a -> Lexer a
parens = andSkipSpaceAfter . ns_parens
ns_parens = between (symbol "(") (string ")")
ns_parens = between (symbol "(") (char ')')

brackets, ns_brackets :: Lexer a -> Lexer a
brackets = andSkipSpaceAfter . ns_brackets
ns_brackets = between (symbol "[") (string "]")
ns_brackets = between (symbol "[") (char ']')

braces, ns_braces :: Lexer a -> Lexer a
braces = andSkipSpaceAfter . ns_braces
ns_braces = between (symbol "{") (string "}")
ns_braces = between (symbol "{") (char '}')

andSkipSpaceAfter :: Lexer a -> Lexer a
andSkipSpaceAfter = Lexer.lexeme space


@@ 206,3 172,6 @@ getSrcPos :: Lexer SrcPos
getSrcPos = fmap
    (\(SourcePos f l c) -> SrcPos f (fromIntegral (unPos l)) (fromIntegral (unPos c)) Nothing)
    getSourcePos

withPos :: Lexer a -> Lexer (WithPos a)
withPos = liftA2 WithPos getSrcPos

M src/Front/Lexd.hs => src/Front/Lexd.hs +11 -5
@@ 2,7 2,8 @@

module Front.Lexd where

import Front.SrcPos
import SrcPos
import Name

data Const
    = Int Int


@@ 36,17 37,22 @@ data Reserved

data TokenTree'
    = Reserved Reserved
    | Keyword String
    | Keyword Ident
    | Lit Const
    | Small String
    | Big String
    | LIdent Ident
    | LName LexdName
    | Parens [TokenTree]
    | Brackets [TokenTree]
    | Braces [TokenTree]
    | Backslashed TokenTree
    | Octothorped TokenTree
    | Octothorpe
    | Ellipsis TokenTree
    | Quote
    | Backquote
    | Unquote TokenTree
    | UnquoteSplicing TokenTree
    deriving (Eq, Show)

type TokenTree = WithPos TokenTree'

newtype Module = Module [TokenTree]

M src/Front/Macro.hs => src/Front/Macro.hs +164 -117
@@ 1,133 1,180 @@
module Front.Macro (expandMacros) where

import Control.Applicative
-- import Control.Applicative
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor
import Data.Map (Map)
import Data.Set (Set)
import qualified Data.Set as Set
import qualified Data.Map as Map
-- import Control.Monad.Reader
-- import Control.Monad.State
-- import Data.Bifunctor
-- import Data.Map (Map)
-- import Data.Set (Set)
-- import qualified Data.Set as Set
-- import qualified Data.Map as Map

import Misc
import Front.SrcPos
-- import Name
import SrcPos
import Front.Lexd
import Front.Parser
-- import Front.Parser

type Literals = Set String
type Rules = [([TokenTree], [TokenTree])]
type Macros = Map String (Literals, Rules)
type Bindings = Map String TokenTree'
type Expand = ReaderT (Bindings, Maybe SrcPos) (StateT Macros (Except (SrcPos, String)))
-- type Literals = Set Ident
-- type Rules = [([TokenTree], [TokenTree])]
-- type Macros = Map Ident (Literals, Rules)
-- type Bindings = Map Ident TokenTree'
-- type Expand = ReaderT (Bindings, Maybe SrcPos) (StateT Macros (Except (SrcPos, String)))

expandMacros :: [TokenTree] -> Except (SrcPos, String) [TokenTree]
expandMacros tts = evalStateT (runReaderT (toplevels tts) (Map.empty, Nothing)) Map.empty
expandMacros :: Module -> Except (SrcPos, String) Module
expandMacros = nyi "expandMacros"

toplevels :: [TokenTree] -> Expand [TokenTree]
toplevels = fmap concat . mapM toplevel
-- expandMacros :: [TokenTree] -> Except (SrcPos, String) [TokenTree]
-- expandMacros tts = evalStateT (runReaderT (toplevels tts) (Map.empty, Nothing)) Map.empty

toplevel :: TokenTree -> Expand [TokenTree]
toplevel = \case
    WithPos mpos (Parens (WithPos _ (Reserved Rdefmacro) : tts)) -> do
        (name, lits, rules) <- case runParser pdefmacro mpos tts of
            (result, []) -> (lift . lift . liftEither) result
            (_, messages) ->
                ice
                    $ "Macro.toplevel: There were messages when running the pdefmacro parser: "
                    ++ show messages
        modify (Map.insert name (lits, rules))
        pure []
    tt -> expand tt
-- toplevels :: [TokenTree] -> Expand [TokenTree]
-- toplevels = fmap concat . mapM toplevel

pdefmacro :: Parser (String, Literals, Rules)
pdefmacro = liftA3 (,,) small' (fmap Set.fromList (parens (many small'))) (some prule)
  where
    prule = parens $ do
        reserved Rcase
        params <- parens (many anyToken)
        template <- many anyToken
        pure (params, template)
-- toplevel :: TokenTree -> Expand [TokenTree]
-- toplevel = \case
--     WithPos mpos (Parens (WithPos _ (Reserved Rdefmacro) : tts)) -> do
--         (name, lits, rules) <- case runParser pdefmacro mpos tts of
--             (result, []) -> (lift . lift . liftEither) result
--             (_, messages) ->
--                 ice
--                     $ "Macro.toplevel: There were messages when running the pdefmacro parser: "
--                     ++ show messages
--         modify (Map.insert name (lits, rules))
--         pure []
--     tt -> expand tt

expand :: TokenTree -> Expand [TokenTree]
expand (WithPos tpos tt') = do
    (bs, expPos) <- ask
    ms <- get
    let tpos' = tpos { inExpansion = expPos }
    let tt = WithPos tpos' tt'
    let par ctor tts = fmap (pure . WithPos tpos' . ctor) (expands tts)
    case tt' of
        Lit _ -> pure [tt]
        Small x -> case Map.lookup x bs of
            Just xtt -> pure [WithPos tpos' xtt]
            Nothing -> pure [tt]
        Big _ -> pure [tt]
        Reserved _ -> pure [tt]
        Keyword _ -> pure [tt]
        Backslashed tt' -> map (WithPos tpos' . Backslashed) <$> expand tt'
        Octothorped tt' -> map (WithPos tpos' . Octothorped) <$> expand tt'
        Octothorpe -> pure [tt]
        Parens (WithPos _ (Small x) : tts1) | Just m <- Map.lookup x ms -> do
            tts2 <- expands tts1
            local (second (const (Just tpos'))) $ do
                tts3 <- uncurry (applyMacro tpos' tts2) m
                expands tts3
        Parens tts -> par Parens tts
        Brackets tts -> par Brackets tts
        Braces tts -> par Braces tts
        Ellipsis (WithPos epos (Small x)) -> case Map.lookup x bs of
            Just (Parens xtts) -> expands xtts
            Just (Brackets xtts) -> expands xtts
            Just (Braces xtts) -> expands xtts
            Just _ ->
                throwError (epos, "Cannot ellipsis splice non-sequence macro pattern variable")
            Nothing -> throwError (epos, "Unbound macro pattern variable")
        Ellipsis (WithPos epos _) ->
            throwError (epos, "Can only ellipsis splice macro pattern variable")
-- pdefmacro :: Parser (Ident, Literals, Rules)
-- pdefmacro = liftA3 (,,)
--                    (unpos <$> smallLhs)
--                    (fmap Set.fromList (parens (many (unpos <$> smallLhs))))
--                    (some prule)
--   where
--     prule = parens $ do
--         reserved Rcase
--         params <- parens (many anyToken)
--         template <- many anyToken
--         pure (params, template)

expands :: [TokenTree] -> Expand [TokenTree]
expands = fmap concat . mapM expand
-- subst :: Map Ident TokenTree -> TokenTree -> [TokenTree]
-- subst s (WithPos tpos tt) = case tt of
--     Lit _ -> _
--     -- (bs, expPos) <- ask
--     -- ms <- get
--     -- let tpos' = tpos { inExpansion = expPos }
--     -- let tt = WithPos tpos' tt'
--     -- let par ctor tts = fmap (pure . WithPos tpos' . ctor) (expands tts)
--     -- case tt' of
--     --     Lit _ -> pure [tt]
--     --     LIdent x -> case Map.lookup x bs of
--     --         Just xtt -> pure [WithPos tpos' xtt]
--     --         Nothing -> pure [tt]
--     --     -- LRelName _ -> pure [tt]
--     --     -- LQualName _ -> pure [tt]
--     --     Reserved _ -> pure [tt]
--     --     Keyword _ -> pure [tt]
--     --     Backslashed tt' -> map (WithPos tpos' . Backslashed) <$> expand tt'
--     --     Octothorped tt' -> map (WithPos tpos' . Octothorped) <$> expand tt'
--     --     Octothorpe -> pure [tt]
--     --     Parens (WithPos _ (LIdent x) : tts1) | Just m <- Map.lookup x ms -> do
--     --         tts2 <- expands tts1
--     --         local (second (const (Just tpos'))) $ do
--     --             tts3 <- uncurry (applyMacro tpos' tts2) m
--     --             expands tts3
--     --     Parens tts -> par Parens tts
--     --     Brackets tts -> par Brackets tts
--     --     Braces tts -> par Braces tts
--     --     Ellipsis (WithPos epos (LIdent x)) -> case Map.lookup x bs of
--     --         Just (Parens xtts) -> expands xtts
--     --         Just (Brackets xtts) -> expands xtts
--     --         Just (Braces xtts) -> expands xtts
--     --         Just _ ->
--     --             throwError (epos, "Cannot ellipsis splice non-sequence macro pattern variable")
--     --         Nothing -> throwError (epos, "Unbound macro pattern variable")
--     --     Ellipsis (WithPos epos _) ->
--     --         throwError (epos, "Can only ellipsis splice macro pattern variable")

applyMacro :: SrcPos -> [TokenTree] -> Literals -> Rules -> Expand [TokenTree]
applyMacro appPos args lits = \case
    [] -> throwError (appPos, "No rule matched in application of macro")
    (params, template) : rules -> case matchRule (map unpos params, args) of
        Just bindings -> local (first (Map.union bindings)) (expands template)
        Nothing -> applyMacro appPos args lits rules
  where
    matchRule :: ([TokenTree'], [TokenTree]) -> Maybe (Map String TokenTree')
    matchRule = \case
        ([], []) -> Just mempty
        (Ellipsis (WithPos _ x) : xs, ys) ->
            let ms = takeWhileJust (matchPat x) ys
                ys' = drop (length ms) ys
                -- By default, each pattern variable in an ellipsis pattern should be bound to an
                -- empty Parens, even if ys was empty
                ms' = Map.fromSet (const []) (fvPat x) : map (fmap pure) ms
                ms'' = fmap Parens (Map.unionsWith (++) ms')
            in  fmap (Map.union ms'') (matchRule (xs, ys'))
        (x : xs, y : ys) -> liftA2 (Map.union . fmap unpos) (matchPat x y) (matchRule (xs, ys))
        ([], _ : _) -> Nothing
        (_ : _, []) -> Nothing
-- expand :: TokenTree -> Expand [TokenTree]
-- expand (WithPos tpos tt') = do
--     (bs, expPos) <- ask
--     ms <- get
--     let tpos' = tpos { inExpansion = expPos }
--     let tt = WithPos tpos' tt'
--     let par ctor tts = fmap (pure . WithPos tpos' . ctor) (expands tts)
--     case tt' of
--         Lit _ -> pure [tt]
--         LIdent x -> case Map.lookup x bs of
--             Just xtt -> pure [WithPos tpos' xtt]
--             Nothing -> pure [tt]
--         -- LRelName _ -> pure [tt]
--         -- LQualName _ -> pure [tt]
--         Reserved _ -> pure [tt]
--         Keyword _ -> pure [tt]
--         Backslashed tt' -> map (WithPos tpos' . Backslashed) <$> expand tt'
--         Octothorped tt' -> map (WithPos tpos' . Octothorped) <$> expand tt'
--         Octothorpe -> pure [tt]
--         Parens (WithPos _ (LIdent x) : tts1) | Just m <- Map.lookup x ms -> do
--             tts2 <- expands tts1
--             local (second (const (Just tpos'))) $ do
--                 tts3 <- uncurry (applyMacro tpos' tts2) m
--                 expands tts3
--         Parens tts -> par Parens tts
--         Brackets tts -> par Brackets tts
--         Braces tts -> par Braces tts
--         Ellipsis (WithPos epos (LIdent x)) -> case Map.lookup x bs of
--             Just (Parens xtts) -> expands xtts
--             Just (Brackets xtts) -> expands xtts
--             Just (Braces xtts) -> expands xtts
--             Just _ ->
--                 throwError (epos, "Cannot ellipsis splice non-sequence macro pattern variable")
--             Nothing -> throwError (epos, "Unbound macro pattern variable")
--         Ellipsis (WithPos epos _) ->
--             throwError (epos, "Can only ellipsis splice macro pattern variable")

    matchPat :: TokenTree' -> TokenTree -> Maybe (Map String TokenTree)
    matchPat p (WithPos apos a) = case (p, a) of
        (Small x, _) | not (Set.member x lits) -> Just (Map.singleton x (WithPos apos a))
        (Parens xs, Parens ys) -> par xs ys
        (Brackets xs, Brackets ys) -> par xs ys
        (Braces xs, Braces ys) -> par xs ys
        (_, _) | p == a -> Just mempty
               | otherwise -> Nothing
      where
        par xs ys = if length xs == length ys
            then Map.unions <$> zipWithM matchPat (map unpos xs) ys
            else Nothing
-- expands :: [TokenTree] -> Expand [TokenTree]
-- expands = fmap concat . mapM expand

    fvPat = \case
        Small x | not (Set.member x lits) -> Set.singleton x
        Parens tts -> par tts
        Brackets tts -> par tts
        Braces tts -> par tts
        Ellipsis tt -> fvPat (unpos tt)
        _ -> Set.empty
        where par = Set.unions . map (fvPat . unpos)
-- applyMacro :: SrcPos -> [TokenTree] -> Literals -> Rules -> Expand [TokenTree]
-- applyMacro appPos args lits = \case
--     [] -> throwError (appPos, "No rule matched in application of macro")
--     (params, template) : rules -> case matchRule (map unpos params, args) of
--         -- FIXME: bindings should be replaced, not joined?
--         Just bindings -> local (first (Map.union bindings)) (expands template)
--         Nothing -> applyMacro appPos args lits rules
--   where
--     matchRule :: ([TokenTree'], [TokenTree]) -> Maybe (Map String TokenTree')
--     matchRule = \case
--         ([], []) -> Just mempty
--         (Ellipsis (WithPos _ x) : xs, ys) ->
--             let ms = takeWhileJust (matchPat x) ys
--                 ys' = drop (length ms) ys
--                 -- By default, each pattern variable in an ellipsis pattern should be bound to an
--                 -- empty Parens, even if ys was empty
--                 ms' = Map.fromSet (const []) (fvPat x) : map (fmap pure) ms
--                 ms'' = fmap Parens (Map.unionsWith (++) ms')
--             in  fmap (Map.union ms'') (matchRule (xs, ys'))
--         (x : xs, y : ys) -> liftA2 (Map.union . fmap unpos) (matchPat x y) (matchRule (xs, ys))
--         ([], _ : _) -> Nothing
--         (_ : _, []) -> Nothing

--     matchPat :: TokenTree' -> TokenTree -> Maybe (Map String TokenTree)
--     matchPat p (WithPos apos a) = case (p, a) of
--         -- (Small x, _) | not (Set.member x lits) -> Just (Map.singleton x (WithPos apos a))
--         (Parens xs, Parens ys) -> par xs ys
--         (Brackets xs, Brackets ys) -> par xs ys
--         (Braces xs, Braces ys) -> par xs ys
--         (_, _) | p == a -> Just mempty
--                | otherwise -> Nothing
--       where
--         par xs ys = if length xs == length ys
--             then Map.unions <$> zipWithM matchPat (map unpos xs) ys
--             else Nothing

--     fvPat = \case
--         -- Small x | not (Set.member x lits) -> Set.singleton x
--         Parens tts -> par tts
--         Brackets tts -> par tts
--         Braces tts -> par tts
--         Ellipsis tt -> fvPat (unpos tt)
--         _ -> Set.empty
--         where par = Set.unions . map (fvPat . unpos)

M src/Front/Match.hs => src/Front/Match.hs +1 -1
@@ 22,7 22,7 @@ import Lens.Micro.Platform (makeLenses, view, to)

import Misc hiding (augment)
import Pretty
import Front.SrcPos
import SrcPos
import Front.Err
import qualified Front.Inferred as Inferred
import Front.Inferred (Pat (..), Pat'(..), Variant(..))

M src/Front/Monomorphic.hs => src/Front/Monomorphic.hs +11 -10
@@ 21,20 21,21 @@ import Data.Void

import FreeVars
import Misc
import Front.Checked (VariantIx, Span, Access (..), Virt (..))
import Front.Parsed (Const(..))
import Front.Abstract
import Front.Checked
import Front.Concrete (Const(..))
import Front.TypeAst

type TConst = TConst' Void
type Type = Type' Void
type MonoType = Type' Void

data TypedVar = TypedVar
    { tvName :: String
    , tvType :: Type
    , tvType :: MonoType
    }
    deriving (Show, Eq, Ord)

type VariantTypes = [Type]
type VariantTypes = [MonoType]

type VarBindings = [(TypedVar, Access)]



@@ 45,7 46,7 @@ data DecisionTree
    deriving (Show, Eq)

type Ction = (VariantIx, Span, TConst, [Expr])
type Fun = ([TypedVar], (Expr, Type))
type Fun = ([TypedVar], (Expr, MonoType))

data Expr
    = Lit Const


@@ 56,18 57,18 @@ data Expr
    | Let Def Expr
    | Match [Expr] DecisionTree
    | Ction Ction
    | Sizeof Type
    | Absurd Type
    | Sizeof MonoType
    | Absurd MonoType
    deriving (Show, Eq)

type Defs = TopologicalOrder Def
data Def = VarDef VarDef | RecDefs RecDefs deriving (Show, Eq)
type Inst = [Type]
type Inst = [MonoType]
type VarDef = (TypedVar, (Inst, Expr))
type RecDefs = [FunDef]
type FunDef = (TypedVar, (Inst, Fun))
type Datas = Map TConst [(String, VariantTypes)]
type Externs = [(String, Type)]
type Externs = [(String, MonoType)]

data Program = Program Defs Datas Externs
    deriving Show

M src/Front/Monomorphize.hs => src/Front/Monomorphize.hs +1 -1
@@ 40,7 40,7 @@ instance Monoid DefInsts where
monomorphize :: Checked.Program -> Program
monomorphize (Checked.Program (Topo defs) datas externs) =
    let
        callMain = Checked.Var (NonVirt, Checked.TypedVar "main" Checked.mainType)
        callMain = Checked.Var (NonVirt, Checked.TypedVar "main" mainType)
        monoExterns = mapM (\(x, t) -> fmap (x, ) (monotype t)) (Map.toList externs)
        monoDefs = foldr (\d1 md2s -> fmap (uncurry (++)) (monoLet' d1 md2s))
                         (mono callMain $> [])

M src/Front/Parse.hs => src/Front/Parse.hs +68 -82
@@ 5,40 5,46 @@ module Front.Parse (parse, c_validIdent, c_validIdentFirst, c_validIdentRest, c_
import Control.Arrow
import Control.Applicative hiding (many, some)
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.Combinators
import Control.Monad.Writer
import Data.Char
import Data.Maybe
import qualified Data.Set as Set
import Text.Read (readMaybe)
import Lens.Micro.Platform (set, view)

import Misc
import Front.SrcPos
import Front.Lexd hiding (Big, Small)
import SrcPos
import Name
import Front.Lexd hiding (Module, Lit)
import qualified Front.Lexd as Lexd
import Front.TypeAst
--import qualified Front.Concrete as Concrete
import Front.Parser
import Front.Parsed hiding (Lit)
import qualified Front.Parsed as Parsed
import Front.Concrete
import Front.Parsed
-- import qualified Front.Concrete as Cst

parse :: [TokenTree] -> (Either (SrcPos, String) Program, [Message])
parse = runParser' (fmap (\(ds, ts, es) -> Program ds ts es) toplevels)
parse :: Lexd.Module -> (Either (SrcPos, String) ParsedModule, [Message])
parse (Lexd.Module tts) =
    runParser' (fmap (\(is, ds, ts, es) -> Module is ds ts es) toplevels) tts

toplevels :: Parser ([Def], [TypeDef], [Extern])
toplevels :: Parser ([ParsedImport], [ParsedDef], [ParsedTypeDef], [ParsedExtern])
toplevels = fmap mconcat (manyTill toplevel end)
  where
    toplevel = tryParens' $ \topPos -> choice
        [ fmap (\d -> ([d], [], [])) (def topPos)
        , fmap (\t -> ([], [t], [])) data_
        , fmap (\t -> ([], [t], [])) typealias
        , fmap (\e -> ([], [], [e])) extern
        [ fmap (\i -> ([i], [], [], [])) import_
        , fmap (\d -> ([], [d], [], [])) (def topPos)
        , fmap (\t -> ([], [], [t], [])) data_
        , fmap (\t -> ([], [], [t], [])) typealias
        , fmap (\e -> ([], [], [], [e])) extern
        ]

extern :: Parser Extern
import_ :: Parser ParsedImport
import_ = reserved Rimport *> option () (keyword "macros") *> smallName

extern :: Parser ParsedExtern
extern = do
    reserved Rextern
    x@(Id (WithPos pos x')) <- small
    x@(WithPos pos (Ident x')) <- smallLhs
    unless (c_validIdent x') $ scribe
        messages
        [ Warning


@@ 47,35 53,35 @@ extern = do
        ]
    Extern x <$> type_

data_ :: Parser TypeDef
data_ :: Parser ParsedTypeDef
data_ = do
    _ <- reserved Rdata
    let onlyName = fmap (, []) big
    let nameAndSome = parens . liftA2 (,) big . some
    (name, params) <- onlyName <|> nameAndSome small
    let onlyName = fmap (, []) bigLhs
    let nameAndSome = parens . liftA2 (,) bigLhs . some
    (name, params) <- onlyName <|> nameAndSome smallLhs
    constrs <- many (onlyName <|> nameAndSome type_)
    pure (TypeDef name params (ConstructorDefs constrs))

typealias :: Parser TypeDef
typealias :: Parser ParsedTypeDef
typealias = do
    _ <- reserved Rtype
    let onlyName = fmap (, []) big
    let nameAndSome = parens . liftA2 (,) big . some
    (name, params) <- onlyName <|> nameAndSome small
    let onlyName = fmap (, []) bigLhs
    let nameAndSome = parens . liftA2 (,) bigLhs . some
    (name, params) <- onlyName <|> nameAndSome smallLhs
    TypeAlias name params <$> type_

def :: SrcPos -> Parser Def
def :: SrcPos -> Parser ParsedDef
def topPos = (reserved Rdefun *> funDef) <|> (reserved Rdef *> varDef)
  where
    body inner = do
        ds <- many (tryParens' def)
        if null ds then expr else fmap (\b -> WithPos (getPos b) (LetRec ds b)) inner
    varDef = do
        name <- small
        name <- smallLhs
        scm <- option Nothing (fmap Just (keyword "of" *> scheme))
        VarDef topPos name scm <$> body expr
    funDef = do
        name <- small
        name <- smallLhs
        (<|>)
            (do
                params <- withPos $ brackets (some pat)


@@ 89,25 95,22 @@ def topPos = (reserved Rdefun *> funDef) <|> (reserved Rdef *> varDef)
                pure $ FunMatchDef topPos name scm cases'
            )

expr :: Parser Expr
expr :: Parser ParsedExpr
expr = withPos expr'

data BindingLhs
    = VarLhs (Id 'Small)
    | CaseVarLhs Pat
    = VarLhs LhsName
    | CaseVarLhs ParsedPat

expr' :: Parser Expr'
expr' :: Parser ParsedExpr'
expr' = choice [var, lit, eConstructor, etuple, deBruijnFun, deBruijnIndex, pexpr]
  where
    lit = token "constant literal" $ const $ \case
        Lit c -> Just (Parsed.Lit c)
        Lexd.Lit c -> Just (Lit c)
        _ -> Nothing
    eConstructor = fmap Ctor big
    -- FIXME: These positions are completely wack. Gotta get a separate variant in the AST for
    --        pairs. Similar to Box.
    etuple = fmap unpos $ tuple expr (\p -> WithPos p (Ctor (Id (WithPos p "Unit")))) $ \l r ->
        let p = getPos l in WithPos p (App (WithPos p (Ctor (Id (WithPos p "Cons")))) [l, r])
    var = fmap Var small
    eConstructor = fmap Ctor bigName
    etuple = ETuple <$> tuple expr
    var = fmap Var smallName
    pexpr = parens' $ \p -> choice [match, if', fun, let1 p, let', letrec, typeAscr, sizeof, app]
    match = reserved Rmatch
        *> liftA2 Match expr (many (parens (reserved Rcase *> liftA2 (,) pat expr)))


@@ 128,7 131,7 @@ expr' = choice [var, lit, eConstructor, etuple, deBruijnFun, deBruijnIndex, pexp
        scribe deBruijnIndices [i]
        pure (DeBruijnIndex i)
    index = token "integral index" $ const $ \case
        Lit (Int n) | n >= 0 -> Just (fromIntegral n :: Word)
        Lexd.Lit (Int n) | n >= 0 -> Just (fromIntegral n :: Word)
        _ -> Nothing
    let1 p = reserved Rlet1 *> (varLhs <|> caseVarLhs) >>= \case
        VarLhs lhs -> liftA2 (Let1 . Def) (varBinding p lhs) expr


@@ 148,7 151,7 @@ expr' = choice [var, lit, eConstructor, etuple, deBruijnFun, deBruijnIndex, pexp
        binding p = varLhs >>= \case
            VarLhs lhs -> varBinding p lhs
            CaseVarLhs _ -> ice "letrec binding: CaseVarLhs"
    varLhs = fmap VarLhs small
    varLhs = fmap VarLhs smallLhs
    caseVarLhs = fmap CaseVarLhs pat
    varBinding pos lhs = VarDef pos lhs Nothing <$> expr
    typeAscr = reserved Rcolon *> liftA2 TypeAscr expr type_


@@ 158,25 161,24 @@ expr' = choice [var, lit, eConstructor, etuple, deBruijnFun, deBruijnIndex, pexp
        rands <- some expr
        pure (App rator rands)

pat :: Parser Pat
pat :: Parser ParsedPat
pat = choice [patInt, patStr, patCtor, patVar, patTuple, ppat]
  where
    patInt = token "integer literal" $ \p -> \case
        Lit (Int x) -> Just (PInt p x)
        Lexd.Lit (Int x) -> Just (PInt p x)
        _ -> Nothing
    patStr = liftA2 PStr getSrcPos strlit
    strlit = token "string literal" $ const $ \case
        Lit (Str s) -> Just s
        Lexd.Lit (Str s) -> Just s
        _ -> Nothing
    patCtor = fmap (\x -> PConstruction (getPos x) x []) big
    patVar = fmap PVar small
    patTuple = tuple pat (\p -> PConstruction p (Id (WithPos p "Unit")) [])
        $ \l r -> let p = getPos l in PConstruction p (Id (WithPos p "Cons")) [l, r]
    patCtor = fmap (\x -> PConstruction (getPos x) x []) bigName
    patVar = fmap PVar smallLhs
    patTuple = liftA2 PTuple getSrcPos (tuple pat)
    ppat = parens' $ \pos -> choice [patBox pos, patCtion pos]
    patBox pos = reserved RBox *> fmap (PBox pos) pat
    patCtion pos = liftM3 PConstruction (pure pos) big (some pat)
    patCtion pos = liftM3 PConstruction (pure pos) bigName (some pat)

scheme :: Parser Scheme
scheme :: Parser ParsedScheme
scheme = do
    pos <- getSrcPos
    let wrap = fmap (Forall pos Set.empty Set.empty)


@@ 186,43 188,21 @@ scheme = do
        constrs = parens (reserved Rwhere *> fmap Set.fromList (some (parens tapp)))
    wrap nonptype <|> parens (universal <|> wrap ptype)

type_ :: Parser Type
type_ :: Parser ParsedType
type_ = nonptype <|> parens ptype

nonptype :: Parser Type
nonptype = choice [fmap TPrim tprim, fmap TVar tvar, fmap (TConst . (, []) . idstr) big, ttuple]
  where
    tprim = token "primitive type" $ const $ \case
        Lexd.Big ('N' : 'a' : 't' : s) | isWord s -> Just (TNat (read s))
        Lexd.Big ('I' : 'n' : 't' : s) | isWord s -> Just (TInt (read s))
        Lexd.Big "Nat" -> Just TNatSize
        Lexd.Big "Int" -> Just TIntSize
        Lexd.Big "F32" -> Just TF32
        Lexd.Big "F64" -> Just TF64
        _ -> Nothing
    ttuple = tuple type_ (const (TConst ("Unit", []))) $ \l r -> TConst ("Cons", [l, r])

-- | FIXME: Positions in here are kind of bad
tuple :: Parser a -> (SrcPos -> a) -> (a -> a -> a) -> Parser a
tuple p unit f = brackets $ do
    a <- p
    as <- many (try p)
    let ls = a : as
    pos <- gets stOuterPos
    r <- option (unit pos) (try (reserved Rdot *> p))
    pure $ foldr f r ls
nonptype :: Parser ParsedType
nonptype = choice [fmap PTVar tvar, fmap (`PTConst` []) bigName, PTTuple <$> tuple type_]

ptype :: Parser Type
ptype = choice [tfun, tbox, fmap (TConst . second (map snd)) tapp]
  where
    tfun = reserved RFun *> liftA2 TFun (brackets (some type_)) type_
    tbox = reserved RBox *> fmap TBox type_
ptype :: Parser ParsedType
ptype = tfun <|> fmap (uncurry PTConst . second (map snd)) tapp
    where tfun = reserved RFun *> liftA2 PTFun (brackets (some type_)) type_

tapp :: Parser (String, [(SrcPos, Type)])
tapp = liftA2 ((,) . idstr) big (some (liftA2 (,) getSrcPos type_))
tapp :: Parser (ParsedName, [(SrcPos, ParsedType)])
tapp = liftA2 (,) bigName (some (liftA2 (,) getSrcPos type_))

tvar :: Parser TVar
tvar = fmap TVExplicit small
tvar = fmap TVExplicit smallLhs

backslashed' :: (SrcPos -> Parser a) -> Parser a
backslashed' = sexpr "`\\`" $ \case


@@ 239,8 219,14 @@ octothorpe = token "`#`" $ const $ \case
    Octothorpe -> Just ()
    _ -> Nothing

isWord :: String -> Bool
isWord s = isJust (readMaybe s :: Maybe Word)
tuple :: Parser a -> Parser (Tuple a)
tuple p = brackets $ do
    as <- many1 (try p)
    b <- option Nothing (fmap Just (try (reserved Rdot *> p)))
    pure $ Tuple as b

-- isWord :: String -> Bool
-- isWord s = isJust (readMaybe s :: Maybe Word)

-- | Valid identifiers in the C language according to the C11 standard "ISO/IEC 9899:2011",
--   excluding "other implementation-defined characters".

M src/Front/Parsed.hs => src/Front/Parsed.hs +19 -166
@@ 1,170 1,23 @@
{-# LANGUAGE DataKinds #-}
module Front.Parsed where

module Front.Parsed (module Front.Parsed, Const (..), TPrim(..), Type' (..), TConst') where

import qualified Data.Set as Set
import Data.Set (Set)
import Control.Arrow ((>>>))
import Data.Bifunctor

import Front.SrcPos
import FreeVars
import Name
import SrcPos
import Front.Concrete
import Front.TypeAst
import Front.Lexd (Const (..))

data Message = Warning SrcPos String
    deriving Show

data IdCase = Big | Small

newtype Id (case' :: IdCase) = Id (WithPos String)
type ParsedName = WithPos LexdName
data ParsedType
    = PTVar TVar
    | PTConst ParsedName [ParsedType]
    | PTFun [ParsedType] ParsedType
    | PTTuple (Tuple ParsedType)
    deriving (Show, Eq, Ord)

data TVar
    = TVExplicit (Id 'Small)
    | TVImplicit String
    deriving (Show, Eq, Ord)

type TConst = TConst' TVar
type Type = Type' TVar

type ClassConstraint = (String, [(SrcPos, Type)])

data Scheme = Forall SrcPos (Set TVar) (Set ClassConstraint) Type
    deriving (Show, Eq)

data Pat
    = PConstruction SrcPos (Id 'Big) [Pat]
    | PInt SrcPos Int
    | PStr SrcPos String
    | PVar (Id 'Small)
    | PBox SrcPos Pat
    -- TODO: Add special pattern for Lazy
    deriving Show

type FunPats = WithPos [Pat]

data Expr'
    = Lit Const
    | Var (Id 'Small)
    | App Expr [Expr]
    | If Expr Expr Expr
    | Let1 DefLike Expr
    | Let [DefLike] Expr
    | LetRec [Def] Expr
    | TypeAscr Expr Type
    | Match Expr [(Pat, Expr)]
    | Fun FunPats Expr
    | DeBruijnFun Word Expr
    | DeBruijnIndex Word
    | FunMatch [(FunPats, Expr)]
    | Ctor (Id 'Big)
    | Sizeof Type
    deriving (Show, Eq)

type Expr = WithPos Expr'

data Def = VarDef SrcPos (Id 'Small) (Maybe Scheme) Expr
         | FunDef SrcPos (Id 'Small) (Maybe Scheme) FunPats Expr
         | FunMatchDef SrcPos (Id 'Small) (Maybe Scheme) [(FunPats, Expr)]
    deriving (Show, Eq)

data DefLike = Def Def | Deconstr Pat Expr
    deriving (Show, Eq)

newtype ConstructorDefs = ConstructorDefs [(Id 'Big, [Type])]
    deriving (Show, Eq)

data TypeDef
    = TypeDef (Id 'Big) [Id 'Small] ConstructorDefs
    | TypeAlias (Id 'Big) [Id 'Small] Type
    deriving (Show, Eq)

data Extern = Extern (Id 'Small) Type
    deriving (Show, Eq)

data Program = Program [Def] [TypeDef] [Extern]
    deriving (Show, Eq)


instance Eq Pat where
    (==) = curry $ \case
        (PConstruction _ x ps, PConstruction _ x' ps') -> x == x' && ps == ps'
        (PVar x, PVar x') -> x == x'
        _ -> False

instance FreeVars Def (Id 'Small) where
    freeVars = \case
        VarDef _ _ _ rhs -> freeVars rhs
        FunDef _ _ _ pats rhs ->
            Set.difference (freeVars rhs) (Set.unions (map bvPat (unpos pats)))
        FunMatchDef _ _ _ cs -> fvCases (map (first unpos) cs)

instance FreeVars DefLike (Id 'Small) where
    freeVars = \case
        Def d -> freeVars d
        Deconstr _ matchee -> freeVars matchee

instance FreeVars Expr (Id 'Small) where
    freeVars = fvExpr

instance HasPos (Id a) where
    getPos (Id x) = getPos x

instance HasPos Pat where
    getPos = \case
        PConstruction p _ _ -> p
        PInt p _ -> p
        PStr p _ -> p
        PVar v -> getPos v
        PBox p _ -> p


fvExpr :: Expr -> Set (Id 'Small)
fvExpr = unpos >>> fvExpr'
  where
    fvExpr' = \case
        Lit _ -> Set.empty
        Var x -> Set.singleton x
        App f as -> fvApp f as
        If p c a -> fvIf p c a
        Let1 b e -> Set.union (freeVars b) (Set.difference (freeVars e) (bvDefLike b))
        Let bs e -> foldr
            (\b fvs -> Set.union (freeVars b) (Set.difference fvs (bvDefLike b)))
            (freeVars e)
            bs
        LetRec ds e -> fvLet (unzip (map (\d -> (defLhs d, d)) ds)) e
        TypeAscr e _t -> freeVars e
        Match e cs -> fvMatch e cs
        Fun (WithPos _ pats) e -> Set.difference (freeVars e) (Set.unions (map bvPat pats))
        DeBruijnFun _ body -> freeVars body
        DeBruijnIndex _ -> Set.empty
        FunMatch cs -> fvCases (map (first unpos) cs)
        Ctor _ -> Set.empty
        Sizeof _t -> Set.empty
    bvDefLike = \case
        Def d -> Set.singleton (defLhs d)
        Deconstr pat _ -> bvPat pat

defLhs :: Def -> Id 'Small
defLhs = \case
    VarDef _ lhs _ _ -> lhs
    FunDef _ lhs _ _ _ -> lhs
    FunMatchDef _ lhs _ _ -> lhs

fvMatch :: Expr -> [(Pat, Expr)] -> Set (Id 'Small)
fvMatch e cs = Set.union (freeVars e) (fvCases (map (first pure) cs))

fvCases :: [([Pat], Expr)] -> Set (Id 'Small)
fvCases = Set.unions . map (\(ps, e) -> Set.difference (freeVars e) (Set.unions (map bvPat ps)))

bvPat :: Pat -> Set (Id 'Small)
bvPat = \case
    PConstruction _ _ ps -> Set.unions (map bvPat ps)
    PInt _ _ -> Set.empty
    PStr _ _ -> Set.empty
    PVar x -> Set.singleton x
    PBox _ p -> bvPat p

idstr :: Id a -> String
idstr (Id (WithPos _ x)) = x
type ParsedImport = ParsedName
type ParsedScheme = Scheme ParsedType ParsedName
type ParsedPat = Pat ParsedName
type ParsedExpr' = Expr' ParsedType ParsedName
type ParsedExpr = Expr ParsedType ParsedName
type ParsedDef = Def ParsedType ParsedName
type ParsedTypeDef = TypeDef ParsedType
type ParsedExtern = Extern ParsedType
type ParsedModule = Module ParsedType ParsedName

M src/Front/Parser.hs => src/Front/Parser.hs +38 -12
@@ 2,7 2,8 @@

module Front.Parser where

import Control.Applicative hiding (many, some)
import Control.Applicative hiding (some)
import Data.Char
import Control.Monad
import Control.Monad.State.Strict
import Control.Monad.Writer


@@ 15,9 16,11 @@ import qualified Data.Set as Set
import Lens.Micro.Platform (makeLenses, view)

import Misc
import Front.SrcPos
import Front.Lexd
import SrcPos
import Name
import Front.Concrete
import Front.Parsed
import Front.Lexd
import Pretty

data Err = Err


@@ 46,6 49,9 @@ data St = St
    }
    deriving Show

data Message = Warning SrcPos String
    deriving Show

data Out = Out
    { _deBruijnIndices :: [Word]
    , _messages :: [Message]


@@ 174,19 180,36 @@ trySexpr expected extract f = do
    modify (\st -> st { stOuterPos = pOld, stInput = ttsOld })
    pure a

big :: Parser (Id 'Front.Parsed.Big)
big = token "big identifier" $ \p -> \case
    Front.Lexd.Big x -> Just (Id (WithPos p x))
-- small' :: Parser String
-- small' = fmap (idstr . unpos) small

bigLhs :: Parser LhsName
bigLhs = token "big identifier" $ \p -> \case
    Front.Lexd.LIdent x@(Ident (c : _)) | isUpper c -> Just (WithPos p x)
    _ -> Nothing

smallLhs :: Parser LhsName
smallLhs = token "small identifier" $ \p -> \case
    Front.Lexd.LIdent x@(Ident (c : _)) | not (isUpper c) -> Just (WithPos p x)
    _ -> Nothing

bigName :: Parser ParsedName
bigName = bigName' <|> fmap (mapPosd (`LIn` LRelative)) bigLhs

bigName' :: Parser ParsedName
bigName' = token "big name" $ \p -> \case
    Front.Lexd.LName x@(Ident (c : _) `LIn` _) | isUpper c -> Just (WithPos p x)
    _ -> Nothing

small' :: Parser String
small' = fmap idstr small
smallName :: Parser ParsedName
smallName = smallName' <|> fmap (mapPosd (`LIn` LRelative)) smallLhs

small :: Parser (Id 'Front.Parsed.Small)
small = token "small identifier" $ \p -> \case
    Front.Lexd.Small x -> Just (Id (WithPos p x))
smallName' :: Parser ParsedName
smallName' = token "small name" $ \p -> \case
    Front.Lexd.LName x@(Ident (c : _) `LIn` _) | not (isUpper c) -> Just (WithPos p x)
    _ -> Nothing


reserved :: Reserved -> Parser ()
reserved k = token (pretty k) $ const $ \case
    Reserved k' | k == k' -> Just ()


@@ 194,5 217,8 @@ reserved k = token (pretty k) $ const $ \case

keyword :: String -> Parser ()
keyword k = token k $ const $ \case
    Keyword k' | k == k' -> Just ()
    Keyword (Ident k') | k == k' -> Just ()
    _ -> Nothing

many1 :: Parser a -> Parser [a]
many1 p = liftM2 (:) p (many p)

A src/Front/Resolve.hs => src/Front/Resolve.hs +1 -0
@@ 0,0 1,1 @@
module Front.Resolve () where

A src/Front/Resolved.hs => src/Front/Resolved.hs +15 -0
@@ 0,0 1,15 @@
module Front.Resolved where

import Name
import SrcPos
import Front.Concrete
import Front.TypeAst

type ResolvedName = WithPos QualName
type ResolvedTConst = TConst' TVar
type ResolvedType = Type' TVar
type ResolvedPat = Pat ResolvedName
type ResolvedScheme = Scheme ResolvedType ResolvedName
type ResolvedExpr = Expr ResolvedType ResolvedName
type ResolvedDef = Def ResolvedType ResolvedName
type ResolvedModule = Module ResolvedType ResolvedName

M src/Front/Subst.hs => src/Front/Subst.hs +13 -18
@@ 5,61 5,56 @@ import Data.Map (Map)
import Data.Bifunctor
import Data.Maybe

import Front.SrcPos
import SrcPos
import Front.TypeAst
import Front.Abstract
import Front.Inferred

-- | Map of substitutions from type-variables to more specific types
type Subst = TVar -> Maybe Type
type Subst' = Map TVar Type

substDef :: Subst -> Def -> Def
substDef :: Subst -> InferredDef -> InferredDef
substDef s = \case
    VarDef d -> VarDef (second (second (substExpr' s)) d)
    RecDefs ds -> RecDefs (map (second (second (substFun' s))) ds)

substExpr :: Map TVar Type -> Expr -> Expr
substExpr :: Map TVar Type -> InferredExpr -> InferredExpr
substExpr s = substExpr' (flip Map.lookup s)

substExpr' :: Subst -> Expr -> Expr
substExpr' :: Subst -> InferredExpr -> InferredExpr
substExpr' s expr = case expr of
    Lit c -> Lit c
    Var v -> Var (second (substTypedVar s) v)
    EVar x t -> EVar x (subst' s t)
    App f as rt -> App (substExpr' s f) (map (substExpr' s) as) (subst' s rt)
    If p c a -> If (substExpr' s p) (substExpr' s c) (substExpr' s a)
    Let def body -> Let (substDef s def) (substExpr' s body)
    Fun f -> Fun (substFun' s f)
    Match m -> Match (substMatch' s m)
    Match ms (MatchCases cs tps tb) -> Match
        (map (substExpr' s) ms)
        (MatchCases (substCases s cs) (map (subst' s) tps) (subst' s tb))
    Ctor i span' (tx, tts) ps -> Ctor i span' (tx, map (subst' s) tts) (map (subst' s) ps)
    Sizeof t -> Sizeof (subst' s t)

substFun :: Map TVar Type -> Fun -> Fun
substFun :: Map TVar Type -> InferredFun -> InferredFun
substFun s = substFun' (flip Map.lookup s)

substFun' :: Subst -> Fun -> Fun
substFun' :: Subst -> InferredFun -> InferredFun
substFun' s (ps, b) = (map (second (subst' s)) ps, bimap (substExpr' s) (subst' s) b)

substMatch' :: Subst -> Match -> Match
substMatch' s = mapPosd
    (\(ms, cs, tps, tb) ->
        (map (substExpr' s) ms, substCases s cs, map (subst' s) tps, subst' s tb)
    )

substCases :: Subst -> Cases -> Cases
substCases s = map (bimap (mapPosd (map (substPat s))) (substExpr' s))

substPat :: Subst -> Pat -> Pat
substPat s (Pat pos t pat) = Pat pos (subst' s t) $ case pat of
    PWild -> PWild
    PVar v -> PVar (substTypedVar s v)
    PVar x t -> PVar x (subst' s t)
    PBox p -> PBox (substPat s p)
    PCon c ps -> PCon (substCon s c) (map (substPat s) ps)

substCon :: Subst -> Con -> Con
substCon s (Con ix sp ts) = Con ix sp (map (subst' s) ts)

substTypedVar :: Subst -> TypedVar -> TypedVar
substTypedVar s (TypedVar x t) = TypedVar x (subst' s t)

subst :: Map TVar Type -> Type -> Type
subst s = subst' (flip Map.lookup s)


M src/Front/TypeAst.hs => src/Front/TypeAst.hs +33 -25
@@ 5,6 5,9 @@ module Front.TypeAst where

import Data.Word

import Name
import SrcPos

data TPrim
    = TNat Word32
    | TNatSize


@@ 14,7 17,12 @@ data TPrim
    | TF64
    deriving (Show, Eq, Ord)

type TConst' var = (String, [Type' var])
type TConst' var = (QualName, [Type' var])

data TVar
    = TVExplicit (WithPos Ident)
    | TVImplicit String
    deriving (Show, Eq, Ord)

data Type' var
    = TVar var


@@ 31,29 39,29 @@ unTconst _ = Nothing
mainType :: Type' var
mainType = TFun [tUnit] tUnit

tByte :: Type' var
tByte = TPrim (TNat 8)

tBox' :: Type' var -> TConst' var
tBox' t = ("Box", [t])

tStr :: Type' var
tStr = TConst tStr'

tStr' :: TConst' var
tStr' = ("Str", [])

tArray :: Type' var -> Type' var
tArray a = TConst ("Array", [a])

--tByte :: Type' var
--tByte = TPrim (TNat 8)
--
--tBox' :: Type' var -> TConst' var
--tBox' t = ("Box", [t])
--
--tStr :: Type' var
--tStr = TConst tStr'
--
--tStr' :: TConst' var
--tStr' = ("Str", [])
--
--tArray :: Type' var -> Type' var
--tArray a = TConst ("Array", [a])
--
tUnit :: Type' var
tUnit = TConst tUnit'

tUnit' :: TConst' var
tUnit' = ("Unit", [])

tUnit = TConst (QName "Unit" QMBuiltin, [])
--
-- tUnit' :: TConst' var
--tUnit' =
--
tBool :: Type' var
tBool = TConst ("Bool", [])

tBool' :: TConst' var
tBool' = ("Bool", [])
tBool = TConst (QName "Bool" QMBuiltin, [])
--
--tBool' :: TConst' var
--tBool' = ("Bool", [])

A src/Name.hs => src/Name.hs +30 -0
@@ 0,0 1,30 @@
module Name where

import Data.String

newtype Ident = Ident {idstr :: String }
    deriving (Show, Eq, Ord)

data LexdName
    = LIn Ident LexdName
    | LAbsolute
    | LRelative
    deriving (Show, Eq, Ord)

-- | A fully qualified module name, e.g. `/std/collections/hashmap`
--
--   Stored in "reversed" order. So "/std/collections/hashmap" becomes
--   `(QMName "hashmap" (QMName "collections" (QMPkg "std")))`.
data QualModuleName
    = QMName Ident QualModuleName
    | QMPkg Ident
    | QMSelf
    | QMBuiltin
    | QMLocal
    deriving (Show, Eq, Ord)

data QualName = QName Ident QualModuleName
    deriving (Show, Eq, Ord)

instance IsString Ident where
    fromString = Ident

M src/Pretty.hs => src/Pretty.hs +51 -34
@@ 1,6 1,6 @@
{-# LANGUAGE UndecidableInstances #-}

module Pretty (pretty, Pretty(..), prettyTConst) where
module Pretty (pretty, prettyTConst, Pretty(..)) where

import Prelude hiding (showChar)
import Data.Bifunctor


@@ 13,10 13,12 @@ import qualified Prettyprinter as Prettyprint

import Misc
import Front.TypeAst
import Front.SrcPos
import SrcPos
import Name
import qualified Front.Lexd as Lexd
import qualified Front.Parsed as Parsed
import qualified Front.Inferred as Inferred
import qualified Front.Abstract as Ast
import qualified Front.Concrete as Cst
import Front.Resolved


-- Pretty print starting at some indentation depth


@@ 57,23 59,38 @@ instance Pretty Lexd.Reserved where
        Lexd.Rdefmacro -> "defmacro"
        Lexd.Rtype -> "type"

instance Pretty var => Pretty (Type' var) where
instance (Pretty var) => Pretty (Type' var) where
    pretty' _ = prettyType
instance Pretty TPrim where
    pretty' _ = prettyTPrim

instance Pretty Parsed.Scheme where
    pretty' _ (Parsed.Forall _ ps cs t) =
        prettyScheme ps (map (second (map snd)) (Set.toList cs)) t
instance Pretty Parsed.TVar where
instance Pretty ResolvedScheme where
    pretty' _ (Cst.Forall _ ps cs t) =
        prettyScheme ps (map (bimap unpos (map snd)) (Set.toList cs)) t
instance Pretty Ast.Scheme where
    pretty' _ (Ast.Forall ps cs t) = prettyScheme ps (Set.toList cs) t
instance Pretty TVar where
    pretty' _ = prettyTVar
instance Pretty (Parsed.Id a) where
    pretty' _ = Parsed.idstr
instance Pretty Ident where
    pretty' _ = idstr
instance Pretty QualName where
    pretty' _ = prettyQName

instance Pretty Void where
    pretty' _ = absurd

prettyType :: Pretty var => Type' var -> String
prettyQName :: QualName -> String
prettyQName (QName x m) = prettyQMName m ++ "/" ++ pretty x

prettyQMName :: QualModuleName -> String
prettyQMName = \case
    QMBuiltin -> "/builtin"
    QMSelf -> "/self"
    QMPkg p -> "/" ++ pretty p
    QMLocal -> ""
    QMName x m -> prettyQMName m ++ "/" ++ pretty x

prettyType :: (Pretty var) => Type' var -> String
prettyType = \case
    TVar tv -> pretty tv
    TPrim c -> pretty c


@@ 81,48 98,48 @@ prettyType = \case
    TBox t -> prettyTBox t
    TConst tc -> prettyTConst tc

prettyScheme :: (Pretty p, Pretty var) => Set p -> [(String, [Type' var])] -> Type' var -> String
prettyScheme :: (Pretty var) => Set TVar -> [(QualName, [Type' var])] -> Type' var -> String
prettyScheme ps cs t = concat
    [ "(forall (" ++ spcPretty (Set.toList ps) ++ ") "
    , "(where " ++ unwords (map prettyTConst cs) ++ ") "
    , pretty t ++ ")"
    ]

prettyTConst :: (Pretty var) => (String, [Type' var]) -> String
prettyTConst :: (Pretty var) => (QualName, [Type' var]) -> String
prettyTConst = \case
    ("Cons", [t1, t2]) -> "[" ++ pretty t1 ++ prettyConses t2
    ("Cons", []) -> ice "prettyTConst: Cons hasn't two types"
    (c, []) -> c
    (c, ts) -> concat ["(", c, " ", spcPretty ts, ")"]
    (QName "Cons" QMBuiltin, [t1, t2]) -> "[" ++ pretty t1 ++ prettyConses t2
    (QName "Cons" QMBuiltin, []) -> ice "prettyTConst: Cons hasn't two types"
    (c, []) -> pretty c
    (c, ts) -> concat ["(", pretty c, " ", spcPretty ts, ")"]
  where
    prettyConses t = case unTconst t of
        Just ("Cons", [t1, t2]) -> " " ++ pretty t1 ++ prettyConses t2
        Just ("Unit", _) -> "]"
        Just (QName "Cons" QMBuiltin, [t1, t2]) -> " " ++ pretty t1 ++ prettyConses t2
        Just (QName "Unit" QMBuiltin, _) -> "]"
        _ -> " . " ++ pretty t ++ "]"


prettyTBox :: Pretty t => t -> String
prettyTBox t = "(Box " ++ pretty t ++ ")"

prettyTFun :: Pretty var => [Type' var] -> Type' var -> String
prettyTFun :: (Pretty t) => [t] -> t -> String
prettyTFun as b = concat ["(Fun [", spcPretty as, "] ", pretty b, ")"]

prettyTPrim :: Parsed.TPrim -> String
prettyTPrim :: Cst.TPrim -> String
prettyTPrim = \case
    Parsed.TNat w -> "Nat" ++ show w
    Parsed.TNatSize -> "Nat"
    Parsed.TInt w -> "Int" ++ show w
    Parsed.TIntSize -> "Int"
    Parsed.TF32 -> "F32"
    Parsed.TF64 -> "F64"

prettyTVar :: Parsed.TVar -> String
    Cst.TNat w -> "Nat" ++ show w
    Cst.TNatSize -> "Nat"
    Cst.TInt w -> "Int" ++ show w
    Cst.TIntSize -> "Int"
    Cst.TF32 -> "F32"
    Cst.TF64 -> "F64"

prettyTVar :: TVar -> String
prettyTVar = \case
    Parsed.TVExplicit v -> Parsed.idstr v
    Parsed.TVImplicit v -> "•" ++ v
    TVExplicit v -> idstr (unpos v)
    TVImplicit v -> "•" ++ v

instance Pretty Inferred.Scheme where
    pretty' _ (Inferred.Forall ps cs t) = prettyScheme ps (Set.toList cs) t
-- instance Pretty Inferred.Scheme where
--     pretty' _ (Inferred.Forall ps cs t) = prettyScheme ps (Set.toList cs) t

instance Pretty Module where
    pretty' _ = show . Prettyprint.pretty

A src/Query.hs => src/Query.hs +15 -0
@@ 0,0 1,15 @@
module Query where

import Data.Map (Map)

import Name
import qualified Front.Lexd as Lexd
import Front.Parsed
import Front.Resolved

data Cache = Cache
    { lexdCache :: Map QualModuleName Lexd.Module
    , expandedCache :: Map QualModuleName Lexd.Module
    , parsedCache :: Map QualModuleName ParsedModule
    , resolvedCache :: Map QualModuleName ResolvedModule
    }

M src/Sizeof.hs => src/Sizeof.hs +2 -1
@@ 7,7 7,8 @@ import qualified Data.Vector as Vec
import Data.Vector (Vector)

import Misc
import Front.Inferred (Span)
import Front.Abstract
-- import Front.Inferred (Span)
import Front.TypeAst

type SizeofConst tvar = TConst' tvar -> Either tvar Word

R src/Front/SrcPos.hs => src/SrcPos.hs +1 -1
@@ 1,4 1,4 @@
module Front.SrcPos where
module SrcPos where

import Text.Megaparsec.Pos