{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}

{- | Lift lambda definitions into the global scope.

This pass is responsible for moving nested lambda definitions into the global
scope and performing necessary callsite adjustments.
-}
module IR.LambdaLift (
  liftProgramLambdas,
) where

import qualified Common.Compiler as Compiler
import Common.Identifiers
import qualified IR.IR as I
import qualified IR.MangleNames as I
import qualified IR.Pretty ()
import qualified IR.Types as I

import Control.Monad (forM, forM_, unless)
import Control.Monad.Except (MonadError (..))
import Control.Monad.State.Lazy (
  MonadState (..),
  StateT (..),
  gets,
  modify,
 )

import Data.Bifunctor (Bifunctor (..))
import Data.Generics (everywhere, mkT)
import Data.List (intersperse, tails)
import Data.Map ((\\))
import qualified Data.Map as M
import Data.Maybe (mapMaybe, maybeToList)


binderVars :: [I.Binder I.Type] -> [(I.VarId, I.Type)]
binderVars :: [Binder Type] -> [(VarId, Type)]
binderVars (I.BindVar VarId
v Type
t : [Binder Type]
bs) = (VarId
v, Type
t) (VarId, Type) -> [(VarId, Type)] -> [(VarId, Type)]
forall a. a -> [a] -> [a]
: [Binder Type] -> [(VarId, Type)]
binderVars [Binder Type]
bs
binderVars (Binder Type
_ : [Binder Type]
bs) = [Binder Type] -> [(VarId, Type)]
binderVars [Binder Type]
bs
binderVars [] = []


-- | Lifting Environment
data LiftCtx = LiftCtx
  { -- | 'globalScope' is a set containing top-level identifiers. All scopes,
    -- regardless of depth, have access to these identifiers.
    LiftCtx -> Map VarId Type
globalScope :: M.Map I.VarId I.Type
  , -- | 'currentScope' is a set containing the identifiers available in the
    -- current scope.
    LiftCtx -> Map VarId Type
currentScope :: M.Map I.VarId I.Type
  , -- | 'currentTrail' is a list of strings tracing the surrounding scopes in
    -- terms of language constructs and relevant identifiers. It is used for
    -- creating unique identifiers for lifted lambdas.
    LiftCtx -> [VarId]
currentTrail :: [I.VarId]
  , -- | Free variable encounetered during the a descent. These need to be added
    -- as arguments to lifted closures, and applied at the original site of the
    -- closure.
    LiftCtx -> Map VarId Type
currentFreeVars :: M.Map I.VarId I.Type
  , -- | 'lifted' is a list of lifted lambdas created while descending into a
    -- top-level definition.
    LiftCtx -> [(Binder Type, Expr Type)]
lifted :: [(I.Binder I.Type, I.Expr I.Type)]
  , -- | 'symTable' is used to generate fresh variable names, using 'I.pickId'.
    LiftCtx -> SymTable Type
symTable :: I.SymTable I.Type
  }


-- | Lift Monad
newtype LiftFn a = LiftFn (StateT LiftCtx Compiler.Pass a)
  deriving (a -> LiftFn b -> LiftFn a
(a -> b) -> LiftFn a -> LiftFn b
(forall a b. (a -> b) -> LiftFn a -> LiftFn b)
-> (forall a b. a -> LiftFn b -> LiftFn a) -> Functor LiftFn
forall a b. a -> LiftFn b -> LiftFn a
forall a b. (a -> b) -> LiftFn a -> LiftFn b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> LiftFn b -> LiftFn a
$c<$ :: forall a b. a -> LiftFn b -> LiftFn a
fmap :: (a -> b) -> LiftFn a -> LiftFn b
$cfmap :: forall a b. (a -> b) -> LiftFn a -> LiftFn b
Functor) via (StateT LiftCtx Compiler.Pass)
  deriving (Functor LiftFn
a -> LiftFn a
Functor LiftFn
-> (forall a. a -> LiftFn a)
-> (forall a b. LiftFn (a -> b) -> LiftFn a -> LiftFn b)
-> (forall a b c.
    (a -> b -> c) -> LiftFn a -> LiftFn b -> LiftFn c)
-> (forall a b. LiftFn a -> LiftFn b -> LiftFn b)
-> (forall a b. LiftFn a -> LiftFn b -> LiftFn a)
-> Applicative LiftFn
LiftFn a -> LiftFn b -> LiftFn b
LiftFn a -> LiftFn b -> LiftFn a
LiftFn (a -> b) -> LiftFn a -> LiftFn b
(a -> b -> c) -> LiftFn a -> LiftFn b -> LiftFn c
forall a. a -> LiftFn a
forall a b. LiftFn a -> LiftFn b -> LiftFn a
forall a b. LiftFn a -> LiftFn b -> LiftFn b
forall a b. LiftFn (a -> b) -> LiftFn a -> LiftFn b
forall a b c. (a -> b -> c) -> LiftFn a -> LiftFn b -> LiftFn c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: LiftFn a -> LiftFn b -> LiftFn a
$c<* :: forall a b. LiftFn a -> LiftFn b -> LiftFn a
*> :: LiftFn a -> LiftFn b -> LiftFn b
$c*> :: forall a b. LiftFn a -> LiftFn b -> LiftFn b
liftA2 :: (a -> b -> c) -> LiftFn a -> LiftFn b -> LiftFn c
$cliftA2 :: forall a b c. (a -> b -> c) -> LiftFn a -> LiftFn b -> LiftFn c
<*> :: LiftFn (a -> b) -> LiftFn a -> LiftFn b
$c<*> :: forall a b. LiftFn (a -> b) -> LiftFn a -> LiftFn b
pure :: a -> LiftFn a
$cpure :: forall a. a -> LiftFn a
$cp1Applicative :: Functor LiftFn
Applicative) via (StateT LiftCtx Compiler.Pass)
  deriving (Applicative LiftFn
a -> LiftFn a
Applicative LiftFn
-> (forall a b. LiftFn a -> (a -> LiftFn b) -> LiftFn b)
-> (forall a b. LiftFn a -> LiftFn b -> LiftFn b)
-> (forall a. a -> LiftFn a)
-> Monad LiftFn
LiftFn a -> (a -> LiftFn b) -> LiftFn b
LiftFn a -> LiftFn b -> LiftFn b
forall a. a -> LiftFn a
forall a b. LiftFn a -> LiftFn b -> LiftFn b
forall a b. LiftFn a -> (a -> LiftFn b) -> LiftFn b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> LiftFn a
$creturn :: forall a. a -> LiftFn a
>> :: LiftFn a -> LiftFn b -> LiftFn b
$c>> :: forall a b. LiftFn a -> LiftFn b -> LiftFn b
>>= :: LiftFn a -> (a -> LiftFn b) -> LiftFn b
$c>>= :: forall a b. LiftFn a -> (a -> LiftFn b) -> LiftFn b
$cp1Monad :: Applicative LiftFn
Monad) via (StateT LiftCtx Compiler.Pass)
  deriving (Monad LiftFn
Monad LiftFn -> (forall a. String -> LiftFn a) -> MonadFail LiftFn
String -> LiftFn a
forall a. String -> LiftFn a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> LiftFn a
$cfail :: forall a. String -> LiftFn a
$cp1MonadFail :: Monad LiftFn
MonadFail) via (StateT LiftCtx Compiler.Pass)
  deriving (MonadError Compiler.Error) via (StateT LiftCtx Compiler.Pass)
  deriving (MonadState LiftCtx) via (StateT LiftCtx Compiler.Pass)


-- | Unwrap the lift monad.
unLiftFn :: LiftFn a -> StateT LiftCtx Compiler.Pass a
unLiftFn :: LiftFn a -> StateT LiftCtx Pass a
unLiftFn (LiftFn StateT LiftCtx Pass a
a) = StateT LiftCtx Pass a
a


-- | Generate a fresh name from an origin, the context trail, and the sym table.
genName :: I.VarId -> I.Type -> LiftFn I.VarId
genName :: VarId -> Type -> LiftFn VarId
genName VarId
origin Type
t = do
  [VarId]
trail <- (LiftCtx -> [VarId]) -> LiftFn [VarId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> [VarId]
currentTrail
  SymTable Type
syms <- (LiftCtx -> SymTable Type) -> LiftFn (SymTable Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> SymTable Type
symTable
  let origin' :: VarId
origin' = [VarId] -> VarId
forall a. Monoid a => [a] -> a
mconcat ([VarId] -> VarId) -> [VarId] -> VarId
forall a b. (a -> b) -> a -> b
$ VarId -> [VarId] -> [VarId]
forall a. a -> [a] -> [a]
intersperse VarId
"_" ([VarId] -> [VarId]) -> [VarId] -> [VarId]
forall a b. (a -> b) -> a -> b
$ [VarId] -> [VarId]
forall a. [a] -> [a]
reverse ([VarId] -> [VarId]) -> [VarId] -> [VarId]
forall a b. (a -> b) -> a -> b
$ VarId
origin VarId -> [VarId] -> [VarId]
forall a. a -> [a] -> [a]
: [VarId]
trail
      name :: VarId
name = SymTable Type -> VarId -> VarId
forall t. Map VarId t -> VarId -> VarId
I.pickId SymTable Type
syms VarId
origin'
      syms' :: SymTable Type
syms' = VarId -> SymInfo Type -> SymTable Type -> SymTable Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
name SymInfo :: forall t. VarId -> t -> SymInfo t
I.SymInfo{symOrigin :: VarId
I.symOrigin = VarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId VarId
origin, symType :: Type
I.symType = Type
t} SymTable Type
syms

  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
ctx -> LiftCtx
ctx{symTable :: SymTable Type
symTable = SymTable Type
syms'}

  VarId -> LiftFn VarId
forall (m :: * -> *) a. Monad m => a -> m a
return VarId
name


-- | Update lift environment before entering a new scope (e.g. non-recursive let definition, match arm).
withEnclosingScope :: Maybe I.VarId -> [I.Binder I.Type] -> LiftFn a -> LiftFn a
withEnclosingScope :: Maybe VarId -> [Binder Type] -> LiftFn a -> LiftFn a
withEnclosingScope Maybe VarId
t ([Binder Type] -> [(VarId, Type)]
binderVars -> [(VarId, Type)]
s) LiftFn a
m = do
  (Map VarId Type
scope, [VarId]
trail) <- (,) (Map VarId Type -> [VarId] -> (Map VarId Type, [VarId]))
-> LiftFn (Map VarId Type)
-> LiftFn ([VarId] -> (Map VarId Type, [VarId]))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LiftCtx -> Map VarId Type) -> LiftFn (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> Map VarId Type
currentScope LiftFn ([VarId] -> (Map VarId Type, [VarId]))
-> LiftFn [VarId] -> LiftFn (Map VarId Type, [VarId])
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LiftCtx -> [VarId]) -> LiftFn [VarId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> [VarId]
currentTrail

  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st ->
    LiftCtx
st
      { currentScope :: Map VarId Type
currentScope = [(VarId, Type)] -> Map VarId Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VarId, Type)]
s Map VarId Type -> Map VarId Type -> Map VarId Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map VarId Type
scope -- Add bindings to inner scope
      , currentTrail :: [VarId]
currentTrail = Maybe VarId -> [VarId]
forall a. Maybe a -> [a]
maybeToList Maybe VarId
t [VarId] -> [VarId] -> [VarId]
forall a. Semigroup a => a -> a -> a
<> [VarId]
trail -- Possibly extend trail
      }

  a
a <- LiftFn a
m

  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st ->
    LiftCtx
st -- Restore state
      { currentScope :: Map VarId Type
currentScope = Map VarId Type
scope
      , currentTrail :: [VarId]
currentTrail = [VarId]
trail
      }
  a -> LiftFn a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a


withLiftedScope :: Maybe I.VarId -> [I.Binder I.Type] -> LiftFn a -> LiftFn (a, [(I.VarId, I.Type)])
withLiftedScope :: Maybe VarId
-> [Binder Type] -> LiftFn a -> LiftFn (a, [(VarId, Type)])
withLiftedScope Maybe VarId
t ([Binder Type] -> [(VarId, Type)]
binderVars -> [(VarId, Type)]
s) LiftFn a
m = do
  (Map VarId Type
scope, [VarId]
trail, Map VarId Type
free) <- (,,) (Map VarId Type
 -> [VarId]
 -> Map VarId Type
 -> (Map VarId Type, [VarId], Map VarId Type))
-> LiftFn (Map VarId Type)
-> LiftFn
     ([VarId]
      -> Map VarId Type -> (Map VarId Type, [VarId], Map VarId Type))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LiftCtx -> Map VarId Type) -> LiftFn (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> Map VarId Type
currentScope LiftFn
  ([VarId]
   -> Map VarId Type -> (Map VarId Type, [VarId], Map VarId Type))
-> LiftFn [VarId]
-> LiftFn
     (Map VarId Type -> (Map VarId Type, [VarId], Map VarId Type))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LiftCtx -> [VarId]) -> LiftFn [VarId]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> [VarId]
currentTrail LiftFn
  (Map VarId Type -> (Map VarId Type, [VarId], Map VarId Type))
-> LiftFn (Map VarId Type)
-> LiftFn (Map VarId Type, [VarId], Map VarId Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LiftCtx -> Map VarId Type) -> LiftFn (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> Map VarId Type
currentFreeVars

  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st ->
    LiftCtx
st
      { currentFreeVars :: Map VarId Type
currentFreeVars = Map VarId Type
forall k a. Map k a
M.empty -- Reset accounting of free variables encountered
      , currentScope :: Map VarId Type
currentScope = [(VarId, Type)] -> Map VarId Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VarId, Type)]
s -- Clear the currentScope (which will not exist when the inner scope is lifted to the global scope)
      , currentTrail :: [VarId]
currentTrail = Maybe VarId -> [VarId]
forall a. Maybe a -> [a]
maybeToList Maybe VarId
t [VarId] -> [VarId] -> [VarId]
forall a. Semigroup a => a -> a -> a
<> [VarId]
trail -- Possibly extend trail
      }

  a
a <- LiftFn a
m

  Map VarId Type
free' <- (LiftCtx -> Map VarId Type) -> LiftFn (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> Map VarId Type
currentFreeVars
  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st ->
    LiftCtx
st -- Restore state
      { currentFreeVars :: Map VarId Type
currentFreeVars = Map VarId Type
free Map VarId Type -> Map VarId Type -> Map VarId Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` (Map VarId Type
free' Map VarId Type -> Map VarId Type -> Map VarId Type
forall k a b. Ord k => Map k a -> Map k b -> Map k a
\\ Map VarId Type
scope)
      , currentScope :: Map VarId Type
currentScope = Map VarId Type
scope
      , currentTrail :: [VarId]
currentTrail = [VarId]
trail
      }
  (a, [(VarId, Type)]) -> LiftFn (a, [(VarId, Type)])
forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, Map VarId Type -> [(VarId, Type)]
forall k a. Map k a -> [(k, a)]
M.toList Map VarId Type
free')


-- | Store a new lifted lambda to later add to the program's top level definitions.
tellLifted :: I.VarId -> I.Expr I.Type -> LiftFn ()
tellLifted :: VarId -> Expr Type -> LiftFn ()
tellLifted VarId
name Expr Type
lam =
  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st -> LiftCtx
st{lifted :: [(Binder Type, Expr Type)]
lifted = (VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar (VarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId VarId
name) (Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
lam), Expr Type
lam) (Binder Type, Expr Type)
-> [(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. a -> [a] -> [a]
: LiftCtx -> [(Binder Type, Expr Type)]
lifted LiftCtx
st}


-- | Context management for liftifreshNameng top level lambda definitions.
extractLifted :: LiftFn [(I.Binder I.Type, I.Expr I.Type)]
extractLifted :: LiftFn [(Binder Type, Expr Type)]
extractLifted = do
  [(Binder Type, Expr Type)]
lifted' <- (LiftCtx -> [(Binder Type, Expr Type)])
-> LiftFn [(Binder Type, Expr Type)]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> [(Binder Type, Expr Type)]
lifted
  (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st -> LiftCtx
st{lifted :: [(Binder Type, Expr Type)]
lifted = []}
  [(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)])
-> [(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)]
forall a b. (a -> b) -> a -> b
$ [(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. [a] -> [a]
reverse [(Binder Type, Expr Type)]
lifted'


{- | Entry-point to lambda lifting.

Maps over top level definitions and lifts out lambda definitions to create a new
Program with the relative order of user definitions preserved.
-}
liftProgramLambdas :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
liftProgramLambdas :: Program Type -> Pass (Program Type)
liftProgramLambdas p :: Program Type
p@I.Program{programDefs :: forall t. Program t -> [(Binder t, Expr t)]
I.programDefs = [(Binder Type, Expr Type)]
defs, symTable :: forall t. Program t -> Map VarId (SymInfo t)
I.symTable = SymTable Type
syms} = do
  ([[(Binder Type, Expr Type)]]
defs', LiftCtx -> SymTable Type
symTable -> SymTable Type
syms') <- StateT LiftCtx Pass [[(Binder Type, Expr Type)]]
-> LiftCtx -> Pass ([[(Binder Type, Expr Type)]], LiftCtx)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (LiftFn [[(Binder Type, Expr Type)]]
-> StateT LiftCtx Pass [[(Binder Type, Expr Type)]]
forall a. LiftFn a -> StateT LiftCtx Pass a
unLiftFn (LiftFn [[(Binder Type, Expr Type)]]
 -> StateT LiftCtx Pass [[(Binder Type, Expr Type)]])
-> LiftFn [[(Binder Type, Expr Type)]]
-> StateT LiftCtx Pass [[(Binder Type, Expr Type)]]
forall a b. (a -> b) -> a -> b
$ ((Binder Type, Expr Type) -> LiftFn [(Binder Type, Expr Type)])
-> [(Binder Type, Expr Type)]
-> LiftFn [[(Binder Type, Expr Type)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type) -> LiftFn [(Binder Type, Expr Type)]
liftTop [(Binder Type, Expr Type)]
defs) LiftCtx
initCtx
  Program Type -> Pass (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Program Type
p{programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [[(Binder Type, Expr Type)]] -> [(Binder Type, Expr Type)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(Binder Type, Expr Type)]]
defs', symTable :: SymTable Type
I.symTable = SymTable Type
syms'}
 where
  initCtx :: LiftCtx
initCtx =
    LiftCtx :: Map VarId Type
-> Map VarId Type
-> [VarId]
-> Map VarId Type
-> [(Binder Type, Expr Type)]
-> SymTable Type
-> LiftCtx
LiftCtx
      { globalScope :: Map VarId Type
globalScope = [(VarId, Type)] -> Map VarId Type
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VarId, Type)]
globals
      , currentScope :: Map VarId Type
currentScope = Map VarId Type
forall k a. Map k a
M.empty
      , currentFreeVars :: Map VarId Type
currentFreeVars = Map VarId Type
forall k a. Map k a
M.empty
      , currentTrail :: [VarId]
currentTrail = []
      , lifted :: [(Binder Type, Expr Type)]
lifted = []
      , symTable :: SymTable Type
symTable = SymTable Type
syms
      }

  globals :: [(VarId, Type)]
globals = ((Binder Type, Expr Type) -> Maybe (VarId, Type))
-> [(Binder Type, Expr Type)] -> [(VarId, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Binder Type, Type) -> Maybe (VarId, Type)
forall a b. (Binder a, b) -> Maybe (VarId, b)
extractBindVar ((Binder Type, Type) -> Maybe (VarId, Type))
-> ((Binder Type, Expr Type) -> (Binder Type, Type))
-> (Binder Type, Expr Type)
-> Maybe (VarId, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Expr Type -> Type)
-> (Binder Type, Expr Type) -> (Binder Type, Type)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract) [(Binder Type, Expr Type)]
defs

  extractBindVar :: (Binder a, b) -> Maybe (VarId, b)
extractBindVar (Binder a -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar -> Just VarId
v, b
t) = (VarId, b) -> Maybe (VarId, b)
forall a. a -> Maybe a
Just (VarId
v, b
t)
  extractBindVar (Binder a, b)
_ = Maybe (VarId, b)
forall a. Maybe a
Nothing

  liftTop :: (Binder Type, Expr Type) -> LiftFn [(Binder Type, Expr Type)]
liftTop (Binder Type
v, lam :: Expr Type
lam@I.Lambda{}) = do
    let ([Binder Type]
bs, Expr Type
body) = Expr Type -> ([Binder Type], Expr Type)
forall t. Expr t -> ([Binder t], Expr t)
I.unfoldLambda Expr Type
lam
    Expr Type
body' <- Maybe VarId
-> [Binder Type] -> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a. Maybe VarId -> [Binder Type] -> LiftFn a -> LiftFn a
withEnclosingScope (VarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId (VarId -> VarId) -> Maybe VarId -> Maybe VarId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Binder Type -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar Binder Type
v) [Binder Type]
bs (LiftFn (Expr Type) -> LiftFn (Expr Type))
-> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
body
    [(Binder Type, Expr Type)]
liftedLambdas <- LiftFn [(Binder Type, Expr Type)]
extractLifted
    [(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)]
forall (m :: * -> *) a. Monad m => a -> m a
return ([(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)])
-> [(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)]
forall a b. (a -> b) -> a -> b
$ [(Binder Type, Expr Type)]
liftedLambdas [(Binder Type, Expr Type)]
-> [(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. [a] -> [a] -> [a]
++ [(Binder Type
v, [Binder Type] -> Expr Type -> Expr Type
I.foldLambda [Binder Type]
bs Expr Type
body')]
  liftTop (Binder Type, Expr Type)
topDef = [(Binder Type, Expr Type)] -> LiftFn [(Binder Type, Expr Type)]
forall (m :: * -> *) a. Monad m => a -> m a
return [(Binder Type, Expr Type)
topDef]


{- | Lifting logic for IR expressions.

As we traverse over IR expressions, we note down any bindings we encounter so
that we can detect free variables. For lambda definitions, we use free
variables to create a new top-level lifted equivalent and then adjust the
callsite by partially-applying the new lifted lambda with those free variables
from the surrounding the scope.
-}
liftLambdas :: I.Expr I.Type -> LiftFn (I.Expr I.Type)
liftLambdas :: Expr Type -> LiftFn (Expr Type)
liftLambdas n :: Expr Type
n@(I.Var VarId
v Type
t) = do
  Map VarId Type
scope <- Map VarId Type -> Map VarId Type -> Map VarId Type
forall k a. Ord k => Map k a -> Map k a -> Map k a
M.union (Map VarId Type -> Map VarId Type -> Map VarId Type)
-> LiftFn (Map VarId Type)
-> LiftFn (Map VarId Type -> Map VarId Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LiftCtx -> Map VarId Type) -> LiftFn (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> Map VarId Type
currentScope LiftFn (Map VarId Type -> Map VarId Type)
-> LiftFn (Map VarId Type) -> LiftFn (Map VarId Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LiftCtx -> Map VarId Type) -> LiftFn (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets LiftCtx -> Map VarId Type
globalScope
  Bool -> LiftFn () -> LiftFn ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (VarId
v VarId -> Map VarId Type -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map VarId Type
scope) (LiftFn () -> LiftFn ()) -> LiftFn () -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ do
    -- @v@ appears free in the current scope; make a note of it so we know to
    -- apply it when lifting the enclosing closure.
    (LiftCtx -> LiftCtx) -> LiftFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((LiftCtx -> LiftCtx) -> LiftFn ())
-> (LiftCtx -> LiftCtx) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \LiftCtx
st -> LiftCtx
st{currentFreeVars :: Map VarId Type
currentFreeVars = VarId -> Type -> Map VarId Type -> Map VarId Type
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarId
v Type
t (Map VarId Type -> Map VarId Type)
-> Map VarId Type -> Map VarId Type
forall a b. (a -> b) -> a -> b
$ LiftCtx -> Map VarId Type
currentFreeVars LiftCtx
st}
  Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
n
liftLambdas (I.App Expr Type
e1 Expr Type
e2 Type
t) = Expr Type -> Expr Type -> Type -> Expr Type
forall t. Expr t -> Expr t -> t -> Expr t
I.App (Expr Type -> Expr Type -> Type -> Expr Type)
-> LiftFn (Expr Type) -> LiftFn (Expr Type -> Type -> Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
e1 LiftFn (Expr Type -> Type -> Expr Type)
-> LiftFn (Expr Type) -> LiftFn (Type -> Expr Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
e2 LiftFn (Type -> Expr Type) -> LiftFn Type -> LiftFn (Expr Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> LiftFn Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
liftLambdas (I.Prim Primitive
p [Expr Type]
exprs Type
t) = Primitive -> [Expr Type] -> Type -> Expr Type
forall t. Primitive -> [Expr t] -> t -> Expr t
I.Prim Primitive
p ([Expr Type] -> Type -> Expr Type)
-> LiftFn [Expr Type] -> LiftFn (Type -> Expr Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr Type -> LiftFn (Expr Type))
-> [Expr Type] -> LiftFn [Expr Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Expr Type -> LiftFn (Expr Type)
liftLambdas [Expr Type]
exprs LiftFn (Type -> Expr Type) -> LiftFn Type -> LiftFn (Expr Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> LiftFn Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
liftLambdas lam :: Expr Type
lam@I.Lambda{} = do
  (Expr Type
lam', VarId
liftName, Expr Type
liftLam) <- Expr Type
-> [Binder Type] -> VarId -> LiftFn (Expr Type, VarId, Expr Type)
liftLambda Expr Type
lam [] VarId
"__lambda"
  VarId -> Expr Type -> LiftFn ()
tellLifted VarId
liftName Expr Type
liftLam
  Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
lam'
liftLambdas (I.Let [(Binder Type, Expr Type)]
ds Expr Type
b Type
t)
  | ((Binder Type, Expr Type) -> Bool)
-> [(Binder Type, Expr Type)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Expr Type -> Bool
forall t. Expr t -> Bool
isLambda (Expr Type -> Bool)
-> ((Binder Type, Expr Type) -> Expr Type)
-> (Binder Type, Expr Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Binder Type, Expr Type) -> Expr Type
forall a b. (a, b) -> b
snd) [(Binder Type, Expr Type)]
ds = do
    -- e.g.,  let f x = ...
    --            g y = ...
    --        ...
    -- Bindings (e.g., f, g) might be recursive and apppear inside definitions.
    let binders :: [Binder Type]
binders = ((Binder Type, Expr Type) -> Binder Type)
-> [(Binder Type, Expr Type)] -> [Binder Type]
forall a b. (a -> b) -> [a] -> [b]
map (Binder Type, Expr Type) -> Binder Type
forall a b. (a, b) -> a
fst [(Binder Type, Expr Type)]
ds
    [((Binder Type, Expr Type), (VarId, Expr Type))]
dsls' <- [(Binder Type, Expr Type)]
-> ((Binder Type, Expr Type)
    -> LiftFn ((Binder Type, Expr Type), (VarId, Expr Type)))
-> LiftFn [((Binder Type, Expr Type), (VarId, Expr Type))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Binder Type, Expr Type)]
ds (((Binder Type, Expr Type)
  -> LiftFn ((Binder Type, Expr Type), (VarId, Expr Type)))
 -> LiftFn [((Binder Type, Expr Type), (VarId, Expr Type))])
-> ((Binder Type, Expr Type)
    -> LiftFn ((Binder Type, Expr Type), (VarId, Expr Type)))
-> LiftFn [((Binder Type, Expr Type), (VarId, Expr Type))]
forall a b. (a -> b) -> a -> b
$ \(Binder Type
x, Expr Type
d) -> do
      (Expr Type
d', VarId
x', Expr Type
lam) <- Expr Type
-> [Binder Type] -> VarId -> LiftFn (Expr Type, VarId, Expr Type)
liftLambda Expr Type
d [Binder Type]
binders (VarId -> LiftFn (Expr Type, VarId, Expr Type))
-> VarId -> LiftFn (Expr Type, VarId, Expr Type)
forall a b. (a -> b) -> a -> b
$ VarId -> (VarId -> VarId) -> Maybe VarId -> VarId
forall b a. b -> (a -> b) -> Maybe a -> b
maybe VarId
"__let_underscore" VarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId (Maybe VarId -> VarId) -> Maybe VarId -> VarId
forall a b. (a -> b) -> a -> b
$ Binder Type -> Maybe VarId
forall a. Binder a -> Maybe VarId
I._binderId Binder Type
x
      ((Binder Type, Expr Type), (VarId, Expr Type))
-> LiftFn ((Binder Type, Expr Type), (VarId, Expr Type))
forall (m :: * -> *) a. Monad m => a -> m a
return ((Binder Type
x, Expr Type
d'), (VarId
x', Expr Type
lam))

    let ([(Binder Type, Expr Type)]
xd, [(VarId, Expr Type)]
xl) = [((Binder Type, Expr Type), (VarId, Expr Type))]
-> ([(Binder Type, Expr Type)], [(VarId, Expr Type)])
forall a b. [(a, b)] -> ([a], [b])
unzip [((Binder Type, Expr Type), (VarId, Expr Type))]
dsls'
        xdMap :: Map VarId (Expr Type)
xdMap = [(VarId, Expr Type)] -> Map VarId (Expr Type)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VarId, Expr Type)] -> Map VarId (Expr Type))
-> [(VarId, Expr Type)] -> Map VarId (Expr Type)
forall a b. (a -> b) -> a -> b
$ ((Binder Type, Expr Type) -> Maybe (VarId, Expr Type))
-> [(Binder Type, Expr Type)] -> [(VarId, Expr Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Binder Type, Expr Type) -> Maybe (VarId, Expr Type)
makeMapping [(Binder Type, Expr Type)]
xd

        makeMapping :: (I.Binder I.Type, I.Expr I.Type) -> Maybe (I.VarId, I.Expr I.Type)
        makeMapping :: (Binder Type, Expr Type) -> Maybe (VarId, Expr Type)
makeMapping (I.BindVar VarId
x Type
_, Expr Type
d') = (VarId, Expr Type) -> Maybe (VarId, Expr Type)
forall a. a -> Maybe a
Just (VarId
x, Expr Type
d')
        makeMapping (Binder Type, Expr Type)
_ = Maybe (VarId, Expr Type)
forall a. Maybe a
Nothing -- Should never be reachable, consider throwing error
        mapRecBinds :: I.Expr I.Type -> I.Expr I.Type
        mapRecBinds :: Expr Type -> Expr Type
mapRecBinds (I.App (I.Var VarId
x Type
xt) Expr Type
args Type
at) = case VarId -> Map VarId (Expr Type) -> Maybe (Expr Type)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarId
x Map VarId (Expr Type)
xdMap of
          Just Expr Type
d' -> Expr Type -> Expr Type -> Type -> Expr Type
forall t. Expr t -> Expr t -> t -> Expr t
I.App Expr Type
d' Expr Type
args Type
at
          Maybe (Expr Type)
Nothing -> Expr Type -> Expr Type -> Type -> Expr Type
forall t. Expr t -> Expr t -> t -> Expr t
I.App (VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var VarId
x Type
xt) Expr Type
args Type
at
        mapRecBinds Expr Type
e = Expr Type
e

    [(VarId, Expr Type)]
-> ((VarId, Expr Type) -> LiftFn ()) -> LiftFn ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(VarId, Expr Type)]
xl (((VarId, Expr Type) -> LiftFn ()) -> LiftFn ())
-> ((VarId, Expr Type) -> LiftFn ()) -> LiftFn ()
forall a b. (a -> b) -> a -> b
$ \(VarId
x', Expr Type
lam) -> do
      -- replace every instance of Var x to Var x' in lam
      let lam' :: Expr Type
lam' = (forall a. Data a => a -> a) -> Expr Type -> Expr Type
(forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((Expr Type -> Expr Type) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT Expr Type -> Expr Type
mapRecBinds) Expr Type
lam
      VarId -> Expr Type -> LiftFn ()
tellLifted VarId
x' Expr Type
lam'

    Expr Type
b' <- Maybe VarId
-> [Binder Type] -> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a. Maybe VarId -> [Binder Type] -> LiftFn a -> LiftFn a
withEnclosingScope Maybe VarId
forall a. Maybe a
Nothing [Binder Type]
binders (LiftFn (Expr Type) -> LiftFn (Expr Type))
-> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
b
    Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> LiftFn (Expr Type))
-> Expr Type -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ [(Binder Type, Expr Type)] -> Expr Type -> Type -> Expr Type
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let [(Binder Type, Expr Type)]
xd Expr Type
b' Type
t
  | [(Binder Type, Expr Type)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Binder Type, Expr Type)]
ds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = do
    -- e.g.,  let x = ...
    --        ...
    -- Binding is not recursive, so x cannot appear in definition.
    let (Binder Type
x, Expr Type
d) = [(Binder Type, Expr Type)] -> (Binder Type, Expr Type)
forall a. [a] -> a
head [(Binder Type, Expr Type)]
ds
    Expr Type
d' <- Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
d
    Expr Type
e' <- Maybe VarId
-> [Binder Type] -> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a. Maybe VarId -> [Binder Type] -> LiftFn a -> LiftFn a
withEnclosingScope Maybe VarId
forall a. Maybe a
Nothing [Binder Type
x] (LiftFn (Expr Type) -> LiftFn (Expr Type))
-> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
b
    Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> LiftFn (Expr Type))
-> Expr Type -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ [(Binder Type, Expr Type)] -> Expr Type -> Type -> Expr Type
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let [(Binder Type
x, Expr Type
d')] Expr Type
e' Type
t
  | Bool
otherwise = String -> LiftFn (Expr Type)
forall a. HasCallStack => String -> a
error (String -> LiftFn (Expr Type)) -> String -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ String
"Let expressions should only bind a list of values, or a single non-value " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [(Binder Type, Expr Type)] -> String
forall a. Show a => a -> String
show [(Binder Type, Expr Type)]
ds
 where
  isLambda :: Expr t -> Bool
isLambda I.Lambda{} = Bool
True
  isLambda Expr t
_ = Bool
False
liftLambdas (I.Match Expr Type
s [(Alt Type, Expr Type)]
as Type
t) = do
  Expr Type
s' <- Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
s
  [(Alt Type, Expr Type)]
as' <- [(Alt Type, Expr Type)]
-> ((Alt Type, Expr Type) -> LiftFn (Alt Type, Expr Type))
-> LiftFn [(Alt Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Alt Type, Expr Type)]
as (((Alt Type, Expr Type) -> LiftFn (Alt Type, Expr Type))
 -> LiftFn [(Alt Type, Expr Type)])
-> ((Alt Type, Expr Type) -> LiftFn (Alt Type, Expr Type))
-> LiftFn [(Alt Type, Expr Type)]
forall a b. (a -> b) -> a -> b
$ \(Alt Type
a, Expr Type
e) -> do
    Expr Type
e' <- Maybe VarId
-> [Binder Type] -> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a. Maybe VarId -> [Binder Type] -> LiftFn a -> LiftFn a
withEnclosingScope Maybe VarId
forall a. Maybe a
Nothing (Alt Type -> [Binder Type]
forall t. Alt t -> [Binder t]
I.altBinders Alt Type
a) (LiftFn (Expr Type) -> LiftFn (Expr Type))
-> LiftFn (Expr Type) -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
e
    (Alt Type, Expr Type) -> LiftFn (Alt Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Alt Type
a, Expr Type
e')
  Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> LiftFn (Expr Type))
-> Expr Type -> LiftFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ Expr Type -> [(Alt Type, Expr Type)] -> Type -> Expr Type
forall t. Expr t -> [(Alt t, Expr t)] -> t -> Expr t
I.Match Expr Type
s' [(Alt Type, Expr Type)]
as' Type
t
liftLambdas lit :: Expr Type
lit@I.Lit{} = Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
lit
liftLambdas dat :: Expr Type
dat@I.Data{} = Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
dat
liftLambdas e :: Expr Type
e@I.Exception{} = Expr Type -> LiftFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
e


liftLambda :: I.Expr I.Type -> [I.Binder I.Type] -> I.VarId -> LiftFn (I.Expr I.Type, I.VarId, I.Expr I.Type)
liftLambda :: Expr Type
-> [Binder Type] -> VarId -> LiftFn (Expr Type, VarId, Expr Type)
liftLambda Expr Type
lam [Binder Type]
letBinds VarId
originName = do
  let ([Binder Type]
bs, Expr Type
body) = Expr Type -> ([Binder Type], Expr Type)
forall t. Expr t -> ([Binder t], Expr t)
I.unfoldLambda Expr Type
lam
  (Expr Type
body', [(VarId, Type)]
free) <- Maybe VarId
-> [Binder Type]
-> LiftFn (Expr Type)
-> LiftFn (Expr Type, [(VarId, Type)])
forall a.
Maybe VarId
-> [Binder Type] -> LiftFn a -> LiftFn (a, [(VarId, Type)])
withLiftedScope (VarId -> Maybe VarId
forall a. a -> Maybe a
Just VarId
originName) ([Binder Type]
letBinds [Binder Type] -> [Binder Type] -> [Binder Type]
forall a. [a] -> [a] -> [a]
++ [Binder Type]
bs) (LiftFn (Expr Type) -> LiftFn (Expr Type, [(VarId, Type)]))
-> LiftFn (Expr Type) -> LiftFn (Expr Type, [(VarId, Type)])
forall a b. (a -> b) -> a -> b
$ do
    Expr Type -> LiftFn (Expr Type)
liftLambdas Expr Type
body

  let -- Helper function to prepend arguments to a function type
      prependArrow :: [Type] -> Type -> Type
prependArrow [Type]
ts Type
t' = let ([Type]
ats', Type
rt') = Type -> ([Type], Type)
I.unfoldArrow Type
t' in ([Type], Type) -> Type
I.foldArrow ([Type]
ts [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
ats', Type
rt')

      -- 'tails' of the types of free variables in lambda body
      ([Type]
liftedLamArgTypes : [[Type]]
intermediateTypes) = [Type] -> [[Type]]
forall a. [a] -> [[a]]
tails ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ ((VarId, Type) -> Type) -> [(VarId, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarId, Type) -> Type
forall a b. (a, b) -> b
snd [(VarId, Type)]
free

      liftedLamType :: Type
liftedLamType = [Type] -> Type -> Type
prependArrow [Type]
liftedLamArgTypes (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
lam

      -- Construct arguments to be folded into the call site
      freeActuals :: [(Expr Type, Type)]
freeActuals = ((VarId, Type) -> [Type] -> (Expr Type, Type))
-> [(VarId, Type)] -> [[Type]] -> [(Expr Type, Type)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\(VarId
v', Type
t') [Type]
ts -> (VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var VarId
v' Type
t', [Type] -> Type -> Type
prependArrow [Type]
ts (Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
lam))) [(VarId, Type)]
free [[Type]]
intermediateTypes

  -- Generate a fresh name for the lifted lambda
  VarId
liftedName <- VarId -> Type -> LiftFn VarId
genName VarId
originName Type
liftedLamType

  let -- Replace lambda with call to lifted top-level lambda applied to all free variables
      replacement :: Expr Type
replacement = Expr Type -> [(Expr Type, Type)] -> Expr Type
forall t. Expr t -> [(Expr t, t)] -> Expr t
I.foldApp (VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var (VarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId VarId
liftedName) Type
liftedLamType) [(Expr Type, Type)]
freeActuals

      -- The lifted lambda expression, which should be added to the global list
      liftedLambda :: Expr Type
liftedLambda = [Binder Type] -> Expr Type -> Expr Type
I.foldLambda (((VarId, Type) -> Binder Type) -> [(VarId, Type)] -> [Binder Type]
forall a b. (a -> b) -> [a] -> [b]
map ((VarId -> Type -> Binder Type) -> (VarId, Type) -> Binder Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar) [(VarId, Type)]
free [Binder Type] -> [Binder Type] -> [Binder Type]
forall a. [a] -> [a] -> [a]
++ [Binder Type]
bs) Expr Type
body'

  (Expr Type, VarId, Expr Type)
-> LiftFn (Expr Type, VarId, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type
replacement, VarId
liftedName, Expr Type
liftedLambda)