{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ViewPatterns #-}
module IR.SegmentLets (
segmentLets,
segmentDefs,
segmentDefs',
) where
import qualified Common.Compiler as Compiler
import Common.Identifiers (
HasFreeVars (..),
VarId (..),
genId,
showId,
ungenId,
)
import qualified IR.IR as I
import Control.Monad.State (
State,
evalState,
gets,
modify,
)
import Data.Bifunctor (Bifunctor (..))
import Data.Generics.Aliases (mkT)
import Data.Generics.Schemes (everywhere)
import Data.Graph (
SCC (..),
stronglyConnComp,
)
import qualified Data.Set as S
type T = I.Type
type Def t = (I.Binder t, I.Expr t)
segmentLets :: I.Program T -> Compiler.Pass (I.Program T)
segmentLets :: Program T -> Pass (Program T)
segmentLets Program T
p =
Program T -> Pass (Program T)
forall (m :: * -> *) a. Monad m => a -> m a
return Program T
p{programDefs :: [(Binder T, Expr T)]
I.programDefs = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((Expr T -> Expr T) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT Expr T -> Expr T
segmentLetExpr) ([(Binder T, Expr T)] -> [(Binder T, Expr T)])
-> [(Binder T, Expr T)] -> [(Binder T, Expr T)]
forall a b. (a -> b) -> a -> b
$ Program T -> [(Binder T, Expr T)]
forall t. Program t -> [(Binder t, Expr t)]
I.programDefs Program T
p}
segmentLetExpr :: I.Expr T -> I.Expr T
segmentLetExpr :: Expr T -> Expr T
segmentLetExpr (I.Let [(Binder T, Expr T)]
ds Expr T
b T
t) = ([(Binder T, Expr T)] -> Expr T -> Expr T)
-> Expr T -> [[(Binder T, Expr T)]] -> Expr T
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr [(Binder T, Expr T)] -> Expr T -> Expr T
ilet Expr T
b ([[(Binder T, Expr T)]] -> Expr T)
-> [[(Binder T, Expr T)]] -> Expr T
forall a b. (a -> b) -> a -> b
$ [(Binder T, Expr T)] -> [[(Binder T, Expr T)]]
forall t. [Def t] -> [[Def t]]
segmentDefs [(Binder T, Expr T)]
ds
where
ilet :: [(Binder T, Expr T)] -> Expr T -> Expr T
ilet [(Binder T, Expr T)]
d' Expr T
b' = [(Binder T, Expr T)] -> Expr T -> T -> Expr T
forall t. [(Binder t, Expr t)] -> Expr t -> t -> Expr t
I.Let [(Binder T, Expr T)]
d' Expr T
b' T
t
segmentLetExpr Expr T
e = Expr T
e
segmentDefs :: [Def t] -> [[Def t]]
segmentDefs :: [Def t] -> [[Def t]]
segmentDefs = (SCC (Def t) -> [Def t]) -> [SCC (Def t)] -> [[Def t]]
forall a b. (a -> b) -> [a] -> [b]
map SCC (Def t) -> [Def t]
forall t. SCC (Def t) -> [Def t]
fromSCC ([SCC (Def t)] -> [[Def t]])
-> ([Def t] -> [SCC (Def t)]) -> [Def t] -> [[Def t]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Def t, VarId, [VarId])] -> [SCC (Def t)]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp ([(Def t, VarId, [VarId])] -> [SCC (Def t)])
-> ([Def t] -> [(Def t, VarId, [VarId])])
-> [Def t]
-> [SCC (Def t)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State Int [(Def t, VarId, [VarId])]
-> Int -> [(Def t, VarId, [VarId])]
forall s a. State s a -> s -> a
`evalState` Int
0) (State Int [(Def t, VarId, [VarId])] -> [(Def t, VarId, [VarId])])
-> ([Def t] -> State Int [(Def t, VarId, [VarId])])
-> [Def t]
-> [(Def t, VarId, [VarId])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Def t -> StateT Int Identity (Def t, VarId, [VarId]))
-> [Def t] -> State Int [(Def t, VarId, [VarId])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Def t -> StateT Int Identity (Def t, VarId, [VarId])
forall t. Def t -> State Int (Def t, VarId, [VarId])
toGraph
where
toGraph :: Def t -> State Int (Def t, I.VarId, [I.VarId])
toGraph :: Def t -> State Int (Def t, VarId, [VarId])
toGraph d :: Def t
d@(Binder t -> Maybe VarId
forall t. Binder t -> Maybe VarId
I._binderId -> Maybe VarId
b, Expr t
e) = do
VarId
v <- StateT Int Identity VarId
-> (VarId -> StateT Int Identity VarId)
-> Maybe VarId
-> StateT Int Identity VarId
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ((Int -> Int) -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) StateT Int Identity ()
-> StateT Int Identity VarId -> StateT Int Identity VarId
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Int -> VarId) -> StateT Int Identity VarId
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Int -> VarId
genVar) VarId -> StateT Int Identity VarId
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe VarId
b
(Def t, VarId, [VarId]) -> State Int (Def t, VarId, [VarId])
forall (m :: * -> *) a. Monad m => a -> m a
return (Def t
d, VarId
v, Set VarId -> [VarId]
forall a. Set a -> [a]
S.toList (Set VarId -> [VarId]) -> Set VarId -> [VarId]
forall a b. (a -> b) -> a -> b
$ Expr t -> Set VarId
forall t i. HasFreeVars t i => t -> Set i
freeVars Expr t
e)
fromSCC :: SCC (Def t) -> [Def t]
fromSCC :: SCC (Def t) -> [Def t]
fromSCC = (Def t -> Def t) -> [Def t] -> [Def t]
forall a b. (a -> b) -> [a] -> [b]
map ((Binder t -> Binder t) -> Def t -> Def t
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Binder t -> Binder t
forall t. Binder t -> Binder t
ungenVar) ([Def t] -> [Def t])
-> (SCC (Def t) -> [Def t]) -> SCC (Def t) -> [Def t]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SCC (Def t) -> [Def t]
forall a. SCC a -> [a]
from
where
from :: SCC a -> [a]
from (AcyclicSCC a
d) = [a
d]
from (CyclicSCC [a]
ds) = [a]
ds
genVar :: Int -> I.VarId
genVar :: Int -> VarId
genVar = VarId -> VarId
forall a. Identifiable a => a -> a
genId (VarId -> VarId) -> (Int -> VarId) -> Int -> VarId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VarId
"_wild" VarId -> VarId -> VarId
forall a. Semigroup a => a -> a -> a
<>) (VarId -> VarId) -> (Int -> VarId) -> Int -> VarId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> VarId
forall a b. (Show a, Identifiable b) => a -> b
showId
ungenVar :: I.Binder t -> I.Binder t
ungenVar :: Binder t -> Binder t
ungenVar b :: Binder t
b@(I.BindVar VarId
v t
_) = Binder t
b{_binderId :: Maybe VarId
I._binderId = VarId -> Maybe VarId
forall a. Identifiable a => a -> Maybe a
ungenId VarId
v}
ungenVar Binder t
b = Binder t
b
type Def' t a b = (VarId, a, b, I.Expr t)
segmentDefs' :: [Def' t a b] -> [[Def' t a b]]
segmentDefs' :: [Def' t a b] -> [[Def' t a b]]
segmentDefs' = (SCC (Def' t a b) -> [Def' t a b])
-> [SCC (Def' t a b)] -> [[Def' t a b]]
forall a b. (a -> b) -> [a] -> [b]
map SCC (Def' t a b) -> [Def' t a b]
forall t a b. SCC (Def' t a b) -> [Def' t a b]
fromSCC ([SCC (Def' t a b)] -> [[Def' t a b]])
-> ([Def' t a b] -> [SCC (Def' t a b)])
-> [Def' t a b]
-> [[Def' t a b]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Def' t a b, VarId, [VarId])] -> [SCC (Def' t a b)]
forall key node. Ord key => [(node, key, [key])] -> [SCC node]
stronglyConnComp ([(Def' t a b, VarId, [VarId])] -> [SCC (Def' t a b)])
-> ([Def' t a b] -> [(Def' t a b, VarId, [VarId])])
-> [Def' t a b]
-> [SCC (Def' t a b)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Def' t a b -> (Def' t a b, VarId, [VarId]))
-> [Def' t a b] -> [(Def' t a b, VarId, [VarId])]
forall a b. (a -> b) -> [a] -> [b]
map Def' t a b -> (Def' t a b, VarId, [VarId])
forall t a b. Def' t a b -> (Def' t a b, VarId, [VarId])
toGraph
where
toGraph :: Def' t a b -> (Def' t a b, I.VarId, [I.VarId])
toGraph :: Def' t a b -> (Def' t a b, VarId, [VarId])
toGraph d :: Def' t a b
d@(VarId
v, a
_, b
_, Expr t
e) = (Def' t a b
d, VarId
v, Set VarId -> [VarId]
forall a. Set a -> [a]
S.toList (Set VarId -> [VarId]) -> Set VarId -> [VarId]
forall a b. (a -> b) -> a -> b
$ Expr t -> Set VarId
forall t i. HasFreeVars t i => t -> Set i
freeVars Expr t
e)
fromSCC :: SCC (Def' t a b) -> [Def' t a b]
fromSCC :: SCC (Def' t a b) -> [Def' t a b]
fromSCC = SCC (Def' t a b) -> [Def' t a b]
forall a. SCC a -> [a]
from
where
from :: SCC a -> [a]
from (AcyclicSCC a
d) = [a
d]
from (CyclicSCC [a]
ds) = [a]
ds