{-# LANGUAGE DerivingVia #-}

{- | Simple Inlining Optimization Pass

Performs preinline and postinline unconditionally.
TODO: Callsite Inline
-}
module IR.Simplify (
  simplifyProgram,
) where

import qualified Common.Compiler as Compiler
import Control.Monad.Except (MonadError (..))
import Control.Monad.State.Lazy (
  MonadState,
  StateT (..),
  evalStateT,
  gets,
  modify,
 )
import Data.Bifunctor (second)
import Data.Generics (Typeable)
import qualified Data.Map as M
import qualified Data.Maybe as Ma
import IR.IR (unfoldLambda)
import qualified IR.IR as I


type InVar = I.VarId


-- | type OutVar = I.VarId -- used by callsite inline
type InExpr = I.Expr I.Type


type OutExpr = I.Expr I.Type
type InScopeSet = String
type Context = String


type Subst = M.Map InVar SubstRng
data SubstRng = DoneEx OutExpr | SuspEx InExpr Subst
  deriving (Typeable)
  deriving (Int -> SubstRng -> ShowS
[SubstRng] -> ShowS
SubstRng -> String
(Int -> SubstRng -> ShowS)
-> (SubstRng -> String) -> ([SubstRng] -> ShowS) -> Show SubstRng
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SubstRng] -> ShowS
$cshowList :: [SubstRng] -> ShowS
show :: SubstRng -> String
$cshow :: SubstRng -> String
showsPrec :: Int -> SubstRng -> ShowS
$cshowsPrec :: Int -> SubstRng -> ShowS
Show)


{-  | Occurrence Information for each binding

Dead: Does not appear at all
OnceSafe: Appears once, NOT inside a lambda
MultiSafe: The binder occurs at most ONCE in each of several distinct case branches;
           NONE of these ocurrences is inside a lambda
OnceUnsafe: Binder occurs exactly once, but inside a lambda.
MultiUnsafe: Binder may occur many times, including inside lambdas.
           Variables exported from the module are also makred MultiUnsafe.
LoopBreaker: Chosen to break dependency between mutually recursive defintions.
Never: Never inline; we use this to develop our inliner incrementally.
-}
data OccInfo
  = Dead
  | LoopBreaker -- TBD
  | OnceSafe
  | MultiSafe
  | OnceUnsafe
  | MultiUnsafe
  | Never
  | ConstructorFunc
  deriving (Int -> OccInfo -> ShowS
[OccInfo] -> ShowS
OccInfo -> String
(Int -> OccInfo -> ShowS)
-> (OccInfo -> String) -> ([OccInfo] -> ShowS) -> Show OccInfo
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [OccInfo] -> ShowS
$cshowList :: [OccInfo] -> ShowS
show :: OccInfo -> String
$cshow :: OccInfo -> String
showsPrec :: Int -> OccInfo -> ShowS
$cshowsPrec :: Int -> OccInfo -> ShowS
Show)
  deriving (Typeable)


-- | Simplifier Environment
data SimplEnv = SimplEnv
  { SimplEnv -> Map VarId OccInfo
occInfo :: M.Map I.VarId OccInfo
  -- ^ 'occInfo' maps an identifier to its occurence category
  , SimplEnv -> Map VarId SubstRng
subst :: M.Map InVar SubstRng
  -- ^ 'subst' maps an identifier to its substitution
  , SimplEnv -> Int
runs :: Int
  -- ^ 'runs' stores how many times the simplifier has run so far
  , SimplEnv -> Int
countLambda :: Int
  -- ^ 'countLambda' how many lambdas the occurence analyzer is inside
  , SimplEnv -> Int
countMatch :: Int
  -- ^ 'countLambda' how many matches the occurence analyzer is inside
  }
  deriving (Int -> SimplEnv -> ShowS
[SimplEnv] -> ShowS
SimplEnv -> String
(Int -> SimplEnv -> ShowS)
-> (SimplEnv -> String) -> ([SimplEnv] -> ShowS) -> Show SimplEnv
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SimplEnv] -> ShowS
$cshowList :: [SimplEnv] -> ShowS
show :: SimplEnv -> String
$cshow :: SimplEnv -> String
showsPrec :: Int -> SimplEnv -> ShowS
$cshowsPrec :: Int -> SimplEnv -> ShowS
Show)
  deriving (Typeable)


-- | Simplifier Monad
newtype SimplFn a = SimplFn (StateT SimplEnv Compiler.Pass a)
  deriving (a -> SimplFn b -> SimplFn a
(a -> b) -> SimplFn a -> SimplFn b
(forall a b. (a -> b) -> SimplFn a -> SimplFn b)
-> (forall a b. a -> SimplFn b -> SimplFn a) -> Functor SimplFn
forall a b. a -> SimplFn b -> SimplFn a
forall a b. (a -> b) -> SimplFn a -> SimplFn b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SimplFn b -> SimplFn a
$c<$ :: forall a b. a -> SimplFn b -> SimplFn a
fmap :: (a -> b) -> SimplFn a -> SimplFn b
$cfmap :: forall a b. (a -> b) -> SimplFn a -> SimplFn b
Functor) via (StateT SimplEnv Compiler.Pass)
  deriving (Functor SimplFn
a -> SimplFn a
Functor SimplFn
-> (forall a. a -> SimplFn a)
-> (forall a b. SimplFn (a -> b) -> SimplFn a -> SimplFn b)
-> (forall a b c.
    (a -> b -> c) -> SimplFn a -> SimplFn b -> SimplFn c)
-> (forall a b. SimplFn a -> SimplFn b -> SimplFn b)
-> (forall a b. SimplFn a -> SimplFn b -> SimplFn a)
-> Applicative SimplFn
SimplFn a -> SimplFn b -> SimplFn b
SimplFn a -> SimplFn b -> SimplFn a
SimplFn (a -> b) -> SimplFn a -> SimplFn b
(a -> b -> c) -> SimplFn a -> SimplFn b -> SimplFn c
forall a. a -> SimplFn a
forall a b. SimplFn a -> SimplFn b -> SimplFn a
forall a b. SimplFn a -> SimplFn b -> SimplFn b
forall a b. SimplFn (a -> b) -> SimplFn a -> SimplFn b
forall a b c. (a -> b -> c) -> SimplFn a -> SimplFn b -> SimplFn c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SimplFn a -> SimplFn b -> SimplFn a
$c<* :: forall a b. SimplFn a -> SimplFn b -> SimplFn a
*> :: SimplFn a -> SimplFn b -> SimplFn b
$c*> :: forall a b. SimplFn a -> SimplFn b -> SimplFn b
liftA2 :: (a -> b -> c) -> SimplFn a -> SimplFn b -> SimplFn c
$cliftA2 :: forall a b c. (a -> b -> c) -> SimplFn a -> SimplFn b -> SimplFn c
<*> :: SimplFn (a -> b) -> SimplFn a -> SimplFn b
$c<*> :: forall a b. SimplFn (a -> b) -> SimplFn a -> SimplFn b
pure :: a -> SimplFn a
$cpure :: forall a. a -> SimplFn a
$cp1Applicative :: Functor SimplFn
Applicative) via (StateT SimplEnv Compiler.Pass)
  deriving (Applicative SimplFn
a -> SimplFn a
Applicative SimplFn
-> (forall a b. SimplFn a -> (a -> SimplFn b) -> SimplFn b)
-> (forall a b. SimplFn a -> SimplFn b -> SimplFn b)
-> (forall a. a -> SimplFn a)
-> Monad SimplFn
SimplFn a -> (a -> SimplFn b) -> SimplFn b
SimplFn a -> SimplFn b -> SimplFn b
forall a. a -> SimplFn a
forall a b. SimplFn a -> SimplFn b -> SimplFn b
forall a b. SimplFn a -> (a -> SimplFn b) -> SimplFn b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SimplFn a
$creturn :: forall a. a -> SimplFn a
>> :: SimplFn a -> SimplFn b -> SimplFn b
$c>> :: forall a b. SimplFn a -> SimplFn b -> SimplFn b
>>= :: SimplFn a -> (a -> SimplFn b) -> SimplFn b
$c>>= :: forall a b. SimplFn a -> (a -> SimplFn b) -> SimplFn b
$cp1Monad :: Applicative SimplFn
Monad) via (StateT SimplEnv Compiler.Pass)
  deriving (Monad SimplFn
Monad SimplFn
-> (forall a. String -> SimplFn a) -> MonadFail SimplFn
String -> SimplFn a
forall a. String -> SimplFn a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> SimplFn a
$cfail :: forall a. String -> SimplFn a
$cp1MonadFail :: Monad SimplFn
MonadFail) via (StateT SimplEnv Compiler.Pass)
  deriving (MonadError Compiler.Error) via (StateT SimplEnv Compiler.Pass)
  deriving (MonadState SimplEnv) via (StateT SimplEnv Compiler.Pass)
  deriving (Typeable)


-- | Run a SimplFn computation.
runSimplFn :: SimplFn a -> Compiler.Pass a
runSimplFn :: SimplFn a -> Pass a
runSimplFn (SimplFn StateT SimplEnv Pass a
m) =
  StateT SimplEnv Pass a -> SimplEnv -> Pass a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    StateT SimplEnv Pass a
m
    SimplEnv :: Map VarId OccInfo
-> Map VarId SubstRng -> Int -> Int -> Int -> SimplEnv
SimplEnv
      { occInfo :: Map VarId OccInfo
occInfo = Map VarId OccInfo
forall k a. Map k a
M.empty
      , runs :: Int
runs = Int
0
      , countLambda :: Int
countLambda = Int
0
      , countMatch :: Int
countMatch = Int
0
      , subst :: Map VarId SubstRng
subst = Map VarId SubstRng
forall k a. Map k a
M.empty
      }


-- | Add a binder to occInfo with category Dead by default
addOccVar :: I.VarId -> SimplFn ()
addOccVar :: VarId -> SimplFn ()
addOccVar VarId
binder = do
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  let m' :: Map VarId OccInfo
m' = case VarId -> Map VarId OccInfo -> Maybe OccInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarId
binder Map VarId OccInfo
m of
        Maybe OccInfo
Nothing -> VarId -> OccInfo -> Map VarId OccInfo -> Map VarId OccInfo
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
binder OccInfo
Dead Map VarId OccInfo
m
        Maybe OccInfo
_ -> VarId -> OccInfo -> Map VarId OccInfo -> Map VarId OccInfo
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
binder OccInfo
ConstructorFunc Map VarId OccInfo
m
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{occInfo :: Map VarId OccInfo
occInfo = Map VarId OccInfo
m'}


-- | Update occInfo for the binder since we just spotted it
updateOccVar :: I.VarId -> SimplFn ()
updateOccVar :: VarId -> SimplFn ()
updateOccVar VarId
binder = do
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  Bool
insidel <- SimplFn Bool
insideLambda
  Bool
insidem <- SimplFn Bool
insideMatch
  let m' :: Map VarId OccInfo
m' = case VarId -> Map VarId OccInfo -> Maybe OccInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarId
binder Map VarId OccInfo
m of
        Maybe OccInfo
Nothing ->
          String -> Map VarId OccInfo
forall a. HasCallStack => String -> a
error
            ( String
"UDPATE: We should already know about this binder "
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ VarId -> String
forall a. Show a => a -> String
show VarId
binder
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" :"
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m
                String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"!"
            )
        Just OccInfo
Dead ->
          -- we only handle OnceSafe currently
          -- if we're inside a lambda, binder is NOT OnceSafe (in fact, it's OnceUnsafe...)
          if Bool
insidel Bool -> Bool -> Bool
|| Bool
insidem
            then VarId -> OccInfo -> Map VarId OccInfo -> Map VarId OccInfo
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
binder OccInfo
Never Map VarId OccInfo
m
            else VarId -> OccInfo -> Map VarId OccInfo -> Map VarId OccInfo
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
binder OccInfo
OnceSafe Map VarId OccInfo
m
        Maybe OccInfo
_ -> VarId -> OccInfo -> Map VarId OccInfo -> Map VarId OccInfo
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
binder OccInfo
Never Map VarId OccInfo
m
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{occInfo :: Map VarId OccInfo
occInfo = Map VarId OccInfo
m'}


{- | Add substitution to the substitution set

Suppose we want to replace x with y.
Then we call insertSubst x (SuspEx y {})
-}
insertSubst :: I.VarId -> SubstRng -> SimplFn ()
insertSubst :: VarId -> SubstRng -> SimplFn ()
insertSubst VarId
binder SubstRng
rng = do
  Map VarId SubstRng
m <- (SimplEnv -> Map VarId SubstRng) -> SimplFn (Map VarId SubstRng)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId SubstRng
subst
  let m' :: Map VarId SubstRng
m' = VarId -> SubstRng -> Map VarId SubstRng -> Map VarId SubstRng
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
binder SubstRng
rng Map VarId SubstRng
m
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{subst :: Map VarId SubstRng
subst = Map VarId SubstRng
m'}


-- | Record that the ocurrence analyzer is looking inside a lambda
recordEnteringLambda :: SimplFn ()
recordEnteringLambda :: SimplFn ()
recordEnteringLambda = do
  Int
curCount <- (SimplEnv -> Int) -> SimplFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Int
countLambda
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{countLambda :: Int
countLambda = Int
curCount Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}


-- | Record that the ocurrence analyzer is no longer looking inside a lambda
recordExitingLambda :: SimplFn ()
recordExitingLambda :: SimplFn ()
recordExitingLambda = do
  Int
curCount <- (SimplEnv -> Int) -> SimplFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Int
countLambda
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{countLambda :: Int
countLambda = Int
curCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1}


-- | Returns whether ocurrence analyzer is currently looking inside a lambda
insideLambda :: SimplFn Bool
insideLambda :: SimplFn Bool
insideLambda = do
  Int
curCount <- (SimplEnv -> Int) -> SimplFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Int
countLambda
  Bool -> SimplFn Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
curCount Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0)


-- | Record that the ocurrence analyzer is looking inside a match
recordEnteringMatch :: SimplFn ()
recordEnteringMatch :: SimplFn ()
recordEnteringMatch = do
  Int
curCount <- (SimplEnv -> Int) -> SimplFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Int
countMatch
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{countMatch :: Int
countMatch = Int
curCount Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}


-- | Record that the ocurrence analyzer is no longer looking inside a match
recordExitingMatch :: SimplFn ()
recordExitingMatch :: SimplFn ()
recordExitingMatch = do
  Int
curCount <- (SimplEnv -> Int) -> SimplFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Int
countMatch
  (SimplEnv -> SimplEnv) -> SimplFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((SimplEnv -> SimplEnv) -> SimplFn ())
-> (SimplEnv -> SimplEnv) -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ \SimplEnv
st -> SimplEnv
st{countMatch :: Int
countMatch = Int
curCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1}


-- | Returns whether ocurrence analyzer is currently looking inside a match
insideMatch :: SimplFn Bool
insideMatch :: SimplFn Bool
insideMatch = do
  Int
curCount <- (SimplEnv -> Int) -> SimplFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Int
countMatch
  Bool -> SimplFn Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
curCount Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0)


{- | Entry-point to Simplifer.

Maps over top level definitions to create a new simplified Program
-}
simplifyProgram :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
simplifyProgram :: Program Type -> Pass (Program Type)
simplifyProgram Program Type
p = SimplFn (Program Type) -> Pass (Program Type)
forall a. SimplFn a -> Pass a
runSimplFn (SimplFn (Program Type) -> Pass (Program Type))
-> SimplFn (Program Type) -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ do
  String
_ <- Program Type -> SimplFn String
runOccAnal Program Type
p -- run the occurrence analyzer
  -- fail and print out the results of the occurence analyzer
  -- info <- runOccAnal p
  -- _ <- Compiler.unexpected $ show info
  [(Binder Type, Expr Type)]
simplifiedProgramDefs <- ((Binder Type, Expr Type) -> SimplFn (Binder Type, Expr Type))
-> [(Binder Type, Expr Type)] -> SimplFn [(Binder Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type) -> SimplFn (Binder Type, Expr Type)
simplTop (Program Type -> [(Binder Type, Expr Type)]
forall t. Program t -> [(Binder t, Expr t)]
I.programDefs Program Type
p)
  Program Type -> SimplFn (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Program Type -> SimplFn (Program Type))
-> Program Type -> SimplFn (Program Type)
forall a b. (a -> b) -> a -> b
$ Program Type
p{programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [(Binder Type, Expr Type)]
simplifiedProgramDefs} -- this whole do expression returns a Compiler.Pass


-- | Simplify a top-level definition
simplTop :: (I.Binder I.Type, I.Expr I.Type) -> SimplFn (I.Binder I.Type, I.Expr I.Type)
simplTop :: (Binder Type, Expr Type) -> SimplFn (Binder Type, Expr Type)
simplTop (Binder Type
v, Expr Type
e) = do
  (,) Binder Type
v (Expr Type -> (Binder Type, Expr Type))
-> SimplFn (Expr Type) -> SimplFn (Binder Type, Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
forall k a. Map k a
M.empty String
"inscopeset" Expr Type
e String
"context"


{- | Simplify an IR expression.

Probably want more documentation here eventually.
For now we ignore the in scope set and context args.

How do we handle each node?
- Var VarId t                           DONE (except callsite inline)
- Data DConId t                         Default case
- Lit Literal t                         Default case
- App (Expr t) (Expr t) t               DONE
- Let [(Binder, Expr t)] (Expr t) t     DONE (except callsite inline)
- Lambda Binder (Expr t) t              DONE
- Match (Expr t) [(Alt, Expr t)] t      DONE (TODO: match arm elimination)
- Prim Primitive [Expr t] t             DONE
-}
simplExpr :: Subst -> InScopeSet -> InExpr -> Context -> SimplFn OutExpr


{- | Simplify Primitive Expression

  Simplify each of the arguments to prim
  For ex. 5 + g + h + (6 + v)
  Is the primitive "plus" followed by args 5, g, h, and (6 + v)
  SimpleExpr calls itself on each of "plus"'s arguments, then
  returns the result wrapped back up in an I.Prim IR node.
-}
simplExpr :: Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins (I.Prim Primitive
prim [Expr Type]
args Type
t) String
cont = do
  [Expr Type]
args' <- (Expr Type -> SimplFn (Expr Type))
-> [Expr Type] -> SimplFn [Expr Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (((String -> SimplFn (Expr Type)) -> String -> SimplFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ String
cont) ((String -> SimplFn (Expr Type)) -> SimplFn (Expr Type))
-> (Expr Type -> String -> SimplFn (Expr Type))
-> Expr Type
-> SimplFn (Expr Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins) [Expr Type]
args
  Expr Type -> SimplFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Primitive -> [Expr Type] -> Type -> Expr Type
forall t. Primitive -> [Expr t] -> t -> Expr t
I.Prim Primitive
prim [Expr Type]
args' Type
t)

-- \| Simplify Match Expression
simplExpr Map VarId SubstRng
sub String
ins (I.Match Expr Type
scrutinee [(Alt Type, Expr Type)]
arms Type
t) String
cont = do
  Expr Type
scrutinee' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
scrutinee String
cont
  let ([Alt Type]
pats, [Expr Type]
rhss) = [(Alt Type, Expr Type)] -> ([Alt Type], [Expr Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Alt Type, Expr Type)]
arms
  [Expr Type]
rhss' <- [SimplFn (Expr Type)] -> SimplFn [Expr Type]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([SimplFn (Expr Type)] -> SimplFn [Expr Type])
-> [SimplFn (Expr Type)] -> SimplFn [Expr Type]
forall a b. (a -> b) -> a -> b
$ (Expr Type -> String -> SimplFn (Expr Type))
-> [Expr Type] -> String -> [SimplFn (Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins) [Expr Type]
rhss String
cont
  let results :: [(Alt Type, Expr Type)]
results = [Alt Type] -> [Expr Type] -> [(Alt Type, Expr Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Alt Type]
pats [Expr Type]
rhss'
  Expr Type -> SimplFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> [(Alt Type, Expr Type)] -> Type -> Expr Type
forall t. Expr t -> [(Alt t, Expr t)] -> t -> Expr t
I.Match Expr Type
scrutinee' [(Alt Type, Expr Type)]
results Type
t)

-- \| Simplify Application Expression
simplExpr Map VarId SubstRng
sub String
ins (I.App Expr Type
lhs Expr Type
rhs Type
t) String
cont = do
  Expr Type
lhs' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
lhs String
cont
  Expr Type
rhs' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
rhs String
cont
  Expr Type -> SimplFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> Expr Type -> Type -> Expr Type
forall t. Expr t -> Expr t -> t -> Expr t
I.App Expr Type
lhs' Expr Type
rhs' Type
t)

-- \| Simplify Lambda Expression
simplExpr Map VarId SubstRng
sub String
ins (I.Lambda Binder Type
binder Expr Type
body Type
t) String
cont = do
  Expr Type
body' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
body String
cont
  Expr Type -> SimplFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binder Type -> Expr Type -> Type -> Expr Type
forall t. Binder t -> Expr t -> t -> Expr t
I.Lambda Binder Type
binder Expr Type
body' Type
t)

-- \| Simplify Variable Expression
simplExpr Map VarId SubstRng
_ String
ins var :: Expr Type
var@(I.Var VarId
v Type
_) String
cont = do
  Map VarId SubstRng
m <- (SimplEnv -> Map VarId SubstRng) -> SimplFn (Map VarId SubstRng)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId SubstRng
subst
  case VarId -> Map VarId SubstRng -> Maybe SubstRng
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarId
v Map VarId SubstRng
m of
    Maybe SubstRng
Nothing -> Expr Type -> SimplFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
var -- callsite inline, future work
    Just (SuspEx Expr Type
e Map VarId SubstRng
s) -> Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
s String
ins Expr Type
e String
cont
    Just (DoneEx Expr Type
e) -> Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
forall k a. Map k a
M.empty String
ins Expr Type
e String
cont

-- \| Simplify Let Expressions
simplExpr Map VarId SubstRng
sub String
ins (I.Let [(Binder Type, Expr Type)]
binders Expr Type
body Type
t) String
cont = do
  [(Maybe (Binder Type, Expr Type), Map VarId SubstRng)]
simplified <- ((Binder Type, Expr Type)
 -> SimplFn (Maybe (Binder Type, Expr Type), Map VarId SubstRng))
-> [(Binder Type, Expr Type)]
-> SimplFn [(Maybe (Binder Type, Expr Type), Map VarId SubstRng)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type)
-> SimplFn (Maybe (Binder Type, Expr Type), Map VarId SubstRng)
forall t.
(Binder t, Expr Type)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
simplBinder [(Binder Type, Expr Type)]
binders
  let ([Maybe (Binder Type, Expr Type)]
simplBinders, [Map VarId SubstRng]
subs) = [(Maybe (Binder Type, Expr Type), Map VarId SubstRng)]
-> ([Maybe (Binder Type, Expr Type)], [Map VarId SubstRng])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Maybe (Binder Type, Expr Type), Map VarId SubstRng)]
simplified
  let binders' :: [(Binder Type, Expr Type)]
binders' = [Maybe (Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. [Maybe a] -> [a]
Ma.catMaybes [Maybe (Binder Type, Expr Type)]
simplBinders
  let subs' :: Map VarId SubstRng
subs' = (Map VarId SubstRng -> Map VarId SubstRng -> Map VarId SubstRng)
-> [Map VarId SubstRng] -> Map VarId SubstRng
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Map VarId SubstRng -> Map VarId SubstRng -> Map VarId SubstRng
forall a. Semigroup a => a -> a -> a
(<>) [Map VarId SubstRng]
subs
  Expr Type
body' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
subs' String
ins Expr Type
body String
cont
  if [(Binder Type, Expr Type)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Binder Type, Expr Type)]
binders' then Expr Type -> SimplFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
body' else Expr Type -> SimplFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([(Binder Type, Expr Type)] -> Expr Type -> Type -> Expr Type
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let [(Binder Type, Expr Type)]
binders' Expr Type
body' Type
t)
 where
  simplBinder
    :: (I.Binder t, I.Expr I.Type)
    -> SimplFn (Maybe (I.Binder t, I.Expr I.Type), Subst)
  simplBinder :: (Binder t, Expr Type)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
simplBinder (Binder t
binder, Expr Type
rhs) = do
    Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
    case Binder t
binder of
      (I.BindVar VarId
v t
_) -> case VarId -> Map VarId OccInfo -> Maybe OccInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarId
v Map VarId OccInfo
m of
        (Just OccInfo
Dead) -> (Maybe (Binder t, Expr Type), Map VarId SubstRng)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Binder t, Expr Type)
forall a. Maybe a
Nothing, Map VarId SubstRng
sub) -- get rid of this dead binding
        (Just OccInfo
OnceSafe) -> do
          -- preinline test PASSES
          -- bind x to E singleton :: k -> a -> Map k a
          VarId -> SubstRng -> SimplFn ()
insertSubst VarId
v (SubstRng -> SimplFn ()) -> SubstRng -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ Expr Type -> Map VarId SubstRng -> SubstRng
SuspEx Expr Type
rhs Map VarId SubstRng
forall k a. Map k a
M.empty
          let sub' :: Map VarId SubstRng
sub' = VarId -> SubstRng -> Map VarId SubstRng
forall k a. k -> a -> Map k a
M.singleton VarId
v (Expr Type -> Map VarId SubstRng -> SubstRng
SuspEx Expr Type
rhs Map VarId SubstRng
sub)
          (Maybe (Binder t, Expr Type), Map VarId SubstRng)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Binder t, Expr Type)
forall a. Maybe a
Nothing, Map VarId SubstRng
sub')
        Maybe OccInfo
_ -> do
          -- preinline test FAILS, so do post inline unconditionally
          Expr Type
e' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
rhs String
cont -- process the RHS
          case Expr Type
e' of
            -- x goes here.
            (I.Lit Literal
_ Type
_) -> do
              VarId -> SubstRng -> SimplFn ()
insertSubst VarId
v (SubstRng -> SimplFn ()) -> SubstRng -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ Expr Type -> SubstRng
DoneEx Expr Type
e'
              (Maybe (Binder t, Expr Type), Map VarId SubstRng)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Binder t, Expr Type)
forall a. Maybe a
Nothing, Map VarId SubstRng
forall k a. Map k a
M.empty) -- PASSES postinline
            (I.Var VarId
_ Type
_) -> do
              VarId -> SubstRng -> SimplFn ()
insertSubst VarId
v (SubstRng -> SimplFn ()) -> SubstRng -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ Expr Type -> SubstRng
DoneEx Expr Type
e'
              (Maybe (Binder t, Expr Type), Map VarId SubstRng)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Binder t, Expr Type)
forall a. Maybe a
Nothing, Map VarId SubstRng
forall k a. Map k a
M.empty) -- PASSES postinline
            Expr Type
_ -> do
              Expr Type
rhs' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
rhs String
cont -- we won't inline x, but still possible to simplify e (RHS)
              (Maybe (Binder t, Expr Type), Map VarId SubstRng)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Binder t, Expr Type) -> Maybe (Binder t, Expr Type)
forall a. a -> Maybe a
Just (Binder t
binder, Expr Type
rhs'), Map VarId SubstRng
sub) -- FAILS postinline; someday callsite inline
      Binder t
_ -> do
        Expr Type
e' <- Map VarId SubstRng
-> String -> Expr Type -> String -> SimplFn (Expr Type)
simplExpr Map VarId SubstRng
sub String
ins Expr Type
rhs String
cont
        (Maybe (Binder t, Expr Type), Map VarId SubstRng)
-> SimplFn (Maybe (Binder t, Expr Type), Map VarId SubstRng)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Binder t, Expr Type) -> Maybe (Binder t, Expr Type)
forall a. a -> Maybe a
Just (Binder t
binder, Expr Type
e'), Map VarId SubstRng
sub) -- can't inline wildcards

-- \| for all other expressions, don't do anything
simplExpr Map VarId SubstRng
_ String
_ Expr Type
e String
_ = Expr Type -> SimplFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
e


{- | Run occurrence analyser over each top level function

Returns logging info as a string.
-}
runOccAnal :: I.Program I.Type -> SimplFn String
runOccAnal :: Program Type -> SimplFn String
runOccAnal I.Program{programDefs :: forall t. Program t -> [(Binder t, Expr t)]
I.programDefs = [(Binder Type, Expr Type)]
defs} = do
  [(Binder Type, Expr Type)]
defs' <- ((Binder Type, Expr Type) -> SimplFn (Binder Type, Expr Type))
-> [(Binder Type, Expr Type)] -> SimplFn [(Binder Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type) -> SimplFn (Binder Type, Expr Type)
forall t. (Binder t, Expr t) -> SimplFn (Binder t, Expr t)
swallowArgs [(Binder Type, Expr Type)]
defs
  [String]
info <- ((Binder Type, Expr Type) -> SimplFn String)
-> [(Binder Type, Expr Type)] -> SimplFn [String]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Binder Type, SimplFn (Expr Type, String)) -> SimplFn String
getOccInfoForDef ((Binder Type, SimplFn (Expr Type, String)) -> SimplFn String)
-> ((Binder Type, Expr Type)
    -> (Binder Type, SimplFn (Expr Type, String)))
-> (Binder Type, Expr Type)
-> SimplFn String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Expr Type -> SimplFn (Expr Type, String))
-> (Binder Type, Expr Type)
-> (Binder Type, SimplFn (Expr Type, String))
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Expr Type -> SimplFn (Expr Type, String)
occAnalExpr) [(Binder Type, Expr Type)]
defs'
  String -> SimplFn String
forall (m :: * -> *) a. Monad m => a -> m a
return ([String] -> String
forall a. Show a => a -> String
show [String]
info)
 where
  {- Take in a top level function, "swallows" its args, and return its body.

  "Swallow" means to add the argument to our occurrence info state.
  It returns a top level function without curried arguments; just the body.
  -}
  swallowArgs :: (I.Binder t, I.Expr t) -> SimplFn (I.Binder t, I.Expr t)
  swallowArgs :: (Binder t, Expr t) -> SimplFn (Binder t, Expr t)
swallowArgs (Binder t
funcName, l :: Expr t
l@I.Lambda{}) = do
    Maybe VarId -> SimplFn ()
addOccs (Maybe VarId -> SimplFn ()) -> Maybe VarId -> SimplFn ()
forall a b. (a -> b) -> a -> b
$ Binder t -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar Binder t
funcName
    let ([Binder t]
args, Expr t
body) = Expr t -> ([Binder t], Expr t)
forall t. Expr t -> ([Binder t], Expr t)
unfoldLambda Expr t
l
    (Binder t -> SimplFn ()) -> [Binder t] -> SimplFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Maybe VarId -> SimplFn ()
addOccs (Maybe VarId -> SimplFn ())
-> (Binder t -> Maybe VarId) -> Binder t -> SimplFn ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder t -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar) [Binder t]
args
    (Binder t, Expr t) -> SimplFn (Binder t, Expr t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binder t
funcName, Expr t
body)
   where
    addOccs :: Maybe VarId -> SimplFn ()
addOccs (Just VarId
nm) = VarId -> SimplFn ()
addOccVar VarId
nm
    addOccs Maybe VarId
Nothing = () -> SimplFn ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  swallowArgs (Binder t
name, Expr t
e) = (Binder t, Expr t) -> SimplFn (Binder t, Expr t)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Binder t
name, Expr t
e)

  -- a hacky way to see the occurence info for each top def
  getOccInfoForDef :: (I.Binder I.Type, SimplFn (I.Expr I.Type, String)) -> SimplFn String
  getOccInfoForDef :: (Binder Type, SimplFn (Expr Type, String)) -> SimplFn String
getOccInfoForDef (Binder Type
v, SimplFn (Expr Type, String)
tpl) = do
    (Expr Type
_, String
occinfo) <- SimplFn (Expr Type, String)
tpl
    String -> SimplFn String
forall (f :: * -> *) a. Applicative f => a -> f a
pure (String -> SimplFn String) -> String -> SimplFn String
forall a b. (a -> b) -> a -> b
$
      String
"START topdef "
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ Binder Type -> String
forall a. Show a => a -> String
show Binder Type
v
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" has OccInfo: "
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
occinfo
        String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" END"


{- | Run the Ocurrence Analyzer over an IR expression node

How do we handle each node?
- Var VarId t                           DONE
- Data DConId t                         Default case (TODO: trivial constructor argument invariant)
- Lit Literal t                         Default case
- App (Expr t) (Expr t) t               DONE
- Let [(Binder, Expr t)] (Expr t) t     DONE
- Lambda Binder (Expr t) t              DONE
- Match (Expr t) [(Alt, Expr t)] t      DONE (TODO: analyze patterns / LHS of arms?)
- Prim Primitive [Expr t] t             DONE
-}
occAnalExpr :: I.Expr I.Type -> SimplFn (I.Expr I.Type, String)


-- | Occurrence Analysis over Let Expression
occAnalExpr :: Expr Type -> SimplFn (Expr Type, String)
occAnalExpr l :: Expr Type
l@(I.Let [(Binder Type, Expr Type)]
nameValPairs Expr Type
body Type
_) = do
  ((Binder Type, Expr Type) -> SimplFn ())
-> [(Binder Type, Expr Type)] -> SimplFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    ( \(Binder Type
binder, Expr Type
rhs) -> do
        (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
rhs
        case Binder Type -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar Binder Type
binder of
          (Just VarId
nm) -> VarId -> SimplFn ()
addOccVar VarId
nm
          Maybe VarId
_ -> () -> SimplFn ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    )
    [(Binder Type, Expr Type)]
nameValPairs
  (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
body
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
l, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)

-- \| Occurrence Analysis over Variable Expression
occAnalExpr var :: Expr Type
var@(I.Var VarId
v Type
_) = do
  VarId -> SimplFn ()
updateOccVar VarId
v
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type
var, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)

-- \| Occurrence Analysis over Lambda Expression
occAnalExpr l :: Expr Type
l@(I.Lambda (I.BindVar VarId
v Type
_) Expr Type
b Type
_) = do
  SimplFn ()
recordEnteringLambda
  VarId -> SimplFn ()
addOccVar VarId
v
  (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
b
  SimplFn ()
recordExitingLambda
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
l, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)
occAnalExpr l :: Expr Type
l@(I.Lambda Binder Type
_ Expr Type
b Type
_) = do
  SimplFn ()
recordEnteringLambda
  (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
b
  SimplFn ()
recordExitingLambda
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
l, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)

-- \| Occurrence Analysis over Application Expression
occAnalExpr a :: Expr Type
a@(I.App Expr Type
lhs Expr Type
rhs Type
_) = do
  (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
lhs
  (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
rhs
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
a, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)

-- \| Occurrence Analysis over Primitve Expression
occAnalExpr p :: Expr Type
p@(I.Prim Primitive
_ [Expr Type]
args Type
_) = do
  (Expr Type -> SimplFn (Expr Type, String))
-> [Expr Type] -> SimplFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Expr Type -> SimplFn (Expr Type, String)
occAnalExpr [Expr Type]
args
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
p, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)

-- \| Occurrence Analysis over Match Expression
occAnalExpr p :: Expr Type
p@(I.Match Expr Type
scrutinee [(Alt Type, Expr Type)]
arms Type
_) = do
  SimplFn ()
recordEnteringMatch
  (Expr Type, String)
_ <- Expr Type -> SimplFn (Expr Type, String)
occAnalExpr Expr Type
scrutinee
  -- let (alts, rhss) = unzip arms
  ((Alt Type, Expr Type) -> SimplFn (Alt Type, String))
-> [(Alt Type, Expr Type)] -> SimplFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Alt Type -> SimplFn (Alt Type, String)
forall t. Alt t -> SimplFn (Alt t, String)
occAnalAlt (Alt Type -> SimplFn (Alt Type, String))
-> ((Alt Type, Expr Type) -> Alt Type)
-> (Alt Type, Expr Type)
-> SimplFn (Alt Type, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alt Type, Expr Type) -> Alt Type
forall a b. (a, b) -> a
fst) [(Alt Type, Expr Type)]
arms
  ((Alt Type, Expr Type) -> SimplFn (Expr Type, String))
-> [(Alt Type, Expr Type)] -> SimplFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Expr Type -> SimplFn (Expr Type, String)
occAnalExpr (Expr Type -> SimplFn (Expr Type, String))
-> ((Alt Type, Expr Type) -> Expr Type)
-> (Alt Type, Expr Type)
-> SimplFn (Expr Type, String)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Alt Type, Expr Type) -> Expr Type
forall a b. (a, b) -> b
snd) [(Alt Type, Expr Type)]
arms
  SimplFn ()
recordExitingMatch
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
p, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)

-- \| for all other expressions, don't do anything
occAnalExpr Expr Type
e = do
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Expr Type, String) -> SimplFn (Expr Type, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expr Type
e, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)


-- | Run Ocurrence Analyser Over a Match Arm
occAnalAlt :: I.Alt t -> SimplFn (I.Alt t, String)
occAnalAlt :: Alt t -> SimplFn (Alt t, String)
occAnalAlt alt :: Alt t
alt@(I.AltBinder Binder t
binder) = do
  case Binder t -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar Binder t
binder of
    (Just VarId
nm) -> VarId -> SimplFn ()
addOccVar VarId
nm
    Maybe VarId
_ -> () -> SimplFn ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Alt t, String) -> SimplFn (Alt t, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Alt t
alt, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)
occAnalAlt alt :: Alt t
alt@(I.AltData DConId
_ [Alt t]
alts t
_) = do
  (Alt t -> SimplFn (Alt t, String)) -> [Alt t] -> SimplFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Alt t -> SimplFn (Alt t, String)
forall t. Alt t -> SimplFn (Alt t, String)
occAnalAlt [Alt t]
alts
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Alt t, String) -> SimplFn (Alt t, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Alt t
alt, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)
occAnalAlt Alt t
lit = do
  Map VarId OccInfo
m <- (SimplEnv -> Map VarId OccInfo) -> SimplFn (Map VarId OccInfo)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets SimplEnv -> Map VarId OccInfo
occInfo
  (Alt t, String) -> SimplFn (Alt t, String)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Alt t
lit, Map VarId OccInfo -> String
forall a. Show a => a -> String
show Map VarId OccInfo
m)