{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE OverloadedStrings #-}

{- | Remove unnecessary Par expressions from the IR

This pass detects unnecessary par expressions and then replaces them with equivalent sequential expressions.
-}
module IR.OptimizePar (
  optimizePar,
) where

import Common.Compiler
import qualified Common.Compiler as Compiler
import Control.Monad.State.Lazy (
  MonadState,
  StateT (..),
  evalStateT,
  gets,
  modify,
 )
import IR.IR (Literal (LitIntegral))
import qualified IR.IR as I


-- | Optimization Environment
data OptParCtx = OptParCtx
  { OptParCtx -> Int
numPars :: Int
  -- ^ 'numPars' the number of par nodes in the input program's IR.
  , OptParCtx -> Int
numBadPars :: Int
  -- ^ 'numLitInts' the number of "bad" par node in the input program's IR.
  }


-- | OptPar Monad
newtype OptParFn a = LiftFn (StateT OptParCtx Compiler.Pass a)
  deriving (a -> OptParFn b -> OptParFn a
(a -> b) -> OptParFn a -> OptParFn b
(forall a b. (a -> b) -> OptParFn a -> OptParFn b)
-> (forall a b. a -> OptParFn b -> OptParFn a) -> Functor OptParFn
forall a b. a -> OptParFn b -> OptParFn a
forall a b. (a -> b) -> OptParFn a -> OptParFn b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> OptParFn b -> OptParFn a
$c<$ :: forall a b. a -> OptParFn b -> OptParFn a
fmap :: (a -> b) -> OptParFn a -> OptParFn b
$cfmap :: forall a b. (a -> b) -> OptParFn a -> OptParFn b
Functor) via (StateT OptParCtx Compiler.Pass)
  deriving (Functor OptParFn
a -> OptParFn a
Functor OptParFn
-> (forall a. a -> OptParFn a)
-> (forall a b. OptParFn (a -> b) -> OptParFn a -> OptParFn b)
-> (forall a b c.
    (a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c)
-> (forall a b. OptParFn a -> OptParFn b -> OptParFn b)
-> (forall a b. OptParFn a -> OptParFn b -> OptParFn a)
-> Applicative OptParFn
OptParFn a -> OptParFn b -> OptParFn b
OptParFn a -> OptParFn b -> OptParFn a
OptParFn (a -> b) -> OptParFn a -> OptParFn b
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
forall a. a -> OptParFn a
forall a b. OptParFn a -> OptParFn b -> OptParFn a
forall a b. OptParFn a -> OptParFn b -> OptParFn b
forall a b. OptParFn (a -> b) -> OptParFn a -> OptParFn b
forall a b c.
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: OptParFn a -> OptParFn b -> OptParFn a
$c<* :: forall a b. OptParFn a -> OptParFn b -> OptParFn a
*> :: OptParFn a -> OptParFn b -> OptParFn b
$c*> :: forall a b. OptParFn a -> OptParFn b -> OptParFn b
liftA2 :: (a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
$cliftA2 :: forall a b c.
(a -> b -> c) -> OptParFn a -> OptParFn b -> OptParFn c
<*> :: OptParFn (a -> b) -> OptParFn a -> OptParFn b
$c<*> :: forall a b. OptParFn (a -> b) -> OptParFn a -> OptParFn b
pure :: a -> OptParFn a
$cpure :: forall a. a -> OptParFn a
$cp1Applicative :: Functor OptParFn
Applicative) via (StateT OptParCtx Compiler.Pass)
  deriving (Applicative OptParFn
a -> OptParFn a
Applicative OptParFn
-> (forall a b. OptParFn a -> (a -> OptParFn b) -> OptParFn b)
-> (forall a b. OptParFn a -> OptParFn b -> OptParFn b)
-> (forall a. a -> OptParFn a)
-> Monad OptParFn
OptParFn a -> (a -> OptParFn b) -> OptParFn b
OptParFn a -> OptParFn b -> OptParFn b
forall a. a -> OptParFn a
forall a b. OptParFn a -> OptParFn b -> OptParFn b
forall a b. OptParFn a -> (a -> OptParFn b) -> OptParFn b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> OptParFn a
$creturn :: forall a. a -> OptParFn a
>> :: OptParFn a -> OptParFn b -> OptParFn b
$c>> :: forall a b. OptParFn a -> OptParFn b -> OptParFn b
>>= :: OptParFn a -> (a -> OptParFn b) -> OptParFn b
$c>>= :: forall a b. OptParFn a -> (a -> OptParFn b) -> OptParFn b
$cp1Monad :: Applicative OptParFn
Monad) via (StateT OptParCtx Compiler.Pass)
  deriving (Monad OptParFn
Monad OptParFn
-> (forall a. String -> OptParFn a) -> MonadFail OptParFn
String -> OptParFn a
forall a. String -> OptParFn a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> OptParFn a
$cfail :: forall a. String -> OptParFn a
$cp1MonadFail :: Monad OptParFn
MonadFail) via (StateT OptParCtx Compiler.Pass)
  deriving (MonadError Compiler.Error) via (StateT OptParCtx Compiler.Pass)
  deriving (MonadState OptParCtx) via (StateT OptParCtx Compiler.Pass)


-- | Example func to delete later! Demonstrates how to extract a value from the OptParFn Monad
getNumberOfPars :: OptParFn Int
getNumberOfPars :: OptParFn Int
getNumberOfPars = (OptParCtx -> Int) -> OptParFn Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets OptParCtx -> Int
numPars


-- | Example func to delete later! Demonstrates how to modify a value in the OptParFn Monad
updateNumberOfPars :: Int -> OptParFn ()
updateNumberOfPars :: Int -> OptParFn ()
updateNumberOfPars Int
num = do
  (OptParCtx -> OptParCtx) -> OptParFn ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((OptParCtx -> OptParCtx) -> OptParFn ())
-> (OptParCtx -> OptParCtx) -> OptParFn ()
forall a b. (a -> b) -> a -> b
$ \OptParCtx
st -> OptParCtx
st{numPars :: Int
numPars = Int
num}


-- | Run a LiftFn computation.
runLiftFn :: OptParFn a -> Compiler.Pass a
runLiftFn :: OptParFn a -> Pass a
runLiftFn (LiftFn StateT OptParCtx Pass a
m) =
  StateT OptParCtx Pass a -> OptParCtx -> Pass a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
    StateT OptParCtx Pass a
m
    OptParCtx :: Int -> Int -> OptParCtx
OptParCtx
      { numPars :: Int
numPars = Int
0
      , numBadPars :: Int
numBadPars = Int
0
      }


{- | Entry-point to Par Optimization.

Maps over top level definitions, removing unnecessary pars.
-}
optimizePar :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
optimizePar :: Program Type -> Pass (Program Type)
optimizePar Program Type
p = OptParFn (Program Type) -> Pass (Program Type)
forall a. OptParFn a -> Pass a
runLiftFn (OptParFn (Program Type) -> Pass (Program Type))
-> OptParFn (Program Type) -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ do
  [(Binder Type, Expr Type)]
optimizedDefs <- ((Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type))
-> [(Binder Type, Expr Type)]
-> OptParFn [(Binder Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type)
optimizeParTop ([(Binder Type, Expr Type)] -> OptParFn [(Binder Type, Expr Type)])
-> [(Binder Type, Expr Type)]
-> OptParFn [(Binder Type, Expr Type)]
forall a b. (a -> b) -> a -> b
$ Program Type -> [(Binder Type, Expr Type)]
forall t. Program t -> [(Binder t, Expr t)]
I.programDefs Program Type
p
  Program Type -> OptParFn (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Program Type -> OptParFn (Program Type))
-> Program Type -> OptParFn (Program Type)
forall a b. (a -> b) -> a -> b
$ Program Type
p{programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [(Binder Type, Expr Type)]
optimizedDefs}


-- | Given a top-level definition, detect + replace unnecessary par expressions
optimizeParTop :: (I.Binder I.Type, I.Expr I.Type) -> OptParFn (I.Binder I.Type, I.Expr I.Type)
optimizeParTop :: (Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type)
optimizeParTop (Binder Type
nm, Expr Type
rhs) = do
  Expr Type
rhs' <- Expr Type -> OptParFn (Expr Type)
detectReplaceBadPar Expr Type
rhs
  (Expr Type
rhs'', Int
_) <- Expr Type -> OptParFn (Expr Type, Int)
countPars Expr Type
rhs' -- calling this so we don't get an "unused" warning
  (Expr Type
rhs''', Int
_) <- Expr Type -> OptParFn (Expr Type, Int)
countBadPars Expr Type
rhs'' -- calling this so we don't get an "unused" warning
  -- uncomment the line below to test countPars
  -- (_, result) <- countPars rhs
  -- _ <- fail (show nm ++ ": Number of Par Exprs: " ++ show result)
  -- uncomment the two lines below to test countBadPars
  -- (_, result') <- countBadPars rhs
  -- _ <- fail (show nm ++ ": Number of Bad Par Exprs: " ++ show result')
  (Binder Type, Expr Type) -> OptParFn (Binder Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Binder Type
nm, Expr Type
rhs''')


-- | Detect Unnecessary Par Expressions + Replace With Equivalent Sequential Expression
detectReplaceBadPar :: I.Expr I.Type -> OptParFn (I.Expr I.Type)
detectReplaceBadPar :: Expr Type -> OptParFn (Expr Type)
detectReplaceBadPar Expr Type
e = do
  Expr Type -> OptParFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
e -- for now, just return the same thing (don't do anyting)


{- | 1) Count Par Nodes

Practice Exercise to Delete Later!

Traverse the IR representation of the body of a top level defintion,
and count the number of par expressions present.
Return the body unchanged, as well as the count numPars.
-}
countPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int)
countPars :: Expr Type -> OptParFn (Expr Type, Int)
countPars Expr Type
e = do
  -- currently a stub
  -- PUT YOUR IMPLEMENTATION HERE
  Int
x <- OptParFn Int
getNumberOfPars
  Int -> OptParFn ()
updateNumberOfPars (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
0) -- calling this so we don't get an "unused" warning
  (Expr Type, Int) -> OptParFn (Expr Type, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type
e, Int
x)


{- | 1.5) Implement IsBad Predicate

Suggested by John during Monday Meeting.
Returns true if par expr contains only instantaneous expressions as arguments.
False otherwise.
Useful for exercise 2.
-}
isBad :: I.Expr I.Type -> Bool
isBad :: Expr Type -> Bool
isBad Expr Type
_ = Bool
False -- currently a stub


{- | 2) Count Bad Par Nodes

Practice Exercise to Delete Later!

Traverse the IR representation of the body of a top level defintion,
and count the number of BAD par expressions present.
Use the helper predicate "isBad" in your implementation.
Return the body unchanged, as well as the count numBadPars.
-}
countBadPars :: I.Expr I.Type -> OptParFn (I.Expr I.Type, Int)
countBadPars :: Expr Type -> OptParFn (Expr Type, Int)
countBadPars Expr Type
e = do
  -- currently a stub
  let y :: Bool
y = Expr Type -> Bool
isBad (Literal -> Type -> Expr Type
forall t. Literal -> t -> Expr t
I.Lit (Integer -> Literal
LitIntegral Integer
5) (Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
e)) -- calling this so we don't get an "unused" warning
  (Expr Type, Int) -> OptParFn (Expr Type, Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type
e, Bool -> Int
forall a. Enum a => a -> Int
fromEnum Bool
y)