{-# LANGUAGE OverloadedStrings #-}

module IR.Constraint.Type where

import qualified Common.Identifiers as Ident
import Control.Monad.Trans (liftIO)
import Data.Bifunctor (Bifunctor (first))
import qualified Data.Map.Strict as Map
import qualified IR.Constraint.Canonical as Can
import qualified IR.Constraint.Error as ET
import IR.Constraint.Monad (
  TC,
  freshName,
  freshVar,
 )
import qualified IR.Constraint.UnionFind as UF
import qualified IR.IR as I


-- | CONSTRAINTS
data Constraint
  = CTrue
  | CSaveTheEnvironment
  | CEqual Type Type
  | CPattern Type Type
  | CLocal Ident.Identifier Type
  | CForeign Can.Scheme Type
  | CAnd [Constraint]
  | CLet
      { Constraint -> [Variable]
_rigidVars :: [Variable]
      , Constraint -> [Variable]
_flexVars :: [Variable]
      , Constraint -> Map Identifier Type
_header :: Map.Map Ident.Identifier Type
      , Constraint -> Constraint
_headerCon :: Constraint
      , Constraint -> Constraint
_bodyCon :: Constraint
      }


exists :: [Variable] -> Constraint -> Constraint
exists :: [Variable] -> Constraint -> Constraint
exists [Variable]
flexVars Constraint
constraint = [Variable]
-> [Variable]
-> Map Identifier Type
-> Constraint
-> Constraint
-> Constraint
CLet [] [Variable]
flexVars Map Identifier Type
forall k a. Map k a
Map.empty Constraint
constraint Constraint
CTrue


-- | TYPE PRIMITIVES
type Variable = UF.Point Descriptor


type Attachment = (I.Annotations, Variable)


data FlatType = TCon1 Ident.TConId [Variable]


data Type
  = TConN Ident.TConId [Type]
  | TVarN Variable
  deriving (Type -> Type -> Bool
(Type -> Type -> Bool) -> (Type -> Type -> Bool) -> Eq Type
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Type -> Type -> Bool
$c/= :: Type -> Type -> Bool
== :: Type -> Type -> Bool
$c== :: Type -> Type -> Bool
Eq)


-- | DESCRIPTORS
data Descriptor = Descriptor
  { Descriptor -> Content
_content :: Content
  , Descriptor -> Int
_rank :: Int
  , Descriptor -> Mark
_mark :: Mark
  , Descriptor -> Maybe Variable
_copy :: Maybe Variable -- for instantiation
  }


data Content
  = FlexVar Ident.TVarId
  | RigidVar Ident.TVarId
  | Structure FlatType
  | Error


mkDescriptor :: Content -> Descriptor
mkDescriptor :: Content -> Descriptor
mkDescriptor Content
content = Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
noRank Mark
noMark Maybe Variable
forall a. Maybe a
Nothing


-- | RANKS

-- No rank means that the variable is generic
noRank :: Int
noRank :: Int
noRank = Int
0


-- Outermost rank means that we have not entered header of any CLet
outermostRank :: Int
outermostRank :: Int
outermostRank = Int
1


-- | MARKS
newtype Mark = Mark Int
  deriving (Mark -> Mark -> Bool
(Mark -> Mark -> Bool) -> (Mark -> Mark -> Bool) -> Eq Mark
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Mark -> Mark -> Bool
$c/= :: Mark -> Mark -> Bool
== :: Mark -> Mark -> Bool
$c== :: Mark -> Mark -> Bool
Eq, Eq Mark
Eq Mark
-> (Mark -> Mark -> Ordering)
-> (Mark -> Mark -> Bool)
-> (Mark -> Mark -> Bool)
-> (Mark -> Mark -> Bool)
-> (Mark -> Mark -> Bool)
-> (Mark -> Mark -> Mark)
-> (Mark -> Mark -> Mark)
-> Ord Mark
Mark -> Mark -> Bool
Mark -> Mark -> Ordering
Mark -> Mark -> Mark
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Mark -> Mark -> Mark
$cmin :: Mark -> Mark -> Mark
max :: Mark -> Mark -> Mark
$cmax :: Mark -> Mark -> Mark
>= :: Mark -> Mark -> Bool
$c>= :: Mark -> Mark -> Bool
> :: Mark -> Mark -> Bool
$c> :: Mark -> Mark -> Bool
<= :: Mark -> Mark -> Bool
$c<= :: Mark -> Mark -> Bool
< :: Mark -> Mark -> Bool
$c< :: Mark -> Mark -> Bool
compare :: Mark -> Mark -> Ordering
$ccompare :: Mark -> Mark -> Ordering
$cp1Ord :: Eq Mark
Ord)


noMark :: Mark
noMark :: Mark
noMark = Int -> Mark
Mark Int
2


occursMark :: Mark
occursMark :: Mark
occursMark = Int -> Mark
Mark Int
1


getVarNamesMark :: Mark
getVarNamesMark :: Mark
getVarNamesMark = Int -> Mark
Mark Int
0


-- occursMark :: Mark
-- occursMark = Mark 1

-- getVarNamesMark :: Mark
-- getVarNamesMark = Mark 0

{-# INLINE nextMark #-}
nextMark :: Mark -> Mark
nextMark :: Mark -> Mark
nextMark (Mark Int
mark) = Int -> Mark
Mark (Int
mark Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)


-- | BUILT-IN TYPES

-- | Fold a list of argument types and a return type into an 'Arrow' 'Type'.
foldArrow :: ([Type], Type) -> Type
foldArrow :: ([Type], Type) -> Type
foldArrow (Type
a : [Type]
as, Type
rt) = Type
a Type -> Type -> Type
==> ([Type], Type) -> Type
foldArrow ([Type]
as, Type
rt)
foldArrow ([], Type
t) = Type
t


infixr 0 ==>
(==>) :: Type -> Type -> Type
==> :: Type -> Type -> Type
(==>) Type
t1 Type
t2 = TConId -> [Type] -> Type
TConN TConId
"->" [Type
t1, Type
t2]


unit :: Type
unit :: Type
unit = TConId -> [Type] -> Type
TConN TConId
"()" []


ref :: Type -> Type
ref :: Type -> Type
ref Type
a = TConId -> [Type] -> Type
TConN TConId
"&" [Type
a]


list :: Type -> Type
list :: Type -> Type
list Type
a = TConId -> [Type] -> Type
TConN TConId
"[]" [Type
a]


time :: Type
time :: Type
time = TConId -> [Type] -> Type
TConN TConId
"Time" []


i64 :: Type
i64 :: Type
i64 = TConId -> [Type] -> Type
TConN TConId
"Int64" []


u64 :: Type
u64 :: Type
u64 = TConId -> [Type] -> Type
TConN TConId
"UInt64" []


i32 :: Type
i32 :: Type
i32 = TConId -> [Type] -> Type
TConN TConId
"Int32" []


u32 :: Type
u32 :: Type
u32 = TConId -> [Type] -> Type
TConN TConId
"UInt32" []


i16 :: Type
i16 :: Type
i16 = TConId -> [Type] -> Type
TConN TConId
"Int16" []


u16 :: Type
u16 :: Type
u16 = TConId -> [Type] -> Type
TConN TConId
"UInt16" []


i8 :: Type
i8 :: Type
i8 = TConId -> [Type] -> Type
TConN TConId
"Int8" []


u8 :: Type
u8 :: Type
u8 = TConId -> [Type] -> Type
TConN TConId
"UInt8" []


-- | MAKE FLEX VARIABLES
mkFlexVar :: TC Variable
mkFlexVar :: TC Variable
mkFlexVar = do
  TVarId
name <- String -> TC TVarId
freshName String
"'t"
  Descriptor -> TC Variable
forall (m :: * -> *) a. MonadIO m => a -> m (Point a)
UF.fresh (Descriptor -> TC Variable) -> Descriptor -> TC Variable
forall a b. (a -> b) -> a -> b
$ Content -> Descriptor
mkDescriptor (Content -> Descriptor) -> Content -> Descriptor
forall a b. (a -> b) -> a -> b
$ TVarId -> Content
FlexVar TVarId
name


mkIRFlexVar :: TC Variable
mkIRFlexVar :: TC Variable
mkIRFlexVar = do
  TVarId
name <- String -> TC TVarId
freshName String
"'ir_t"
  Descriptor -> TC Variable
forall (m :: * -> *) a. MonadIO m => a -> m (Point a)
UF.fresh (Descriptor -> TC Variable) -> Descriptor -> TC Variable
forall a b. (a -> b) -> a -> b
$ Content -> Descriptor
mkDescriptor (Content -> Descriptor) -> Content -> Descriptor
forall a b. (a -> b) -> a -> b
$ TVarId -> Content
FlexVar TVarId
name


mkRigidVar :: TC Variable
mkRigidVar :: TC Variable
mkRigidVar = do
  TVarId
name <- String -> TC TVarId
freshName String
"~t"
  Descriptor -> TC Variable
forall (m :: * -> *) a. MonadIO m => a -> m (Point a)
UF.fresh (Descriptor -> TC Variable) -> Descriptor -> TC Variable
forall a b. (a -> b) -> a -> b
$ Content -> Descriptor
mkDescriptor (Content -> Descriptor) -> Content -> Descriptor
forall a b. (a -> b) -> a -> b
$ TVarId -> Content
RigidVar TVarId
name


-- | TO CANONICAL TYPE
toCanType :: Variable -> TC Can.Type
toCanType :: Variable -> TC Type
toCanType Variable
variable = do
  (Descriptor Content
content Int
_ Mark
_ Maybe Variable
_) <- IO Descriptor
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Descriptor
 -> StateT
      TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor)
-> IO Descriptor
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall a b. (a -> b) -> a -> b
$ Variable -> IO Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
variable
  case Content
content of
    Structure FlatType
term -> FlatType -> TC Type
termToCanType FlatType
term
    FlexVar TVarId
name -> Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return (TVarId -> Type
Can.TVar TVarId
name)
    RigidVar TVarId
name -> Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return (TVarId -> Type
Can.TVar TVarId
name)
    Content
Error -> String -> TC Type
forall a. HasCallStack => String -> a
error String
"cannot handle Error types in variableToCanType"


termToCanType :: FlatType -> TC Can.Type
termToCanType :: FlatType -> TC Type
termToCanType FlatType
term = case FlatType
term of
  TCon1 TConId
name [Variable]
args -> TConId -> [Type] -> Type
Can.TCon TConId
name ([Type] -> Type)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
-> TC Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Variable -> TC Type)
-> [Variable]
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Variable -> TC Type
toCanType [Variable]
args


-- | TO ERROR TYPE
toErrorType :: Variable -> TC ET.Type
toErrorType :: Variable -> TC Type
toErrorType Variable
variable = do
  Descriptor
descriptor <- IO Descriptor
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Descriptor
 -> StateT
      TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor)
-> IO Descriptor
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall a b. (a -> b) -> a -> b
$ Variable -> IO Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
variable
  let mark :: Mark
mark = Descriptor -> Mark
_mark Descriptor
descriptor
  if Mark
mark Mark -> Mark -> Bool
forall a. Eq a => a -> a -> Bool
== Mark
occursMark
    then Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ET.Infinite
    else do
      IO ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ()
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ())
-> IO ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall a b. (a -> b) -> a -> b
$ Variable -> (Descriptor -> Descriptor) -> IO ()
forall (m :: * -> *) a. MonadIO m => Point a -> (a -> a) -> m ()
UF.modify Variable
variable (\Descriptor
desc -> Descriptor
desc{_mark :: Mark
_mark = Mark
occursMark})
      Type
errType <- Variable -> Content -> TC Type
contentToErrorType Variable
variable (Descriptor -> Content
_content Descriptor
descriptor)
      IO ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ()
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ())
-> IO ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall a b. (a -> b) -> a -> b
$ Variable -> (Descriptor -> Descriptor) -> IO ()
forall (m :: * -> *) a. MonadIO m => Point a -> (a -> a) -> m ()
UF.modify Variable
variable (\Descriptor
desc -> Descriptor
desc{_mark :: Mark
_mark = Mark
mark})
      Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
errType


contentToErrorType :: Variable -> Content -> TC ET.Type
contentToErrorType :: Variable -> Content -> TC Type
contentToErrorType Variable
_ Content
content = case Content
content of
  Structure FlatType
term -> FlatType -> TC Type
termToErrorType FlatType
term
  FlexVar TVarId
name -> Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return (TVarId -> Type
ET.FlexVar TVarId
name)
  RigidVar TVarId
name -> Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return (TVarId -> Type
ET.RigidVar TVarId
name)
  Content
Error -> Type -> TC Type
forall (m :: * -> *) a. Monad m => a -> m a
return Type
ET.Error


termToErrorType :: FlatType -> TC ET.Type
termToErrorType :: FlatType -> TC Type
termToErrorType FlatType
term = case FlatType
term of
  TCon1 TConId
name [Variable]
args -> TConId -> [Type] -> Type
ET.Type TConId
name ([Type] -> Type)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
-> TC Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Variable -> TC Type)
-> [Variable]
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Variable -> TC Type
toErrorType [Variable]
args


binderToVarId :: I.Binder t -> TC Ident.VarId
binderToVarId :: Binder t -> TC VarId
binderToVarId = TC VarId -> (VarId -> TC VarId) -> Maybe VarId -> TC VarId
forall b a. b -> (a -> b) -> Maybe a -> b
maybe TC VarId
freshVar VarId -> TC VarId
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe VarId -> TC VarId)
-> (Binder t -> Maybe VarId) -> Binder t -> TC VarId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Binder t -> Maybe VarId
forall a. Binder a -> Maybe VarId
I.binderToVar


uncarryAttachment :: (I.Carrier c) => c Attachment -> ([I.Annotation], Variable)
uncarryAttachment :: c Attachment -> ([Annotation], Variable)
uncarryAttachment c Attachment
carrier = (Annotations -> [Annotation])
-> Attachment -> ([Annotation], Variable)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Annotations -> [Annotation]
Can.unAnnotations (Attachment -> ([Annotation], Variable))
-> Attachment -> ([Annotation], Variable)
forall a b. (a -> b) -> a -> b
$ c Attachment -> Attachment
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract c Attachment
carrier