{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE OverloadedStrings #-}
module IR.OptimizePar (
optimizePar,
) where
import Common.Compiler
import qualified Common.Compiler as Compiler
import Control.Monad.State.Lazy (
MonadState,
StateT (..),
evalStateT,
gets,
modify,
)
import IR.IR (Literal (LitIntegral))
import qualified IR.IR as I
data OptParCtx = OptParCtx
{ OptParCtx -> Int
numPars :: Int
, OptParCtx -> Int
numBadPars :: Int
}
newtype OptParFn a = LiftFn (StateT OptParCtx Compiler.Pass a)
deriving (a -> OptParFn b -> OptParFn a
(a -> b) -> OptParFn a -> OptParFn b
(forall a b. (a -> b) -> OptParFn a -> OptParFn b)
-> (forall a b. a -> OptParFn b -> OptParFn a) -> Functor OptParFn
forall a b. a -> OptParFn b -> OptParFn a
forall a b. (a -> b) -> OptParFn a -> OptParFn b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> OptParFn b -> OptParFn a
$c<$ :: forall a b. a -> OptParFn b -> OptParFn a
fmap :: (a -> b) -> OptParFn a -> OptParFn b
$cfmap :: forall a b. (a -> b) -> OptParFn a -> OptParFn b
Functor) via (StateT OptParCtx Compiler.Pass)
deriving (Functor OptParFn
a -> OptParFn a
Functor OptParFn
-> (forall a. a -> OptParFn a)
-> (forall a b. OptParFn (a -> b) -> OptParFn a -> OptParFn b)
-> (forall a b c.
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c)
-> (forall a b. OptParFn a -> OptParFn b -> OptParFn b)
-> (forall a b. OptParFn a -> OptParFn b -> OptParFn a)
-> Applicative OptParFn
OptParFn a -> OptParFn b -> OptParFn b
OptParFn a -> OptParFn b -> OptParFn a
OptParFn (a -> b) -> OptParFn a -> OptParFn b
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
forall a. a -> OptParFn a
forall a b. OptParFn a -> OptParFn b -> OptParFn a
forall a b. OptParFn a -> OptParFn b -> OptParFn b
forall a b. OptParFn (a -> b) -> OptParFn a -> OptParFn b
forall a b c.
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: OptParFn a -> OptParFn b -> OptParFn a
$c<* :: forall a b. OptParFn a -> OptParFn b -> OptParFn a
*> :: OptParFn a -> OptParFn b -> OptParFn b
$c*> :: forall a b. OptParFn a -> OptParFn b -> OptParFn b
liftA2 :: (a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
$cliftA2 :: forall a b c.
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
<*> :: OptParFn (a -> b) -> OptParFn a -> OptParFn b
$c<*> :: forall a b. OptParFn (a -> b) -> OptParFn a -> OptParFn b
pure :: a -> OptParFn a
$cpure :: forall a. a -> OptParFn a
$cp1Applicative :: Functor OptParFn
Applicative) via (StateT OptParCtx Compiler.Pass)
deriving (Applicative OptParFn
a -> OptParFn a
Applicative OptParFn
-> (forall a b. OptParFn a -> (a -> OptParFn b) -> OptParFn b)
-> (forall a b. OptParFn a -> OptParFn b -> OptParFn b)
-> (forall a. a -> OptParFn a)
-> Monad OptParFn
OptParFn a -> (a -> OptParFn b) -> OptParFn b
OptParFn a -> OptParFn b -> OptParFn b
forall a. a -> OptParFn a
forall a b. OptParFn a -> OptParFn b -> OptParFn b
forall a b. OptParFn a -> (a -> OptParFn b) -> OptParFn b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> OptParFn a
$creturn :: forall a. a -> OptParFn a
>> :: OptParFn a -> OptParFn b -> OptParFn b
$c>> :: forall a b. OptParFn a -> OptParFn b -> OptParFn b
>>= :: OptParFn a -> (a -> OptParFn b) -> OptParFn b
$c>>= :: forall a b. OptParFn a -> (a -> OptParFn b) -> OptParFn b
$cp1Monad :: Applicative OptParFn
Monad) via (StateT OptParCtx Compiler.Pass)
deriving (Monad OptParFn
Monad OptParFn
-> (forall a. String -> OptParFn a) -> MonadFail OptParFn
String -> OptParFn a
forall a. String -> OptParFn a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> OptParFn a
$cfail :: forall a. String -> OptParFn a
$cp1MonadFail :: Monad OptParFn
MonadFail) via (StateT OptParCtx Compiler.Pass)
deriving (MonadError Compiler.Error) via (StateT OptParCtx Compiler.Pass)
deriving (MonadState OptParCtx) via (StateT OptParCtx Compiler.Pass)
getNumberOfPars :: OptParFn Int
getNumberOfPars :: OptParFn Int
getNumberOfPars = (OptParCtx -> Int) -> OptParFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets OptParCtx -> Int
numPars
updateNumberOfPars :: Int -> OptParFn ()
updateNumberOfPars :: Int -> OptParFn ()
updateNumberOfPars Int
num = do
(OptParCtx -> OptParCtx) -> OptParFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((OptParCtx -> OptParCtx) -> OptParFn ())
-> (OptParCtx -> OptParCtx) -> OptParFn ()
forall a b. (a -> b) -> a -> b
$ \OptParCtx
st -> OptParCtx
st{numPars :: Int
numPars = Int
num}
runLiftFn :: OptParFn a -> Compiler.Pass a
runLiftFn :: OptParFn a -> Pass a
runLiftFn (LiftFn StateT OptParCtx Pass a
m) =
StateT OptParCtx Pass a -> OptParCtx -> Pass a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
StateT OptParCtx Pass a
m
OptParCtx :: Int -> Int -> OptParCtx
OptParCtx
{ numPars :: Int
numPars = Int
0
, numBadPars :: Int
numBadPars = Int
0
}
optimizePar :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
optimizePar :: Program Type -> Pass (Program Type)
optimizePar Program Type
p = OptParFn (Program Type) -> Pass (Program Type)
forall a. OptParFn a -> Pass a
runLiftFn (OptParFn (Program Type) -> Pass (Program Type))
-> OptParFn (Program Type) -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ do
[(Binder Type, Expr Type)]
optimizedDefs <- ((Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type))
-> [(Binder Type, Expr Type)]
-> OptParFn [(Binder Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type)
optimizeParTop ([(Binder Type, Expr Type)] -> OptParFn [(Binder Type, Expr Type)])
-> [(Binder Type, Expr Type)]
-> OptParFn [(Binder Type, Expr Type)]
forall a b. (a -> b) -> a -> b
$ Program Type -> [(Binder Type, Expr Type)]
forall t. Program t -> [(Binder t, Expr t)]
I.programDefs Program Type
p
Program Type -> OptParFn (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Program Type -> OptParFn (Program Type))
-> Program Type -> OptParFn (Program Type)
forall a b. (a -> b) -> a -> b
$ Program Type
p{programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [(Binder Type, Expr Type)]
optimizedDefs}
optimizeParTop :: (I.Binder I.Type, I.Expr I.Type) -> OptParFn (I.Binder I.Type, I.Expr I.Type)
optimizeParTop :: (Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type)
optimizeParTop (Binder Type
nm, Expr Type
rhs) = do
Expr Type
rhs' <- Expr Type -> OptParFn (Expr Type)
detectReplaceBadPar Expr Type
rhs
(Expr Type
rhs'', Int
_) <- Expr Type -> OptParFn (Expr Type, Int)
countPars Expr Type
rhs'
(Expr Type
rhs''', Int
_) <- Expr Type -> OptParFn (Expr Type, Int)
countBadPars Expr Type
rhs''
(Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Binder Type
nm, Expr Type
rhs''')
detectReplaceBadPar :: I.Expr I.Type -> OptParFn (I.Expr I.Type)
detectReplaceBadPar :: Expr Type -> OptParFn (Expr Type)
detectReplaceBadPar Expr Type
e = do
Expr Type -> OptParFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
e
countPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int)
countPars :: Expr Type -> OptParFn (Expr Type, Int)
countPars Expr Type
e = do
Int
x <- OptParFn Int
getNumberOfPars
Int -> OptParFn ()
updateNumberOfPars (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
0)
(Expr Type, Int) -> OptParFn (Expr Type, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type
e, Int
x)
isBad :: I.Expr I.Type -> Bool
isBad :: Expr Type -> Bool
isBad Expr Type
_ = Bool
False
countBadPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int)
countBadPars :: Expr Type -> OptParFn (Expr Type, Int)
countBadPars Expr Type
e = do
let y :: Bool
y = Expr Type -> Bool
isBad (Literal -> Type -> Expr Type
forall t. Literal -> t -> Expr t
I.Lit (Integer -> Literal
LitIntegral Integer
5) (Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
e))
(Expr Type, Int) -> OptParFn (Expr Type, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type
e, Bool -> Int
forall a. Enum a => a -> Int
fromEnum Bool
y)