{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE QuasiQuotes #-}

-- | Extract and encapsulate the type information needed during codegen.
module Codegen.Typegen (
  DConInfo (..),
  TConInfo (..),
  TypegenInfo (..),
  genTypes,
) where

import Codegen.LibSSM
import qualified IR.IR as I

import qualified Common.Compiler as Compiler
import Common.Identifiers (
  DConId (..),
  TConId (..),
 )

import qualified Language.C.Quote.GCC as C
import qualified Language.C.Syntax as C

import Control.Monad (forM)
import qualified Data.Map as M


-- | Type-related information, abstracted behind partial lookup functions.
data TypegenInfo = TypegenInfo
  { TypegenInfo -> DConId -> Maybe DConInfo
dconInfo :: DConId -> Maybe DConInfo
  -- ^ for each data constructor.
  , TypegenInfo -> TConId -> Maybe TConInfo
tconInfo :: TConId -> Maybe TConInfo
  -- ^ for each type constructor.
  }


-- | Information and codegen handlers associated with each data constructor.
data DConInfo = DConInfo
  { DConInfo -> TConId
dconType :: TConId
  -- ^ the type that the data constructor inhabits
  , DConInfo -> Int
dconSize :: Int
  -- ^ number of fields
  , DConInfo -> Bool
dconOnHeap :: Bool
  -- ^ whether the data constructor is heap-allocated
  , DConInfo -> Exp
dconCase :: C.Exp
  -- ^ the dcon tag to match on, i.e., in @case tag@
  , DConInfo -> Exp
dconConstruct :: C.Exp
  -- ^ constructs a dcon instance
  , DConInfo -> Int -> Exp -> Exp
dconDestruct :: Int -> C.Exp -> C.Exp
  -- ^ retrieve the ith field
  }


-- | Information and codegen handlers associated with each type constructor.
data TConInfo = TConInfo
  { TConInfo -> TypeEncoding
typeEncoding :: TypeEncoding
  -- ^ how the data type is encoded
  , TConInfo -> Exp -> Exp
typeScrut :: C.Exp -> C.Exp
  -- ^ how to retrieve the tag of an instance
  }


-- | How a data type may be encoded, i.e., heap-allocated, by value, or both.
data TypeEncoding = TypePacked | TypeMixed -- TODO: | TypeHeap


-- | Create codegen definitions and helpers for sslang type definitions.
genTypes
  :: [(TConId, I.TypeDef)] -> Compiler.Pass ([C.Definition], TypegenInfo)
genTypes :: [(TConId, TypeDef)] -> Pass ([Definition], TypegenInfo)
genTypes [(TConId, TypeDef)]
tdefs = do
  [Definition]
cdefs <- ((TConId, TypeDef) -> Pass Definition)
-> [(TConId, TypeDef)] -> Pass [Definition]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TConId, TypeDef) -> Pass Definition
genTypeDef [(TConId, TypeDef)]
tdefs
  TypegenInfo
typeInfo <- [(TConId, TypeDef)] -> Pass TypegenInfo
genTypeInfo [(TConId, TypeDef)]
tdefs
  ([Definition], TypegenInfo) -> Pass ([Definition], TypegenInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Definition]
cdefs, TypegenInfo
typeInfo)


-- | Generate C enums for sslang type definitions, enumerating tags.
genTypeDef :: (TConId, I.TypeDef) -> Compiler.Pass C.Definition
genTypeDef :: (TConId, TypeDef) -> Pass Definition
genTypeDef (TConId
tcon, TypeDef
tdef) = Definition -> Pass Definition
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cedecl|enum $id:tcon { $enums:tags };|]
 where
  tags :: [CEnum]
tags = case TypeDef
tdef of
    I.TypeDef [] [TVarId]
_ -> []
    I.TypeDef ((DConId
dcon, TypeVariant
_) : [(DConId, TypeVariant)]
dcons) [TVarId]
_ ->
      [C.cenum|$id:dcon = 0|] CEnum -> [CEnum] -> [CEnum]
forall a. a -> [a] -> [a]
: ((DConId, TypeVariant) -> CEnum)
-> [(DConId, TypeVariant)] -> [CEnum]
forall a b. (a -> b) -> [a] -> [b]
map (DConId, TypeVariant) -> CEnum
forall a b. ToIdent a => (a, b) -> CEnum
mkEnum [(DConId, TypeVariant)]
dcons
  mkEnum :: (a, b) -> CEnum
mkEnum (a
dcon, b
_) = [C.cenum|$id:dcon|]


-- | Compute codgen helpers for each sslang type definition.
genTypeInfo :: [(TConId, I.TypeDef)] -> Compiler.Pass TypegenInfo
genTypeInfo :: [(TConId, TypeDef)] -> Pass TypegenInfo
genTypeInfo [(TConId, TypeDef)]
tdefs = do
  [[(DConId, DConInfo)]]
dInfos <- [(TConId, TypeDef)]
-> ((TConId, TypeDef) -> Pass [(DConId, DConInfo)])
-> Pass [[(DConId, DConInfo)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(TConId, TypeDef)]
tdefs (((TConId, TypeDef) -> Pass [(DConId, DConInfo)])
 -> Pass [[(DConId, DConInfo)]])
-> ((TConId, TypeDef) -> Pass [(DConId, DConInfo)])
-> Pass [[(DConId, DConInfo)]]
forall a b. (a -> b) -> a -> b
$ \(TConId
tcon, I.TypeDef [(DConId, TypeVariant)]
tvars [TVarId]
_) -> do
    [(DConId, TypeVariant)]
-> ((DConId, TypeVariant) -> Pass (DConId, DConInfo))
-> Pass [(DConId, DConInfo)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(DConId, TypeVariant)]
tvars (((DConId, TypeVariant) -> Pass (DConId, DConInfo))
 -> Pass [(DConId, DConInfo)])
-> ((DConId, TypeVariant) -> Pass (DConId, DConInfo))
-> Pass [(DConId, DConInfo)]
forall a b. (a -> b) -> a -> b
$ \(DConId
dcon, TypeVariant
dvari) -> do
      let fields :: Int
fields = TypeVariant -> Int
I.variantFields TypeVariant
dvari
          onHeap :: Bool
onHeap = Int
fields Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0
          caseExp :: Exp
caseExp = [C.cexp|$id:dcon|]
          construct :: Exp
construct
            | Bool
onHeap = Int -> DConId -> Exp
new_adt Int
fields DConId
dcon
            | Bool
otherwise = Exp -> Exp
marshal [C.cexp|$id:dcon|]
          destruct :: Int -> Exp -> Exp
destruct = (Exp -> Int -> Exp) -> Int -> Exp -> Exp
forall a b c. (a -> b -> c) -> b -> a -> c
flip Exp -> Int -> Exp
adt_field
      -- TODO: does not handle packed ADTs
      (DConId, DConInfo) -> Pass (DConId, DConInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( DConId
dcon
        , DConInfo :: TConId
-> Int -> Bool -> Exp -> Exp -> (Int -> Exp -> Exp) -> DConInfo
DConInfo
            { dconType :: TConId
dconType = TConId
tcon
            , dconSize :: Int
dconSize = Int
fields
            , dconOnHeap :: Bool
dconOnHeap = Bool
onHeap
            , dconCase :: Exp
dconCase = Exp
caseExp
            , dconConstruct :: Exp
dconConstruct = Exp
construct
            , dconDestruct :: Int -> Exp -> Exp
dconDestruct = Int -> Exp -> Exp
destruct
            }
        )

  let dInfoLookup :: DConId -> Maybe DConInfo
dInfoLookup = (DConId -> Map DConId DConInfo -> Maybe DConInfo)
-> Map DConId DConInfo -> DConId -> Maybe DConInfo
forall a b c. (a -> b -> c) -> b -> a -> c
flip DConId -> Map DConId DConInfo -> Maybe DConInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Map DConId DConInfo -> DConId -> Maybe DConInfo)
-> Map DConId DConInfo -> DConId -> Maybe DConInfo
forall a b. (a -> b) -> a -> b
$ [(DConId, DConInfo)] -> Map DConId DConInfo
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(DConId, DConInfo)] -> Map DConId DConInfo)
-> [(DConId, DConInfo)] -> Map DConId DConInfo
forall a b. (a -> b) -> a -> b
$ [[(DConId, DConInfo)]] -> [(DConId, DConInfo)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[(DConId, DConInfo)]]
dInfos

  [(TConId, TConInfo)]
tInfos <- [(TConId, TypeDef)]
-> ((TConId, TypeDef) -> Pass (TConId, TConInfo))
-> Pass [(TConId, TConInfo)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(TConId, TypeDef)]
tdefs (((TConId, TypeDef) -> Pass (TConId, TConInfo))
 -> Pass [(TConId, TConInfo)])
-> ((TConId, TypeDef) -> Pass (TConId, TConInfo))
-> Pass [(TConId, TConInfo)]
forall a b. (a -> b) -> a -> b
$ \(TConId
tcon, I.TypeDef [(DConId, TypeVariant)]
tvars [TVarId]
_) -> do
    -- Extract info for each dcon associated with this tcon
    [DConInfo]
tvarsInfo <- [(DConId, TypeVariant)]
-> ((DConId, TypeVariant) -> Pass DConInfo) -> Pass [DConInfo]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(DConId, TypeVariant)]
tvars (((DConId, TypeVariant) -> Pass DConInfo) -> Pass [DConInfo])
-> ((DConId, TypeVariant) -> Pass DConInfo) -> Pass [DConInfo]
forall a b. (a -> b) -> a -> b
$ \(DConId
dcon, TypeVariant
_) -> do
      let failMsg :: [Char]
failMsg = [Char]
"Missing info for data constructor: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ DConId -> [Char]
forall a. Show a => a -> [Char]
show DConId
dcon
      Pass DConInfo
-> (DConInfo -> Pass DConInfo) -> Maybe DConInfo -> Pass DConInfo
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char] -> Pass DConInfo
forall (m :: * -> *) a. MonadError Error m => [Char] -> m a
Compiler.unexpected [Char]
failMsg) DConInfo -> Pass DConInfo
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe DConInfo -> Pass DConInfo)
-> Maybe DConInfo -> Pass DConInfo
forall a b. (a -> b) -> a -> b
$ DConId -> Maybe DConInfo
dInfoLookup DConId
dcon

    -- Determine the encoding of inhabitants of this type
    let (TypeEncoding
encoding, Exp -> Exp
tagFn)
          | Bool -> Bool
not ((DConInfo -> Bool) -> [DConInfo] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any DConInfo -> Bool
dconOnHeap [DConInfo]
tvarsInfo) = (TypeEncoding
TypePacked, Exp -> Exp
unmarshal)
          | Bool
otherwise = (TypeEncoding
TypeMixed, Exp -> Exp
adt_tag)
    (TConId, TConInfo) -> Pass (TConId, TConInfo)
forall (m :: * -> *) a. Monad m => a -> m a
return (TConId
tcon, TConInfo :: TypeEncoding -> (Exp -> Exp) -> TConInfo
TConInfo{typeEncoding :: TypeEncoding
typeEncoding = TypeEncoding
encoding, typeScrut :: Exp -> Exp
typeScrut = Exp -> Exp
tagFn})

  let tInfoLookup :: TConId -> Maybe TConInfo
tInfoLookup = (TConId -> Map TConId TConInfo -> Maybe TConInfo)
-> Map TConId TConInfo -> TConId -> Maybe TConInfo
forall a b c. (a -> b -> c) -> b -> a -> c
flip TConId -> Map TConId TConInfo -> Maybe TConInfo
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Map TConId TConInfo -> TConId -> Maybe TConInfo)
-> Map TConId TConInfo -> TConId -> Maybe TConInfo
forall a b. (a -> b) -> a -> b
$ [(TConId, TConInfo)] -> Map TConId TConInfo
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(TConId, TConInfo)]
tInfos

  TypegenInfo -> Pass TypegenInfo
forall (m :: * -> *) a. Monad m => a -> m a
return TypegenInfo :: (DConId -> Maybe DConInfo)
-> (TConId -> Maybe TConInfo) -> TypegenInfo
TypegenInfo{dconInfo :: DConId -> Maybe DConInfo
dconInfo = DConId -> Maybe DConInfo
dInfoLookup, tconInfo :: TConId -> Maybe TConInfo
tconInfo = TConId -> Maybe TConInfo
tInfoLookup}