{-# LANGUAGE DerivingVia #-}

{- | Turns non-nullary data constructors into calls to constructor functions.

Worked example of ADT definition and corresponding constructor functions:

> type Shape
>   Square Int
>   Rect Int Int

Let's turn data constructor `Square` into constructor function `__Square`:

> __Square arg0 : Int -> Shape =  Square arg0

The difference is that `Square` cannot be partially-applied, whereas `__Square` can.

Next, `Rect` turns into this:

> __Rect arg0 arg1 : Int -> Int -> Shape =  Rect arg0 arg1

The difference is that `Rect` cannot be partially-applied, whereas `__Rect` can.


Representing constructor functions in the IR:

Every top-level function has the form (I.VarId, I.Expr Poly.Type) = (functionName, functionBody)
The function body is a lambda expression representing a call to the fully applied data constructor.

Let's turn the top-level func for `Square` into IR:

@
(__Square, body)
body = fun arg0 { App L R t } : Int -> Shape
 where L = Square : Int -> Shape
       R = arg0 : type Int
       t = Shape, because the type of a fully applied data constructor
           is its type constructor@
@
Next `Rect` turns into this:

@
(Rect, body)
body = fun arg0 { fun arg1 { App L R t } : Int -> Shape } : Int -> Int -> Shape
 where L = App L2 R2 t
       R = arg1 : Int
       t = Shape, because the type of a fully applied data constructor
           is its type constructor
        where L2 = Rect : Int -> Int -> Shape
              R  = arg0 : Int
              t = Int -> Shape, because at this point in the inner App,
                  Rect is partially applied with only 1 arg.
@
-}
module IR.DConToFunc (
  dConToFunc,
) where

import qualified Common.Compiler as Compiler

import Common.Compiler (MonadError)
import Common.Identifiers (
  fromId,
  fromString,
  ident,
 )

import Control.Monad.Reader (
  MonadReader,
  ReaderT (..),
  asks,
 )
import Data.Bifunctor (Bifunctor (..))
import Data.Generics.Aliases (mkM)
import Data.Generics.Schemes (everywhereM)
import Data.List (inits)
import qualified Data.Map as M
import Data.Maybe (mapMaybe)
import qualified IR.IR as I
import qualified IR.Types as I


-- | Environment storing arity of each 'DCon'
type ArityEnv = M.Map I.DConId Int


-- | Arity Reader Monad
newtype ArityFn a = ArityFn (ReaderT ArityEnv Compiler.Pass a)
  deriving (a -> ArityFn b -> ArityFn a
(a -> b) -> ArityFn a -> ArityFn b
(forall a b. (a -> b) -> ArityFn a -> ArityFn b)
-> (forall a b. a -> ArityFn b -> ArityFn a) -> Functor ArityFn
forall a b. a -> ArityFn b -> ArityFn a
forall a b. (a -> b) -> ArityFn a -> ArityFn b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ArityFn b -> ArityFn a
$c<$ :: forall a b. a -> ArityFn b -> ArityFn a
fmap :: (a -> b) -> ArityFn a -> ArityFn b
$cfmap :: forall a b. (a -> b) -> ArityFn a -> ArityFn b
Functor) via (ReaderT ArityEnv Compiler.Pass)
  deriving (Functor ArityFn
a -> ArityFn a
Functor ArityFn
-> (forall a. a -> ArityFn a)
-> (forall a b. ArityFn (a -> b) -> ArityFn a -> ArityFn b)
-> (forall a b c.
    (a -> b -> c) -> ArityFn a -> ArityFn b -> ArityFn c)
-> (forall a b. ArityFn a -> ArityFn b -> ArityFn b)
-> (forall a b. ArityFn a -> ArityFn b -> ArityFn a)
-> Applicative ArityFn
ArityFn a -> ArityFn b -> ArityFn b
ArityFn a -> ArityFn b -> ArityFn a
ArityFn (a -> b) -> ArityFn a -> ArityFn b
(a -> b -> c) -> ArityFn a -> ArityFn b -> ArityFn c
forall a. a -> ArityFn a
forall a b. ArityFn a -> ArityFn b -> ArityFn a
forall a b. ArityFn a -> ArityFn b -> ArityFn b
forall a b. ArityFn (a -> b) -> ArityFn a -> ArityFn b
forall a b c. (a -> b -> c) -> ArityFn a -> ArityFn b -> ArityFn c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ArityFn a -> ArityFn b -> ArityFn a
$c<* :: forall a b. ArityFn a -> ArityFn b -> ArityFn a
*> :: ArityFn a -> ArityFn b -> ArityFn b
$c*> :: forall a b. ArityFn a -> ArityFn b -> ArityFn b
liftA2 :: (a -> b -> c) -> ArityFn a -> ArityFn b -> ArityFn c
$cliftA2 :: forall a b c. (a -> b -> c) -> ArityFn a -> ArityFn b -> ArityFn c
<*> :: ArityFn (a -> b) -> ArityFn a -> ArityFn b
$c<*> :: forall a b. ArityFn (a -> b) -> ArityFn a -> ArityFn b
pure :: a -> ArityFn a
$cpure :: forall a. a -> ArityFn a
$cp1Applicative :: Functor ArityFn
Applicative) via (ReaderT ArityEnv Compiler.Pass)
  deriving (Applicative ArityFn
a -> ArityFn a
Applicative ArityFn
-> (forall a b. ArityFn a -> (a -> ArityFn b) -> ArityFn b)
-> (forall a b. ArityFn a -> ArityFn b -> ArityFn b)
-> (forall a. a -> ArityFn a)
-> Monad ArityFn
ArityFn a -> (a -> ArityFn b) -> ArityFn b
ArityFn a -> ArityFn b -> ArityFn b
forall a. a -> ArityFn a
forall a b. ArityFn a -> ArityFn b -> ArityFn b
forall a b. ArityFn a -> (a -> ArityFn b) -> ArityFn b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ArityFn a
$creturn :: forall a. a -> ArityFn a
>> :: ArityFn a -> ArityFn b -> ArityFn b
$c>> :: forall a b. ArityFn a -> ArityFn b -> ArityFn b
>>= :: ArityFn a -> (a -> ArityFn b) -> ArityFn b
$c>>= :: forall a b. ArityFn a -> (a -> ArityFn b) -> ArityFn b
$cp1Monad :: Applicative ArityFn
Monad) via (ReaderT ArityEnv Compiler.Pass)
  deriving (Monad ArityFn
Monad ArityFn
-> (forall a. String -> ArityFn a) -> MonadFail ArityFn
String -> ArityFn a
forall a. String -> ArityFn a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> ArityFn a
$cfail :: forall a. String -> ArityFn a
$cp1MonadFail :: Monad ArityFn
MonadFail) via (ReaderT ArityEnv Compiler.Pass)
  deriving (MonadError Compiler.Error) via (ReaderT ArityEnv Compiler.Pass)
  deriving (MonadReader ArityEnv) via (ReaderT ArityEnv Compiler.Pass)


-- | Run a computation within an arity environment
runArityFn :: [(I.TConId, I.TypeDef)] -> ArityFn a -> Compiler.Pass a
runArityFn :: [(TConId, TypeDef)] -> ArityFn a -> Pass a
runArityFn [(TConId, TypeDef)]
tds (ArityFn ReaderT ArityEnv Pass a
m) = ReaderT ArityEnv Pass a -> ArityEnv -> Pass a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT ArityEnv Pass a
m (ArityEnv -> Pass a) -> ArityEnv -> Pass a
forall a b. (a -> b) -> a -> b
$ [(DConId, Int)] -> ArityEnv
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(DConId, Int)]
env
 where
  env :: [(DConId, Int)]
env = ((TConId, TypeDef) -> [(DConId, Int)])
-> [(TConId, TypeDef)] -> [(DConId, Int)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (((DConId, TypeVariant) -> (DConId, Int))
-> [(DConId, TypeVariant)] -> [(DConId, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((TypeVariant -> Int) -> (DConId, TypeVariant) -> (DConId, Int)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second TypeVariant -> Int
I.variantFields) ([(DConId, TypeVariant)] -> [(DConId, Int)])
-> ((TConId, TypeDef) -> [(DConId, TypeVariant)])
-> (TConId, TypeDef)
-> [(DConId, Int)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeDef -> [(DConId, TypeVariant)]
I.variants (TypeDef -> [(DConId, TypeVariant)])
-> ((TConId, TypeDef) -> TypeDef)
-> (TConId, TypeDef)
-> [(DConId, TypeVariant)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TConId, TypeDef) -> TypeDef
forall a b. (a, b) -> b
snd) [(TConId, TypeDef)]
tds


{- | 'dConToFunc' modifies programDefs and traverses the IR to accomplish two tasks:

  (1) Add top-level constructor functions for each non-nullary 'DCon' to progamDefs
  2. Turn non-nullary data constuctors into calls to top level constructor funcs
-}
dConToFunc :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
dConToFunc :: Program Type -> Pass (Program Type)
dConToFunc p :: Program Type
p@I.Program{programDefs :: forall t. Program t -> [(Binder t, Expr t)]
I.programDefs = [(Binder Type, Expr Type)]
defs, typeDefs :: forall t. Program t -> [(TConId, TypeDef)]
I.typeDefs = [(TConId, TypeDef)]
tDefs} =
  [(TConId, TypeDef)]
-> ArityFn (Program Type) -> Pass (Program Type)
forall a. [(TConId, TypeDef)] -> ArityFn a -> Pass a
runArityFn [(TConId, TypeDef)]
tDefs (ArityFn (Program Type) -> Pass (Program Type))
-> ArityFn (Program Type) -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ do
    [(Binder Type, Expr Type)]
defs'' <- ArityFn [(Binder Type, Expr Type)]
defs' -- user defined functions
    Program Type -> ArityFn (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Program Type
p{programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [(Binder Type, Expr Type)]
tDefs' [(Binder Type, Expr Type)]
-> [(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. [a] -> [a] -> [a]
++ [(Binder Type, Expr Type)]
defs''} -- constructor funcs ++ user funcs
 where
  tDefs' :: [(Binder Type, Expr Type)]
tDefs' = [[(Binder Type, Expr Type)]] -> [(Binder Type, Expr Type)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ((TConId, TypeDef) -> [(Binder Type, Expr Type)]
createFuncs ((TConId, TypeDef) -> [(Binder Type, Expr Type)])
-> [(TConId, TypeDef)] -> [[(Binder Type, Expr Type)]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(TConId, TypeDef)]
tDefs) -- top-level constructor functions
  defs' :: ArityFn [(Binder Type, Expr Type)]
defs' =
    ((Binder Type, Expr Type) -> Bool)
-> [(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Binder Type -> [Binder Type] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` ((Binder Type, Expr Type) -> Binder Type
forall a b. (a, b) -> a
fst ((Binder Type, Expr Type) -> Binder Type)
-> [(Binder Type, Expr Type)] -> [Binder Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Binder Type, Expr Type)]
tDefs')) (Binder Type -> Bool)
-> ((Binder Type, Expr Type) -> Binder Type)
-> (Binder Type, Expr Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Binder Type, Expr Type) -> Binder Type
forall a b. (a, b) -> a
fst)
      ([(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)])
-> ArityFn [(Binder Type, Expr Type)]
-> ArityFn [(Binder Type, Expr Type)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GenericM ArityFn
-> [(Binder Type, Expr Type)] -> ArityFn [(Binder Type, Expr Type)]
forall (m :: * -> *). Monad m => GenericM m -> GenericM m
everywhereM ((Expr Type -> ArityFn (Expr Type)) -> a -> ArityFn a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM Expr Type -> ArityFn (Expr Type)
dataToApp) [(Binder Type, Expr Type)]
defs
  -- We need to filter defs to account for name collisions
  -- between user defined funcs and newly created and inserted constructor funcs
  {- Example of name collision case and correction
  Given the SSLANG program
  @
  type Color
    White
    Black
    RGB Int Int Int

  main ( cout : & Int ) -> () = ()
  @
  dConToFunc will produce
  @
  type Color
    White
    Black
    RGB Int Int Int

  __RGB (__arg0 : Int) (__arg1 : Int) (__arg2 : Int) -> Color =
  RGB __arg0 __arg1 __arg2

  main ( cout : & Int ) -> () = ()
  @
  This is okay.
  Now, given the SSLANG program
  @
  type Color
    White
    Black
    RGB Int Int Int

  __RGB (r : Int) (g : Int) (b : Int) -> Color =
  RGB r (g+2) b

  main ( cout : & Int ) -> () = ()
  @
  dConToFunc (without filtering defs!) produces
  @
  type Color
    White
    Black
    RGB Int Int Int

  __RGB (__arg0 : Int) (__arg1 : Int) (__arg2 : Int) -> Color =
  RGB __arg0 __arg1 __arg2

  __RGB (r : Int) (g : Int) (b : Int) -> Color =
  RGB r (g+2) b

  main ( cout : & Int ) -> () = ()
  @
  This is not okay, because there are two functions named > __RGB.
  To prevent this duplicate function case, we search defs for any user defined functions
  that have the same name as our newly inserted constructor functions, and remove them.
  @
  // defs = [ (__RGB, ...), (main, ...) ]
  // tDefs' = [ (__RGB,...) ]
  // If defs contains func w/ same name as a func in tDefs', remove it!
  // defs' = [ (main, ...) ]
  @
  dConToFunc (with filtering of defs) produces
  @
  type Color
    White
    Black
    RGB Int Int Int

  __RGB (r : Int) (g : Int) (b : Int) -> Color =
  RGB r (g+2) b

  main ( cout : & Int ) -> () = ()
  @
  Now this is okay.
    -}
  createFuncs :: (TConId, TypeDef) -> [(Binder Type, Expr Type)]
createFuncs (TConId
tconid, I.TypeDef{variants :: TypeDef -> [(DConId, TypeVariant)]
I.variants = [(DConId, TypeVariant)]
vars}) =
    TConId -> (DConId, TypeVariant) -> Maybe (Binder Type, Expr Type)
createFunc TConId
tconid ((DConId, TypeVariant) -> Maybe (Binder Type, Expr Type))
-> [(DConId, TypeVariant)] -> [(Binder Type, Expr Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
`mapMaybe` [(DConId, TypeVariant)]
vars


{- | Replace data constructor application with function application

  Turn I.App instances of the form (I.Data _ _) into (I.Var _ _)
-}
dataToApp :: I.Expr I.Type -> ArityFn (I.Expr I.Type)
dataToApp :: Expr Type -> ArityFn (Expr Type)
dataToApp a :: Expr Type
a@(I.Data DConId
dconid Type
t) = do
  Just Int
arity <- (ArityEnv -> Maybe Int) -> ArityFn (Maybe Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (DConId -> ArityEnv -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup DConId
dconid)
  case Int
arity of
    Int
0 -> Expr Type -> ArityFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Expr Type
a -- leave nullary data constructors alone
    Int
_ -> Expr Type -> ArityFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> ArityFn (Expr Type))
-> Expr Type -> ArityFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var (DConId -> VarId
nameFunc DConId
dconid) Type
t
dataToApp Expr Type
a = Expr Type -> ArityFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
a


{- | Create a top level function for each data constructor

  Returns Nothing for nullary data constructors,
  which don't need top-level constructor functions.
-}
createFunc
  :: I.TConId -> (I.DConId, I.TypeVariant) -> Maybe (I.Binder I.Type, I.Expr I.Type)
-- case of nullary dcon; guarantees params to be non-empty in the next pattern
createFunc :: TConId -> (DConId, TypeVariant) -> Maybe (Binder Type, Expr Type)
createFunc TConId
_ (DConId
_, I.VariantNamed []) = Maybe (Binder Type, Expr Type)
forall a. Maybe a
Nothing
createFunc TConId
tcon (DConId
dconid, I.VariantNamed [(VarId, Type)]
params) = (Binder Type, Expr Type) -> Maybe (Binder Type, Expr Type)
forall a. a -> Maybe a
Just (VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar VarId
func_name (Type -> Binder Type) -> Type -> Binder Type
forall a b. (a -> b) -> a -> b
$ Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
lambda, Expr Type
lambda)
 where
  func_name :: VarId
func_name = DConId -> VarId
nameFunc DConId
dconid -- distinguish func name from fully applied dcon in IR
  lambda :: Expr Type
lambda = [Binder Type] -> Expr Type -> Expr Type
I.foldLambda ((VarId -> Type -> Binder Type) -> (VarId, Type) -> Binder Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar ((VarId, Type) -> Binder Type) -> [(VarId, Type)] -> [Binder Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarId, Type)]
params) Expr Type
body
  body :: Expr Type
body = Expr Type -> [(Expr Type, Type)] -> Expr Type
forall t. Expr t -> [(Expr t, t)] -> Expr t
I.foldApp Expr Type
dcon [(Expr Type, Type)]
args
  dcon :: Expr Type
dcon = DConId -> Type -> Expr Type
forall t. DConId -> t -> Expr t
I.Data (DConId -> DConId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId DConId
dconid) Type
t
  args :: [(Expr Type, Type)]
args = [Expr Type] -> [Type] -> [(Expr Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VarId -> Type -> Expr Type) -> (VarId, Type) -> Expr Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var ((VarId, Type) -> Expr Type) -> [(VarId, Type)] -> [Expr Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarId, Type)]
params) [Type]
ts
  tconTyp :: Type
tconTyp = TConId -> [Type] -> Type
I.TCon TConId
tcon []
  (Type
t : [Type]
ts) =
    [Type] -> [Type]
forall a. [a] -> [a]
reverse ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$
      (Type -> Type -> Type) -> [Type] -> Type
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 Type -> Type -> Type
I.Arrow ([Type] -> Type) -> ([Type] -> [Type]) -> [Type] -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [Type]
forall a. [a] -> [a]
reverse
        ([Type] -> Type) -> [[Type]] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [[Type]] -> [[Type]]
forall a. [a] -> [a]
tail
          ([Type] -> [[Type]]
forall a. [a] -> [[a]]
inits ([Type] -> [[Type]]) -> [Type] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ [Type] -> [Type]
forall a. [a] -> [a]
reverse (((VarId, Type) -> Type
forall a b. (a, b) -> b
snd ((VarId, Type) -> Type) -> [(VarId, Type)] -> [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarId, Type)]
params) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type
tconTyp]))
{- Turns a list of types into arrow types representing nested lambdas
   @[Int, Int, Shape]@ becomes
   @[Int -> Int -> Shape, Int -> Shape, Shape]@
   @(t:ts)@ is permitted because params is always non-empty
   @tail@ is permitted because inits on a non-empty list always returns a list of at least two elements.
-}
createFunc TConId
tcon (DConId
dcon, I.VariantUnnamed [Type]
params) =
  TConId -> (DConId, TypeVariant) -> Maybe (Binder Type, Expr Type)
createFunc
    TConId
tcon
    (DConId
dcon, [(VarId, Type)] -> TypeVariant
I.VariantNamed [(VarId, Type)]
argNames)
 where
  argNames :: [(VarId, Type)]
argNames = (Type -> Int -> (VarId, Type))
-> [Type] -> [Int] -> [(VarId, Type)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Type
t Int
i -> (Int -> VarId
nameArg Int
i, Type
t)) [Type]
params [Int
0 ..]


-- | Create a name for the constructor function
nameFunc :: I.DConId -> I.VarId
nameFunc :: DConId -> VarId
nameFunc DConId
dconid = String -> VarId
forall a. IsString a => String -> a
fromString (String -> VarId) -> String -> VarId
forall a b. (a -> b) -> a -> b
$ String
"__" String -> String -> String
forall a. [a] -> [a] -> [a]
++ DConId -> String
forall i. Identifiable i => i -> String
ident DConId
dconid


-- | Create a name for a constructor function argument
nameArg :: Int -> I.VarId
nameArg :: Int -> VarId
nameArg Int
i = String -> VarId
forall a. IsString a => String -> a
fromString (String
"__arg" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)