~jack/misc

792d1a6fb90d4bd33438b6e5e72a62737f8589d3 — Jack Kelly 18 days ago 9e971cb master
stg: implement the eval/apply machine
A stg/src/Language/STG/EAMachine.hs => stg/src/Language/STG/EAMachine.hs +232 -0
@@ 0,0 1,232 @@
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}

module Language.STG.EAMachine where

import Control.Lens hiding (uncons)
import Language.STG.EvalApplyAST
import Relude hiding (Alt)
import qualified Text.Show as S

data S = S
  { expr :: Expr,
    stack :: [Cont],
    heap :: Heap,
    gensyms :: Gensyms
  }
  deriving (Eq, Show)

type Heap = Map Var Object

data Cont
  = CCase Alts
  | CUpdate Var
  | CApply [Atom]
  deriving (Eq, Show)

newtype Gensyms = Gensyms [Var]

instance Show Gensyms where
  show _ = "Gensyms"

instance Eq Gensyms where
  Gensyms xs == Gensyms ys = viaNonEmpty head xs == viaNonEmpty head ys

$(makePrisms ''Cont)

gNext :: Gensyms -> (Gensyms, Var)
gNext (Gensyms (v : vs)) = (Gensyms vs, v)
gNext _ = error "gNext: impossible"

start :: Prog -> S
start p =
  S
    { expr = EAtom $ V "main",
      stack = [],
      heap = p,
      gensyms = Gensyms $ ("_g" <>) . show <$> ([0 ..] :: [Int])
    }

step :: S -> S
step s = case catMaybes $ sequenceA rules s of
  [s'] -> s'
  ss ->
    error $
      "Should only have one matching rule (got: " <> show (length ss) <> ")"
  where
    rules =
      [ rLet,
        rCaseConAny,
        rCase,
        rRet,
        rThunk,
        rUpdate,
        rKnownCall,
        rPrimOp,
        rExact,
        rCallKPap2,
        rTcall,
        rPcall,
        rRetFun
      ]

rCase :: S -> Maybe S
rCase s@S {..} = do
  (e, alts) <- expr ^? _ECase
  guard $ case e of
    EAtom a -> not $ case a of
      L {} -> True
      V v -> heap ^? ix v . to objectIsValue == Just True
    _ -> True
  Just $
    s
      { expr = e,
        stack = CCase alts : stack
      }

rCaseConAny :: S -> Maybe S
rCaseConAny s@S {..} = do
  (EAtom a, alts) <- expr ^? _ECase
  guard $ case a of
    L {} -> True
    V v -> heap ^? ix v . to objectIsValue == Just True
  case (a, alts) of
    (L {}, PrimAlt (Default x e)) ->
      Just $ s {expr = stgSub x a e}
    (V var, ConAlts cas def) ->
      do
        val <- heap ^? ix var
        case val of
          Con c args ->
            let loop :: [ConAlt] -> Maybe S
                loop [] =
                  def <&> \(Default v body) ->
                    s {expr = stgSub v a body}
                loop (ConPat con vars body : cs)
                  | c == con =
                    Just $
                      s
                        { expr = foldl' (flip $ uncurry stgSub) body $ zip vars args
                        }
                  | otherwise = loop cs
             in loop cas
          _ -> def <&> \(Default v body) -> s {expr = stgSub v a body}
    _ -> Nothing

rRet :: S -> Maybe S
rRet s@S {..} = do
  a <- expr ^? _EAtom
  guard $ case a of
    L {} -> True
    V v -> heap ^? ix v . to objectIsValue == Just True
  (CCase alts, s') <- uncons stack
  Just $
    s
      { expr = ECase (EAtom a) alts,
        stack = s'
      }

rLet :: S -> Maybe S
rLet s@S {..} = do
  (x, o, e) <- expr ^? _ELet
  let (g', x') = gNext gensyms
  Just $
    s
      { expr = stgSub x (V x') e,
        heap = heap & at x' ?~ o,
        gensyms = g'
      }

rThunk :: S -> Maybe S
rThunk s@S {..} = do
  v <- expr ^? _EAtom . _V
  e <- heap ^? ix v . _Thunk
  Just $
    s
      { expr = e,
        stack = CUpdate v : stack,
        heap = heap & at v ?~ Blackhole
      }

rUpdate :: S -> Maybe S
rUpdate s@S {..} = do
  V y <- expr ^? _EAtom
  (CUpdate x, s') <- uncons stack
  hy <- heap ^? ix y
  guard $ objectIsValue hy
  Just $ s {stack = s', heap = heap & at x ?~ hy}

rKnownCall :: S -> Maybe S
rKnownCall s@S {..} = do
  (f, Known n, args) <- expr ^? _EFunCall
  guard $ n == length args
  (vars, body) <- heap ^? ix f . _Fun
  Just $ s {expr = foldl' (flip $ uncurry stgSub) body $ zip vars args}

rPrimOp :: S -> Maybe S
rPrimOp s@S {..} = do
  (p, args) <- expr ^? _EPrimOp
  case (p, args) of
    (IAdd, [L (I x), L (I y)]) -> Just $ s {expr = EAtom . L . I $ x + y}
    (ISub, [L (I x), L (I y)]) -> Just $ s {expr = EAtom . L . I $ x - y}
    (IMul, [L (I x), L (I y)]) -> Just $ s {expr = EAtom . L . I $ x * y}
    (DAdd, [L (D x), L (D y)]) -> Just $ s {expr = EAtom . L . D $ x + y}
    (DSub, [L (D x), L (D y)]) -> Just $ s {expr = EAtom . L . D $ x - y}
    (DMul, [L (D x), L (D y)]) -> Just $ s {expr = EAtom . L . D $ x * y}
    (DDiv, [L (D x), L (D y)]) -> Just $ s {expr = EAtom . L . D $ x / y}
    _ -> Nothing

rExact :: S -> Maybe S
rExact s@S {..} = do
  (f, Unknown, args) <- expr ^? _EFunCall
  (vars, body) <- heap ^? ix f . _Fun
  guard $ length vars == length args
  Just $ s {expr = foldl' (flip $ uncurry stgSub) body $ zip vars args}

rCallKPap2 :: S -> Maybe S
rCallKPap2 s@S {..} = do
  (f, Known _, args) <- expr ^? _EFunCall
  (vars, body) <- heap ^? ix f . _Fun
  case compare (length args) (length vars) of
    EQ -> Nothing
    LT ->
      let (g', p) = gNext gensyms
       in Just $
            s
              { expr = EAtom (V p),
                heap = heap & at p ?~ Pap f args,
                gensyms = g'
              }
    GT ->
      let (applied, leftover) = splitAt (length vars) args
       in Just $
            s
              { expr = foldl' (flip $ uncurry stgSub) body $ zip vars applied,
                stack = CApply leftover : stack
              }

rTcall :: S -> Maybe S
rTcall s@S {..} = do
  (f, Unknown, args) <- expr ^? _EFunCall
  e <- heap ^? ix f . _Thunk
  Just $ s {expr = e, stack = CApply args : stack}

rPcall :: S -> Maybe S
rPcall s@S {..} = do
  (f, Unknown, args) <- expr ^? _EFunCall
  (g, rest) <- heap ^? ix f . _Pap
  Just $ s {expr = EFunCall g Unknown $ args ++ rest}

rRetFun :: S -> Maybe S
rRetFun s@S {..} = do
  V f <- expr ^? _EAtom
  hf <- heap ^? ix f
  guard $ case hf of
    Fun {} -> True
    Pap {} -> True
    _ -> False
  (CApply rest, s') <- uncons stack
  Just $ s {expr = EFunCall f Unknown rest, stack = s'}

M stg/src/Language/STG/Eval.hs => stg/src/Language/STG/Eval.hs +1 -1
@@ 64,7 64,7 @@ initialState (Prog (Binds m)) =
      sHeap = heap,
      sTopOfHeap = topOfHeap,
      sGlobalEnv = globals,
      sGenSyms = Gensyms $ Var . ("_g" <>) . show <$> [0 ..]
      sGenSyms = Gensyms $ Var . ("_g" <>) . show <$> ([0 ..] :: [Int])
    }
  where
    (heap, topOfHeap, globals) =

A stg/src/Language/STG/EvalApplyAST.hs => stg/src/Language/STG/EvalApplyAST.hs +173 -0
@@ 0,0 1,173 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedLists #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}

-- | AST for Eval/Apply variant of STG, in the style of
-- https://simonmar.github.io/bib/papers/eval-apply.pdf
module Language.STG.EvalApplyAST where

import Control.Lens.TH
import Relude hiding (Alt)

type Prog = Binds

type Binds = Map Var Object

type Con = Text

type Var = Text

data Literal = I Int | D Double deriving (Eq, Show)

data Atom = L Literal | V Var deriving (Eq, Show)

data Arity = Unknown | Known Int deriving (Eq, Show)

data Alts
  = ConAlts [ConAlt] (Maybe DefAlt)
  | PrimAlt DefAlt
  deriving (Eq, Show)

data ConAlt = ConPat Con [Var] Expr deriving (Eq, Show)

data DefAlt = Default Var Expr deriving (Eq, Show)

data PrimOp = IAdd | ISub | IMul | DAdd | DSub | DMul | DDiv
  deriving (Eq, Show)

data Object
  = Fun [Var] Expr
  | Pap Var [Atom]
  | Con Con [Atom]
  | Thunk Expr
  | Blackhole
  deriving (Eq, Show)

data Expr
  = EAtom Atom
  | EFunCall Var Arity [Atom]
  | EPrimOp PrimOp [Atom]
  | ELet Var Object Expr
  | ECase Expr Alts
  deriving (Eq, Show)

$(makePrisms ''Atom)
$(makePrisms ''Object)
$(makePrisms ''Expr)

stgSub :: Var -> Atom -> Expr -> Expr
stgSub v a e =
  let subArgs :: [Atom] -> [Atom]
      subArgs = map $ \arg -> case arg of
        V v' | v == v' -> a
        _ -> arg
   in case e of
        EAtom (V v') | v == v' -> EAtom a
        EAtom _ -> e
        EFunCall f ar args -> EFunCall f' ar $ subArgs args
          where
            f' =
              if v == f
                then case a of
                  V fv -> fv
                  L _ -> error "cannot sub literal over a function call"
                else f
        EPrimOp p args -> EPrimOp p $ subArgs args
        ELet x o letBody
          | x == v -> e
          | otherwise -> ELet x o' $ stgSub v a letBody
          where
            o' = case o of
              Fun args body
                | v `elem` args -> o
                | otherwise -> Fun args $ stgSub v a body
              Pap {} -> o
              Con c args -> Con c $ subArgs args
              Thunk expr -> Thunk $ stgSub v a expr
              Blackhole -> Blackhole
        ECase c alts -> ECase (stgSub v a c) alts'
          where
            alts' = case alts of
              ConAlts cas def ->
                let mapConAlts = map $ \alt -> case alt of
                      ConPat con args body
                        | v `elem` args -> alt
                        | otherwise -> ConPat con args $ stgSub v a body
                 in ConAlts (mapConAlts cas) $ mapDefAlt <$> def
              PrimAlt def -> PrimAlt $ mapDefAlt def

            mapDefAlt d@(Default v' body)
              | v == v' = d
              | otherwise = Default v' $ stgSub v a body

objectIsValue :: Object -> Bool
objectIsValue = \case
  Fun {} -> True
  Pap {} -> True
  Con {} -> True
  Thunk {} -> False
  Blackhole -> False

stgProg :: Prog
stgProg =
  [ ("intMul", stgIntMul),
    ("main", stgMain),
    ("map", stgMap),
    ("nil", stgNil),
    ("one", Con "I" [L (I 1)]),
    ("two", Con "I" [L (I 2)]),
    ("three", Con "I" [L (I 3)])
  ]

stgIntMul :: Object
stgIntMul =
  Fun ["x", "y"]
    . ECase (EAtom $ V "x")
    $ ConAlts
      [ ConPat "I" ["x#"]
          . ECase (EAtom $ V "y")
          $ ConAlts
            [ ConPat "I" ["y#"]
                . ECase (EPrimOp IMul [V "x#", V "y#"])
                . PrimAlt
                . Default "r#"
                . ELet "r" (Con "I" [V "r#"])
                . EAtom
                $ V "r"
            ]
            Nothing
      ]
      Nothing

stgMain :: Object
stgMain =
  Thunk $
    ELet
      "double"
      (Thunk $ EFunCall "intMul" (Known 2) [V "two"])
      . ELet "t3" (Con "Nil" [])
      . ELet "t2" (Con "Cons" [V "three", V "t3"])
      . ELet "t1" (Con "Cons" [V "two", V "t2"])
      . ELet "list" (Con "Cons" [V "one", V "t1"])
      $ EFunCall "map" (Known 2) [V "double", V "list"]

stgMap :: Object
stgMap =
  Fun ["f", "xs"] $
    ECase
      (EAtom (V "xs"))
      ( ConAlts
          [ ConPat "Nil" [] . EAtom $ V "nil",
            ConPat "Cons" ["y", "ys"]
              . ELet "h" (Thunk $ EFunCall "f" Unknown [V "y"])
              . ELet "t" (Thunk $ EFunCall "map" (Known 2) [V "f", V "ys"])
              . ELet "r" (Con "Cons" [V "h", V "t"])
              . EAtom
              $ V "r"
          ]
          Nothing
      )

stgNil :: Object
stgNil = Con "Nil" []

M stg/stg.cabal => stg/stg.cabal +10 -4
@@ 16,15 16,21 @@ extra-source-files: CHANGELOG.md

library
  build-depends:
    , base        ^>=4.13
    , containers  ^>=0.6.2.1
    , pretty-show ^>=1.10
    , relude      ^>=0.6.0.0
    , base         ^>=4.13
    , bytestring   ^>=0.10.10.0
    , containers   ^>=0.6.2.1
    , language-c   ^>=0.8.3
    , lens         ^>=4.18.1
    , pretty-show  ^>=1.10
    , relude       ^>=0.6.0.0

  ghc-options:        -Wall
  hs-source-dirs:     src
  default-extensions: NoImplicitPrelude
  default-language:   Haskell2010
  exposed-modules:
    Language.STG
    Language.STG.AST
    Language.STG.EAMachine
    Language.STG.Eval
    Language.STG.EvalApplyAST

M stg/stg.nix => stg/stg.nix +6 -2
@@ 1,9 1,13 @@
{ mkDerivation, base, containers, pretty-show, relude, stdenv }:
{ mkDerivation, base, bytestring, containers, language-c, lens
, pretty-show, relude, stdenv
}:
mkDerivation {
  pname = "stg";
  version = "0.1.0.0";
  src = ./.;
  libraryHaskellDepends = [ base containers pretty-show relude ];
  libraryHaskellDepends = [
    base bytestring containers language-c lens pretty-show relude
  ];
  description = "Exploring the STG";
  license = stdenv.lib.licenses.agpl3Plus;
}