module IR.Constraint.Monad where

import qualified Common.Compiler as Compiler
import Common.Identifiers (
  DConId (..),
  TConId (..),
  TVarId (..),
  VarId (..),
  fromId,
 )
import qualified Common.Identifiers as Ident
import Control.Monad.Except (ExceptT)
import qualified Control.Monad.Except as Except
import Control.Monad.State (StateT)
import qualified Control.Monad.State as State
import Data.Bifunctor (second)
import qualified Data.Map.Strict as Map
import GHC.IO.Unsafe (unsafePerformIO)
import qualified IR.Constraint.Canonical as Can
import qualified IR.IR as I

import Control.Monad.Writer (WriterT)
import qualified Control.Monad.Writer as Writer
import Prettyprinter (Doc)


type TC a = StateT TCState (ExceptT Compiler.Error (WriterT (Doc String) IO)) a


type DConInfo = (DConId, TConId, [TVarId], [Can.Type])


type DConMap = Map.Map DConId DConInfo


data TCState = TCState
  { TCState -> Int
_freshCtr :: Int -- Used for generating fresh variable names
  , TCState -> DConMap
_dconMap :: DConMap
  , TCState -> Map TConId Int
_kindMap :: Map.Map TConId Can.Kind
  , TCState -> Map VarId Type
_externMap :: Map.Map VarId Can.Type
  }


runTC :: TCState -> TC a -> (Either Compiler.Error a, Doc String)
runTC :: TCState -> TC a -> (Either Error a, Doc String)
runTC TCState
state TC a
m = IO (Either Error a, Doc String) -> (Either Error a, Doc String)
forall a. IO a -> a
unsafePerformIO (IO (Either Error a, Doc String) -> (Either Error a, Doc String))
-> IO (Either Error a, Doc String) -> (Either Error a, Doc String)
forall a b. (a -> b) -> a -> b
$ WriterT (Doc String) IO (Either Error a)
-> IO (Either Error a, Doc String)
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
Writer.runWriterT (WriterT (Doc String) IO (Either Error a)
 -> IO (Either Error a, Doc String))
-> WriterT (Doc String) IO (Either Error a)
-> IO (Either Error a, Doc String)
forall a b. (a -> b) -> a -> b
$ ExceptT Error (WriterT (Doc String) IO) a
-> WriterT (Doc String) IO (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
Except.runExceptT (ExceptT Error (WriterT (Doc String) IO) a
 -> WriterT (Doc String) IO (Either Error a))
-> ExceptT Error (WriterT (Doc String) IO) a
-> WriterT (Doc String) IO (Either Error a)
forall a b. (a -> b) -> a -> b
$ TC a -> TCState -> ExceptT Error (WriterT (Doc String) IO) a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
State.evalStateT TC a
m TCState
state


mkTCState :: I.Program Can.Annotations -> TCState
mkTCState :: Program Annotations -> TCState
mkTCState Program Annotations
prog =
  let kenv :: Map TConId Int
kenv = [(TConId, Int)] -> Map TConId Int
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(TConId, Int)] -> Map TConId Int)
-> [(TConId, Int)] -> Map TConId Int
forall a b. (a -> b) -> a -> b
$ ((TConId, TypeDef) -> (TConId, Int))
-> [(TConId, TypeDef)] -> [(TConId, Int)]
forall a b. (a -> b) -> [a] -> [b]
map ((TypeDef -> Int) -> (TConId, TypeDef) -> (TConId, Int)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((TypeDef -> Int) -> (TConId, TypeDef) -> (TConId, Int))
-> (TypeDef -> Int) -> (TConId, TypeDef) -> (TConId, Int)
forall a b. (a -> b) -> a -> b
$ [TVarId] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TVarId] -> Int) -> (TypeDef -> [TVarId]) -> TypeDef -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeDef -> [TVarId]
I.targs) ([(TConId, TypeDef)] -> [(TConId, Int)])
-> [(TConId, TypeDef)] -> [(TConId, Int)]
forall a b. (a -> b) -> a -> b
$ Program Annotations -> [(TConId, TypeDef)]
forall t. Program t -> [(TConId, TypeDef)]
I.typeDefs Program Annotations
prog
   in TCState :: Int -> DConMap -> Map TConId Int -> Map VarId Type -> TCState
TCState
        { _freshCtr :: Int
_freshCtr = Int
0
        , _dconMap :: DConMap
_dconMap = Program Annotations -> DConMap
mkDConMap Program Annotations
prog
        , _kindMap :: Map TConId Int
_kindMap = Map TConId Int -> Map TConId Int -> Map TConId Int
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Map TConId Int
kenv Map TConId Int
Can.builtinKinds
        , _externMap :: Map VarId Type
_externMap = [(VarId, Type)] -> Map VarId Type
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(VarId, Type)] -> Map VarId Type)
-> [(VarId, Type)] -> Map VarId Type
forall a b. (a -> b) -> a -> b
$ Program Annotations -> [(VarId, Type)]
forall t. Program t -> [(VarId, Type)]
I.externDecls Program Annotations
prog
        }


mkDConMap :: I.Program Can.Annotations -> DConMap
mkDConMap :: Program Annotations -> DConMap
mkDConMap I.Program{typeDefs :: forall t. Program t -> [(TConId, TypeDef)]
I.typeDefs = [(TConId, TypeDef)]
tdefs} =
  (DConMap -> (DConId, TConId, [TVarId], [Type]) -> DConMap)
-> DConMap -> [(DConId, TConId, [TVarId], [Type])] -> DConMap
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
    (\DConMap
m (DConId
dcid, TConId
tcid, [TVarId]
tvs, [Type]
ts) -> DConId -> (DConId, TConId, [TVarId], [Type]) -> DConMap -> DConMap
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert DConId
dcid (DConId
dcid, TConId
tcid, [TVarId]
tvs, [Type]
ts) DConMap
m)
    DConMap
forall k a. Map k a
Map.empty
    (((TConId, TypeDef) -> [(DConId, TConId, [TVarId], [Type])])
-> [(TConId, TypeDef)] -> [(DConId, TConId, [TVarId], [Type])]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (TConId, TypeDef) -> [(DConId, TConId, [TVarId], [Type])]
forall b. (b, TypeDef) -> [(DConId, b, [TVarId], [Type])]
tdef2dcon [(TConId, TypeDef)]
tdefs)
 where
  tdef2dcon :: (b, TypeDef) -> [(DConId, b, [TVarId], [Type])]
tdef2dcon (b
tcid, TypeDef
tdef) =
    [ (DConId
dcid, b
tcid, TypeDef -> [TVarId]
I.targs TypeDef
tdef, TypeVariant -> [Type]
getVariantArgTypes TypeVariant
tv)
    | (DConId
dcid, TypeVariant
tv) <- TypeDef -> [(DConId, TypeVariant)]
I.variants TypeDef
tdef
    ]
  -- tdef2dconCons (_, tdef) = [(dcid, getVariantArgTypes tv) | (dcid, tv) <- variants tdef]
  getVariantArgTypes :: TypeVariant -> [Type]
getVariantArgTypes (I.VariantNamed [(VarId, Type)]
ns) = ((VarId, Type) -> Type) -> [(VarId, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (VarId, Type) -> Type
forall a b. (a, b) -> b
snd [(VarId, Type)]
ns
  getVariantArgTypes (I.VariantUnnamed [Type]
ts) = [Type]
ts


-- | Generate a fresh identifier based on some prefix.
freshName :: String -> TC TVarId
freshName :: String -> TC TVarId
freshName String
prefix = do
  Int
n <- (TCState -> Int)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
State.gets TCState -> Int
_freshCtr
  (TCState -> TCState)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
State.modify ((TCState -> TCState)
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ())
-> (TCState -> TCState)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall a b. (a -> b) -> a -> b
$ \TCState
state -> TCState
state{_freshCtr :: Int
_freshCtr = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
  TVarId -> TC TVarId
forall (m :: * -> *) a. Monad m => a -> m a
return (TVarId -> TC TVarId) -> TVarId -> TC TVarId
forall a b. (a -> b) -> a -> b
$ String -> TVarId
forall a. IsString a => String -> a
Ident.fromString (String -> TVarId) -> String -> TVarId
forall a b. (a -> b) -> a -> b
$ String
prefix String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
n


-- | Generate a fresh program variable name, e.g., for anonymous binders.
freshVar :: TC VarId
freshVar :: TC VarId
freshVar = TVarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId (TVarId -> VarId) -> TC TVarId -> TC VarId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> TC TVarId
freshName String
"__anon_binder"


-- | Generate a fresh type identifier to be associated with annotations.
freshAnnVar :: TC VarId
freshAnnVar :: TC VarId
freshAnnVar = TVarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId (TVarId -> VarId) -> TC TVarId -> TC VarId
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> TC TVarId
freshName String
"$ann"


getDConInfo :: DConId -> TC (Maybe DConInfo)
getDConInfo :: DConId -> TC (Maybe (DConId, TConId, [TVarId], [Type]))
getDConInfo DConId
dcon = do
  DConMap
dconMap <- (TCState -> DConMap)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) DConMap
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
State.gets TCState -> DConMap
_dconMap
  Maybe (DConId, TConId, [TVarId], [Type])
-> TC (Maybe (DConId, TConId, [TVarId], [Type]))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (DConId, TConId, [TVarId], [Type])
 -> TC (Maybe (DConId, TConId, [TVarId], [Type])))
-> Maybe (DConId, TConId, [TVarId], [Type])
-> TC (Maybe (DConId, TConId, [TVarId], [Type]))
forall a b. (a -> b) -> a -> b
$ DConId -> DConMap -> Maybe (DConId, TConId, [TVarId], [Type])
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup DConId
dcon DConMap
dconMap


getKind :: TConId -> TC (Maybe Can.Kind)
getKind :: TConId -> TC (Maybe Int)
getKind TConId
tcon = do
  Map TConId Int
kenv <- (TCState -> Map TConId Int)
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (Map TConId Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
State.gets TCState -> Map TConId Int
_kindMap
  Maybe Int -> TC (Maybe Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Int -> TC (Maybe Int)) -> Maybe Int -> TC (Maybe Int)
forall a b. (a -> b) -> a -> b
$ TConId -> Map TConId Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup TConId
tcon Map TConId Int
kenv


getExtern :: VarId -> TC (Maybe Can.Type)
getExtern :: VarId -> TC (Maybe Type)
getExtern VarId
var = do
  Map VarId Type
externs <- (TCState -> Map VarId Type)
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (Map VarId Type)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
State.gets TCState -> Map VarId Type
_externMap
  Maybe Type -> TC (Maybe Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Type -> TC (Maybe Type)) -> Maybe Type -> TC (Maybe Type)
forall a b. (a -> b) -> a -> b
$ VarId -> Map VarId Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup VarId
var Map VarId Type
externs


throwError :: String -> TC a
throwError :: String -> TC a
throwError String
s = Error -> TC a
forall e (m :: * -> *) a. MonadError e m => e -> m a
Except.throwError (Error -> TC a) -> Error -> TC a
forall a b. (a -> b) -> a -> b
$ ErrorMsg -> Error
Compiler.TypeError (ErrorMsg -> Error) -> ErrorMsg -> Error
forall a b. (a -> b) -> a -> b
$ String -> ErrorMsg
forall a. IsString a => String -> a
Compiler.fromString String
s