{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}

{- | Description : Insert reference counting primitives

This inserts @dup@ and @drop@ primitives according to a caller @dup@,
callee @drop@ policy.  The value returned by a function should be
passed back referenced (ownership transfers from the callee back to
the caller).

The @dup : a -> a@ primitive behaves like the identity function,
evaluating and returning its first argument and increasing the
reference count on the result.  It is meant to be wrapped around
function arguments.

The @drop : a -> b -> a@ primitive evaluates and returns its first
argument.  It decrements the reference count to its second argument
after it has evaluated its first argument.  It is meant to be wrapped
around function bodies that need to use and then de-reference their
arguments.

Thus, something like

> add a b = a + b

becomes

@
add a b =
  drop
    (drop
       ((dup a) + (dup b))
       b)
    a
@

Arguments @a@ and @b@ to the @+@ primitive are duplicated and the
result of @+@ is duplicated internally, so @add@ does not need to
duplicate its result.  Both arguments @a@ and @b@ are dropped.

Try running @sslc --dump-ir-final@ on an example to see the inserted
@dup@ and @drop@ constructs.

Our approach was inspired by Perceus
<https://www.microsoft.com/en-us/research/publication/perceus-garbage-free-reference-counting-with-reuse/>
-}
module IR.InsertRefCounting (insertRefCounting) where

import qualified Common.Compiler as Compiler
import Common.Identifiers
import Control.Monad.State.Lazy (
  MonadState (..),
  StateT (..),
  forM,
  modify,
 )
import qualified IR.MangleNames as I
import qualified IR.IR as I
import qualified IR.Types as I
import qualified Data.Map as M


-- * The external interface


-- \$external

{- | Insert dup and drop primitives throughout a program

 Applies `insertTop` to the program's definitions
-}
insertRefCounting :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
insertRefCounting :: Program Type -> Pass (Program Type)
insertRefCounting p :: Program Type
p@I.Program{symTable :: forall t. Program t -> Map VarId (SymInfo t)
I.symTable = Map VarId (SymInfo Type)
symTable, programDefs :: forall t. Program t -> [(Binder t, Expr t)]
I.programDefs = [(Binder Type, Expr Type)]
defs} = do
  ([(Binder Type, Expr Type)]
defs', Map VarId (SymInfo Type)
symTable') <- StateT (Map VarId (SymInfo Type)) Pass [(Binder Type, Expr Type)]
-> Map VarId (SymInfo Type)
-> Pass ([(Binder Type, Expr Type)], Map VarId (SymInfo Type))
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (((Binder Type, Expr Type)
 -> StateT (Map VarId (SymInfo Type)) Pass (Binder Type, Expr Type))
-> [(Binder Type, Expr Type)]
-> StateT
     (Map VarId (SymInfo Type)) Pass [(Binder Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Binder Type, Expr Type)
insertTop [(Binder Type, Expr Type)]
defs) Map VarId (SymInfo Type)
symTable
  Program Type -> Pass (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Program Type -> Pass (Program Type))
-> Program Type -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ Program Type
p{ symTable :: Map VarId (SymInfo Type)
I.symTable = Map VarId (SymInfo Type)
symTable', programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [(Binder Type, Expr Type)]
defs' }

-- * Module internals, not intended for use outside this module


--
-- \$internal

-- | Monad for creating fresh variables: add an Int to the pass
type Fresh = StateT (I.SymTable I.Type) Compiler.Pass


{- | Create a fresh variable name with the given name seed.

Uses the symbol table to determine name uniqueness, and keeps it up to date.
-}
getFresh :: String -> I.Type -> Fresh I.VarId
getFresh :: String -> Type -> Fresh VarId
getFresh String
seed Type
t = do
  Map VarId (SymInfo Type)
symTable <- StateT (Map VarId (SymInfo Type)) Pass (Map VarId (SymInfo Type))
forall s (m :: * -> *). MonadState s m => m s
get
  let str :: VarId
str = String -> VarId
forall a. IsString a => String -> a
fromString (String -> VarId) -> String -> VarId
forall a b. (a -> b) -> a -> b
$ String
"__dupdrop_anon_" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
seed
      i :: VarId
i = Map VarId (SymInfo Type) -> VarId -> VarId
forall t. Map VarId t -> VarId -> VarId
I.pickId Map VarId (SymInfo Type)
symTable VarId
str
  (Map VarId (SymInfo Type) -> Map VarId (SymInfo Type))
-> StateT (Map VarId (SymInfo Type)) Pass ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map VarId (SymInfo Type) -> Map VarId (SymInfo Type))
 -> StateT (Map VarId (SymInfo Type)) Pass ())
-> (Map VarId (SymInfo Type) -> Map VarId (SymInfo Type))
-> StateT (Map VarId (SymInfo Type)) Pass ()
forall a b. (a -> b) -> a -> b
$ VarId
-> SymInfo Type
-> Map VarId (SymInfo Type)
-> Map VarId (SymInfo Type)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
i (SymInfo Type
 -> Map VarId (SymInfo Type) -> Map VarId (SymInfo Type))
-> SymInfo Type
-> Map VarId (SymInfo Type)
-> Map VarId (SymInfo Type)
forall a b. (a -> b) -> a -> b
$ SymInfo :: forall t. VarId -> t -> SymInfo t
I.SymInfo{ symOrigin :: VarId
I.symOrigin = VarId
str, symType :: Type
I.symType = Type
t }
  VarId -> Fresh VarId
forall (m :: * -> *) a. Monad m => a -> m a
return VarId
i


-- | Make a dup primitive that returns the type of its argument
makeDup
  :: I.Expr I.Type
  -- ^ The variable to duplicate and return
  -> I.Expr I.Type
  -- ^ The @dup@ call
makeDup :: Expr Type -> Expr Type
makeDup Expr Type
e = Primitive -> [Expr Type] -> Type -> Expr Type
forall t. Primitive -> [Expr t] -> t -> Expr t
I.Prim Primitive
I.Dup [Expr Type
e] (Type -> Expr Type) -> Type -> Expr Type
forall a b. (a -> b) -> a -> b
$ Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
e


-- | Make a drop primitive with unit type
makeDrop
  :: I.Expr I.Type
  -- ^ The expression to evaluate and return
  -> I.Expr I.Type
  -- ^ The variable to drop afterwards
  -> I.Expr I.Type
  -- ^ The @drop@ call
makeDrop :: Expr Type -> Expr Type -> Expr Type
makeDrop Expr Type
r Expr Type
e = Primitive -> [Expr Type] -> Type -> Expr Type
forall t. Primitive -> [Expr t] -> t -> Expr t
I.Prim Primitive
I.Drop [Expr Type
e, Expr Type
r] Type
I.Unit


-- \$internal

{- | Insert referencing counting for top-level expressions

 Applies `insertExpr` to a top-level delcaration
-}
insertTop :: (I.Binder I.Type, I.Expr I.Type) -> Fresh (I.Binder I.Type, I.Expr I.Type)
insertTop :: (Binder Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Binder Type, Expr Type)
insertTop (Binder Type
var, Expr Type
expr) = (Binder Type
var,) (Expr Type -> (Binder Type, Expr Type))
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Binder Type, Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
expr


{- | Insert reference counting into an expression

This is the main workhorse of this module.

* __Literals__ are unchanged, e.g.,

> 42

remains

> 42

* __Data constructors__ are unchanged, e.g.,

> True

remains

> True

because they are functions whose results are returned with an existing reference

* A __variable reference__ becomes a call to dup because it introduces
  another reference to the named object, e.g.,

> v

becomes

> dup v

* __Application__ recurses (inserts dups and drops) on both the
  function being applied and its argument e.g.,

> add x y

becomes

> (dup add) (dup x) (dup y)

* __Primitive__ function application inserts dups and drops on its
  arguments, e.g.,

> (+) x y

becomes

> (+) (dup x) (dup y)

* __Let__ introduces new names whose values are dropped after the body
  is evaluated; @let _ =@ are given names so they can be dropped.

> let a = Foo 42
>     _ = a
> 17

becomes

> let a = Foo 42
> drop
>   (let anon1_underscore = dup a
>    drop
>      17
>      anon1_underscore)
>   a

* Nested __Lambda__ expressions are handled by collecting them into a
  single expression with multiple arguments, adding dups and drops to
  the body, and adding drops around the body for each argument (which
  the caller should have duped)

> add a b = a + b

desugars to

> add = fun a (fun b (a + b))

and becomes

@
add = fun a (
        fun b (
           drop (
              drop (
                (dup a) + (dup b)
              ) b
           ) a
@

* __Matches__ that operate on a variable are modified by inserting
 dups and drops into the arms (but not the scrutinee)

@
match v
  Foo x = x + 1
  Bar = 42
@

desugars to

@
match v
  Foo pat_anon0 = let x = pat_anon0
                    x + 1
  Bar = 42
@

then becomes

@
match v
  Foo pat_anon0 = drop
                    dup pat_anon0
                    let x = dup pat_anon0
                    drop (
                       dup x + 1
                    ) x
                  ) pat_anon0
  Bar = 42
@

* __Matches__ that scrutinize an expression lift the scrutinee
  into a @let@ then insert dups and drops on the whole thing

@
  match add x y
    10 = 5
    _ = 3
@

becomes

@
  let anon0_scrutinee = (dup add) (dup x) (dup y)
  drop (
    match anon0_scrutinee
      10 = 5
      _ = 3
  ) anon0_scruitinee
@
-}

--
insertExpr :: I.Expr I.Type -> Fresh (I.Expr I.Type)
insertExpr :: Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr dcon :: Expr Type
dcon@I.Data{} = Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
dcon
insertExpr lit :: Expr Type
lit@I.Lit{} = Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
lit
insertExpr var :: Expr Type
var@I.Var{} = Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type))
-> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> Expr Type
makeDup Expr Type
var
insertExpr (I.App Expr Type
f Expr Type
x Type
t) = Expr Type -> Expr Type -> Type -> Expr Type
forall t. Expr t -> Expr t -> t -> Expr t
I.App (Expr Type -> Expr Type -> Type -> Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
-> StateT
     (Map VarId (SymInfo Type)) Pass (Expr Type -> Type -> Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
f StateT
  (Map VarId (SymInfo Type)) Pass (Expr Type -> Type -> Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Type -> Expr Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
x StateT (Map VarId (SymInfo Type)) Pass (Type -> Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass Type
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> StateT (Map VarId (SymInfo Type)) Pass Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
insertExpr (I.Prim Primitive
p [Expr Type]
es Type
typ) = Primitive -> [Expr Type] -> Type -> Expr Type
forall t. Primitive -> [Expr t] -> t -> Expr t
I.Prim Primitive
p ([Expr Type] -> Type -> Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass [Expr Type]
-> StateT (Map VarId (SymInfo Type)) Pass (Type -> Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type))
-> [Expr Type]
-> StateT (Map VarId (SymInfo Type)) Pass [Expr Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr [Expr Type]
es StateT (Map VarId (SymInfo Type)) Pass (Type -> Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass Type
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> StateT (Map VarId (SymInfo Type)) Pass Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
typ
insertExpr (I.Let [(Binder Type, Expr Type)]
bins Expr Type
expr Type
typ) = do
  [(VarId, Type, Expr Type)]
bins' <- [(Binder Type, Expr Type)]
-> ((Binder Type, Expr Type)
    -> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type, Expr Type))
-> StateT
     (Map VarId (SymInfo Type)) Pass [(VarId, Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Binder Type, Expr Type)]
bins (Binder Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type, Expr Type)
droppedBinder
  Expr Type
expr' <- Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
expr
  Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type))
-> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall a b. (a -> b) -> a -> b
$ [(Binder Type, Expr Type)] -> Expr Type -> Type -> Expr Type
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let (((VarId, Type, Expr Type) -> (Binder Type, Expr Type))
-> [(VarId, Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a b. (a -> b) -> [a] -> [b]
map (VarId, Type, Expr Type) -> (Binder Type, Expr Type)
forall t b. (VarId, t, b) -> (Binder t, b)
defFromBind [(VarId, Type, Expr Type)]
bins') (((VarId, Type, Expr Type) -> Expr Type -> Expr Type)
-> Expr Type -> [(VarId, Type, Expr Type)] -> Expr Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Expr Type -> Expr Type -> Expr Type
makeDrop (Expr Type -> Expr Type -> Expr Type)
-> ((VarId, Type, Expr Type) -> Expr Type)
-> (VarId, Type, Expr Type)
-> Expr Type
-> Expr Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VarId, Type, Expr Type) -> Expr Type
forall t c. (VarId, t, c) -> Expr t
varFromBind) Expr Type
expr' [(VarId, Type, Expr Type)]
bins') Type
typ
 where
  varFromBind :: (VarId, t, c) -> Expr t
varFromBind (VarId
v, t
t, c
_) = VarId -> t -> Expr t
forall t. VarId -> t -> Expr t
I.Var VarId
v t
t
  defFromBind :: (VarId, t, b) -> (Binder t, b)
defFromBind (VarId
v, t
t, b
d) = (VarId -> t -> Binder t
forall t. VarId -> t -> Binder t
I.BindVar VarId
v t
t, b
d)

  droppedBinder :: (Binder Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type, Expr Type)
droppedBinder (I.BindAnon Type
t, Expr Type
d) = do
    VarId
temp <- String -> Type -> Fresh VarId
getFresh String
"underscore" Type
t
    Expr Type
d' <- Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
d
    (VarId, Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarId
temp, Type
t, Expr Type
d')
  droppedBinder (I.BindVar VarId
v Type
t, Expr Type
d) = do
    Expr Type
d' <- Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
d
    (VarId, Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarId
v, Type
t, Expr Type
d')
insertExpr lam :: Expr Type
lam@I.Lambda{} = do
  let ([Binder Type]
args, Expr Type
body) = Expr Type -> ([Binder Type], Expr Type)
forall t. Expr t -> ([Binder t], Expr t)
I.unfoldLambda Expr Type
lam
  [(VarId, Type)]
args' <- [Binder Type]
-> (Binder Type
    -> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type))
-> StateT (Map VarId (SymInfo Type)) Pass [(VarId, Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Binder Type]
args ((Binder Type
  -> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type))
 -> StateT (Map VarId (SymInfo Type)) Pass [(VarId, Type)])
-> (Binder Type
    -> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type))
-> StateT (Map VarId (SymInfo Type)) Pass [(VarId, Type)]
forall a b. (a -> b) -> a -> b
$ \case
    I.BindAnon Type
t -> do
      VarId
v <- String -> Type -> Fresh VarId
getFresh String
"arg" Type
t
      (VarId, Type)
-> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarId
v, Type
t)
    I.BindVar VarId
v Type
t -> (VarId, Type)
-> StateT (Map VarId (SymInfo Type)) Pass (VarId, Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarId
v, Type
t)
  let argBinders :: [Binder Type]
argBinders = (VarId -> Type -> Binder Type) -> (VarId, Type) -> Binder Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar ((VarId, Type) -> Binder Type) -> [(VarId, Type)] -> [Binder Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarId, Type)]
args' -- zipWith (\(v, t) b -> (I.BindVar v t, b)) args' argTypes
      argVars :: [Expr Type]
argVars = (VarId -> Type -> Expr Type) -> (VarId, Type) -> Expr Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var ((VarId, Type) -> Expr Type) -> [(VarId, Type)] -> [Expr Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarId, Type)]
args'
  Expr Type
body' <- Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
body
  Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type))
-> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall a b. (a -> b) -> a -> b
$ [Binder Type] -> Expr Type -> Expr Type
I.foldLambda [Binder Type]
argBinders (Expr Type -> Expr Type) -> Expr Type -> Expr Type
forall a b. (a -> b) -> a -> b
$ (Expr Type -> Expr Type -> Expr Type)
-> Expr Type -> [Expr Type] -> Expr Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Expr Type -> Expr Type -> Expr Type
makeDrop Expr Type
body' [Expr Type]
argVars
insertExpr (I.Match v :: Expr Type
v@I.Var{} [(Alt Type, Expr Type)]
alts Type
typ) = do
  [(Alt Type, Expr Type)]
alts' <- [(Alt Type, Expr Type)]
-> ((Alt Type, Expr Type)
    -> StateT (Map VarId (SymInfo Type)) Pass (Alt Type, Expr Type))
-> StateT (Map VarId (SymInfo Type)) Pass [(Alt Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Alt Type, Expr Type)]
alts (Alt Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Alt Type, Expr Type)
insertAlt
  Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type))
-> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> [(Alt Type, Expr Type)] -> Type -> Expr Type
forall t. Expr t -> [(Alt t, Expr t)] -> t -> Expr t
I.Match Expr Type
v [(Alt Type, Expr Type)]
alts' Type
typ
insertExpr (I.Match Expr Type
scrutExpr [(Alt Type, Expr Type)]
alts Type
typ) = do
  let scrutType :: Type
scrutType = Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
scrutExpr
  VarId
scrutVar <- String -> Type -> Fresh VarId
getFresh String
"scrutinee" Type
scrutType
  Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr (Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type))
-> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall a b. (a -> b) -> a -> b
$
    [(Binder Type, Expr Type)] -> Expr Type -> Type -> Expr Type
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let
      [(VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar VarId
scrutVar Type
scrutType, Expr Type
scrutExpr)]
      (Expr Type -> [(Alt Type, Expr Type)] -> Type -> Expr Type
forall t. Expr t -> [(Alt t, Expr t)] -> t -> Expr t
I.Match (VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var VarId
scrutVar Type
scrutType) [(Alt Type, Expr Type)]
alts Type
typ)
      Type
typ
insertExpr e :: Expr Type
e@(I.Exception ExceptType
_ Type
_) = Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
e


{- | Insert dups and drops into pattern match arms

The body of default and literal patterns is simply recursed upon.

Every named variable in a pattern is duped and dropped, e.g.,

@
match v
  Foo x = expr
@

becomes

@
match v
  Foo _anon1 = drop
                 (let x = dup _anon1
                  expr
               ) x
@
-}
insertAlt :: (I.Alt I.Type, I.Expr I.Type) -> Fresh (I.Alt I.Type, I.Expr I.Type)
insertAlt :: (Alt Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Alt Type, Expr Type)
insertAlt (I.AltBinder Binder Type
v, Expr Type
e) = (Binder Type -> Alt Type
forall t. Binder t -> Alt t
I.AltBinder Binder Type
v,) (Expr Type -> (Alt Type, Expr Type))
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Alt Type, Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
e
insertAlt (I.AltLit Literal
l Type
t, Expr Type
e) = (Literal -> Type -> Alt Type
forall t. Literal -> t -> Alt t
I.AltLit Literal
l Type
t,) (Expr Type -> (Alt Type, Expr Type))
-> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Alt Type, Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
e
insertAlt (I.AltData DConId
dcon [Alt Type]
binds Type
t, Expr Type
body) = do
  Expr Type
body' <- Expr Type -> StateT (Map VarId (SymInfo Type)) Pass (Expr Type)
insertExpr Expr Type
body
  -- NOTE: we don't recurse here because we assume alts are already desguared i.e., flat
  (Alt Type, Expr Type)
-> StateT (Map VarId (SymInfo Type)) Pass (Alt Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (DConId -> [Alt Type] -> Type -> Alt Type
forall t. DConId -> [Alt t] -> t -> Alt t
I.AltData DConId
dcon [Alt Type]
binds Type
t, (Binder Type -> Expr Type -> Expr Type)
-> Expr Type -> [Binder Type] -> Expr Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Binder Type -> Expr Type -> Expr Type
dropDupLet Expr Type
body' ([Binder Type] -> Expr Type) -> [Binder Type] -> Expr Type
forall a b. (a -> b) -> a -> b
$ (Alt Type -> [Binder Type]) -> [Alt Type] -> [Binder Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Alt Type -> [Binder Type]
forall t. Alt t -> [Binder t]
I.altBinders [Alt Type]
binds)
 where
  dropDupLet :: Binder Type -> Expr Type -> Expr Type
dropDupLet (I.BindAnon Type
_) Expr Type
e = Expr Type
e
  dropDupLet (I.BindVar VarId
v Type
t') Expr Type
e =
    let varExpr :: Expr Type
varExpr = VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var VarId
v Type
t'
        dupExpr :: Expr Type
dupExpr = Expr Type -> Expr Type
makeDup Expr Type
varExpr
     in Expr Type -> Expr Type -> Expr Type
makeDrop Expr Type
varExpr ([(Binder Type, Expr Type)] -> Expr Type -> Type -> Expr Type
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let [(Type -> Binder Type
forall t. t -> Binder t
I.BindAnon (Type -> Binder Type) -> Type -> Binder Type
forall a b. (a -> b) -> a -> b
$ Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
dupExpr, Expr Type
dupExpr)] Expr Type
e (Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
e))