{-# LANGUAGE Rank2Types #-}

module IR.Constraint.Unify where

import qualified IR.Constraint.Error as ET
import IR.Constraint.Monad (TC)
import IR.Constraint.Type as Type
import qualified IR.Constraint.UnionFind as UF


-- | UNIFY
data Answer
  = Ok
  | Err ET.Type ET.Type


unify :: Variable -> Variable -> TC Answer
unify :: Variable -> Variable -> TC Answer
unify Variable
v1 Variable
v2 = case Variable -> Variable -> Unify ()
guardedUnify Variable
v1 Variable
v2 of
  Unify forall r. (() -> TC r) -> (() -> TC r) -> TC r
k -> (() -> TC Answer) -> (() -> TC Answer) -> TC Answer
forall r. (() -> TC r) -> (() -> TC r) -> TC r
k () -> TC Answer
onSuccess ((() -> TC Answer) -> TC Answer) -> (() -> TC Answer) -> TC Answer
forall a b. (a -> b) -> a -> b
$ \() -> do
    Type
t1 <- Variable -> TC Type
Type.toErrorType Variable
v1
    Type
t2 <- Variable -> TC Type
Type.toErrorType Variable
v2
    Variable
-> Variable
-> Descriptor
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall (m :: * -> *) a.
(MonadIO m, MonadFail m) =>
Point a -> Point a -> a -> m ()
UF.union Variable
v1 Variable
v2 Descriptor
errorDescriptor
    Answer -> TC Answer
forall (m :: * -> *) a. Monad m => a -> m a
return (Answer -> TC Answer) -> Answer -> TC Answer
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Answer
Err Type
t1 Type
t2


onSuccess :: () -> TC Answer
onSuccess :: () -> TC Answer
onSuccess () = Answer -> TC Answer
forall (m :: * -> *) a. Monad m => a -> m a
return Answer
Ok


errorDescriptor :: Descriptor
errorDescriptor :: Descriptor
errorDescriptor = Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
Error Int
noRank Mark
noMark Maybe Variable
forall a. Maybe a
Nothing


-- | CPS style UNIFIER
newtype Unify a
  = Unify
      ( forall r
         . (a -> TC r)
        -> (() -> TC r)
        -> TC r
      )


mismatch :: Unify a
mismatch :: Unify a
mismatch = (forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a
forall a.
(forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a
Unify ((forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a)
-> (forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a
forall a b. (a -> b) -> a -> b
$ \a -> TC r
_ () -> TC r
err -> () -> TC r
err ()


-- | UNIFICATION HELPERS
data Context = Context
  { Context -> Variable
_first :: Variable
  , Context -> Descriptor
_firstDesc :: Descriptor
  , Context -> Variable
_second :: Variable
  , Context -> Descriptor
_secondDesc :: Descriptor
  }


-- | MERGE
merge :: Context -> Content -> Unify ()
merge :: Context -> Content -> Unify ()
merge (Context Variable
var1 (Descriptor Content
_ Int
rank1 Mark
_ Maybe Variable
_) Variable
var2 (Descriptor Content
_ Int
rank2 Mark
_ Maybe Variable
_)) Content
content =
  (forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ()
forall a.
(forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a
Unify ((forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ())
-> (forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ()
forall a b. (a -> b) -> a -> b
$ \() -> TC r
ok () -> TC r
_ ->
    () -> TC r
ok
      (() -> TC r)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
-> TC r
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Variable
-> Variable
-> Descriptor
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall (m :: * -> *) a.
(MonadIO m, MonadFail m) =>
Point a -> Point a -> a -> m ()
UF.union
        Variable
var1
        Variable
var2
        (Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
rank1 Int
rank2) Mark
noMark Maybe Variable
forall a. Maybe a
Nothing)


-- | ACTUALLY UNIFY
guardedUnify :: Variable -> Variable -> Unify ()
guardedUnify :: Variable -> Variable -> Unify ()
guardedUnify Variable
left Variable
right = (forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ()
forall a.
(forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a
Unify ((forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ())
-> (forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ()
forall a b. (a -> b) -> a -> b
$ \() -> TC r
ok () -> TC r
err -> do
  Bool
equivalent <- Variable
-> Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Bool
forall (m :: * -> *) a. MonadIO m => Point a -> Point a -> m Bool
UF.equivalent Variable
left Variable
right
  if Bool
equivalent
    then () -> TC r
ok ()
    else do
      Descriptor
leftDesc <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
left
      Descriptor
rightDesc <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
right
      case Context -> Unify ()
actuallyUnify (Variable -> Descriptor -> Variable -> Descriptor -> Context
Context Variable
left Descriptor
leftDesc Variable
right Descriptor
rightDesc) of
        Unify forall r. (() -> TC r) -> (() -> TC r) -> TC r
k -> (() -> TC r) -> (() -> TC r) -> TC r
forall r. (() -> TC r) -> (() -> TC r) -> TC r
k () -> TC r
ok () -> TC r
err


subUnify :: Variable -> Variable -> Unify ()
subUnify :: Variable -> Variable -> Unify ()
subUnify Variable
var1 Variable
var2 = Variable -> Variable -> Unify ()
guardedUnify Variable
var1 Variable
var2


actuallyUnify :: Context -> Unify ()
actuallyUnify :: Context -> Unify ()
actuallyUnify context :: Context
context@(Context Variable
_ (Descriptor Content
firstContent Int
_ Mark
_ Maybe Variable
_) Variable
_ (Descriptor Content
secondContent Int
_ Mark
_ Maybe Variable
_)) =
  case Content
firstContent of
    FlexVar TVarId
_ -> Context -> Content -> Content -> Unify ()
unifyFlex Context
context Content
firstContent Content
secondContent
    RigidVar TVarId
_ -> Context -> Content -> Content -> Unify ()
unifyRigid Context
context Content
firstContent Content
secondContent
    Structure FlatType
flatType ->
      Context -> FlatType -> Content -> Content -> Unify ()
unifyStructure Context
context FlatType
flatType Content
firstContent Content
secondContent
    Content
Error ->
      -- If there was an error, just pretend it is okay. This lets us avoid
      -- "cascading" errors where one problem manifests as multiple message.
      Context -> Content -> Unify ()
merge Context
context Content
Error


-- UNIFY FLEXIBLE VARIABLES

unifyFlex :: Context -> Content -> Content -> Unify ()
unifyFlex :: Context -> Content -> Content -> Unify ()
unifyFlex Context
context Content
_ Content
otherContent = case Content
otherContent of
  Content
Error -> Context -> Content -> Unify ()
merge Context
context Content
Error
  FlexVar TVarId
_ -> Context -> Content -> Unify ()
merge Context
context Content
otherContent
  RigidVar TVarId
_ -> Context -> Content -> Unify ()
merge Context
context Content
otherContent
  Structure FlatType
_ -> Context -> Content -> Unify ()
merge Context
context Content
otherContent


-- UNIFY RIGID VARIABLES

unifyRigid :: Context -> Content -> Content -> Unify ()
unifyRigid :: Context -> Content -> Content -> Unify ()
unifyRigid Context
context Content
content Content
otherContent = case Content
otherContent of
  FlexVar TVarId
_ -> Context -> Content -> Unify ()
merge Context
context Content
content
  RigidVar TVarId
_ -> Unify ()
forall a. Unify a
mismatch
  Structure FlatType
_ -> Unify ()
forall a. Unify a
mismatch
  Content
Error -> Context -> Content -> Unify ()
merge Context
context Content
Error


-- UNIFY STRUCTURES

unifyStructure :: Context -> FlatType -> Content -> Content -> Unify ()
unifyStructure :: Context -> FlatType -> Content -> Content -> Unify ()
unifyStructure Context
context FlatType
flatType Content
content Content
otherContent = case Content
otherContent of
  FlexVar TVarId
_ -> Context -> Content -> Unify ()
merge Context
context Content
content
  RigidVar TVarId
_ -> Unify ()
forall a. Unify a
mismatch
  Structure FlatType
otherFlatType -> case (FlatType
flatType, FlatType
otherFlatType) of
    (TCon1 TConId
name [Variable]
args, TCon1 TConId
otherName [Variable]
otherArgs) | TConId
name TConId -> TConId -> Bool
forall a. Eq a => a -> a -> Bool
== TConId
otherName ->
      (forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ()
forall a.
(forall r. (a -> TC r) -> (() -> TC r) -> TC r) -> Unify a
Unify ((forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ())
-> (forall r. (() -> TC r) -> (() -> TC r) -> TC r) -> Unify ()
forall a b. (a -> b) -> a -> b
$ \() -> TC r
ok () -> TC r
err ->
        let ok1 :: () -> TC r
ok1 () = case Context -> Content -> Unify ()
merge Context
context Content
otherContent of
              Unify forall r. (() -> TC r) -> (() -> TC r) -> TC r
k -> (() -> TC r) -> (() -> TC r) -> TC r
forall r. (() -> TC r) -> (() -> TC r) -> TC r
k () -> TC r
ok () -> TC r
err
         in Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
forall r.
Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
unifyArgs Context
context [Variable]
args [Variable]
otherArgs () -> TC r
ok1 () -> TC r
err
    (FlatType, FlatType)
_ -> Unify ()
forall a. Unify a
mismatch
  Content
Error -> Context -> Content -> Unify ()
merge Context
context Content
Error


-- UNIFY ARGS

unifyArgs
  :: Context -> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
unifyArgs :: Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
unifyArgs Context
context (Variable
arg1 : [Variable]
others1) (Variable
arg2 : [Variable]
others2) () -> TC r
ok () -> TC r
err =
  case Variable -> Variable -> Unify ()
subUnify Variable
arg1 Variable
arg2 of
    Unify forall r. (() -> TC r) -> (() -> TC r) -> TC r
k ->
      (() -> TC r) -> (() -> TC r) -> TC r
forall r. (() -> TC r) -> (() -> TC r) -> TC r
k
        (\() -> Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
forall r.
Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
unifyArgs Context
context [Variable]
others1 [Variable]
others2 () -> TC r
ok () -> TC r
err)
        (\() -> Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
forall r.
Context
-> [Variable] -> [Variable] -> (() -> TC r) -> (() -> TC r) -> TC r
unifyArgs Context
context [Variable]
others1 [Variable]
others2 () -> TC r
err () -> TC r
err)
unifyArgs Context
_ [] [] () -> TC r
ok () -> TC r
_ = () -> TC r
ok ()
unifyArgs Context
_ [Variable]
_ [Variable]
_ () -> TC r
_ () -> TC r
err = () -> TC r
err ()