module IR.Constraint.Solve where

import qualified Common.Identifiers as Ident
import Control.Monad (
  foldM,
  forM_,
 )
import Data.Map.Strict ((!))
import qualified Data.Map.Strict as Map
import qualified Data.Vector as Vector
import qualified Data.Vector.Mutable as MVector
import qualified IR.Constraint.Canonical as Can
import qualified IR.Constraint.Error as ET
import IR.Constraint.Monad (
  TC,
  throwError,
 )
import qualified IR.Constraint.Occurs as Occurs
import IR.Constraint.Type as Type
import qualified IR.Constraint.Unify as Unify
import qualified IR.Constraint.UnionFind as UF


-- | RUN SOLVER
run :: Constraint -> TC ()
run :: Constraint -> TC ()
run Constraint
constraint = do
  Pools
pools <- Int
-> [Variable]
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (MVector
        (PrimState
           (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
        [Variable])
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MVector.replicate Int
8 []
  (State Env
_ Mark
_ [Error]
errors) <-
    Env -> Int -> Pools -> State -> Constraint -> TC State
solve
      Env
forall k a. Map k a
Map.empty
      Int
outermostRank
      Pools
pools
      State
emptyState
      Constraint
constraint
  -- throw all the errors
  -- if null errors then return $ Right () else return $ Left errors
  case [Error]
errors of
    [] -> () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    [Error]
errs -> String -> TC ()
forall a. String -> TC a
throwError ((Error -> String) -> [Error] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Error -> String
ET.toErrorString [Error]
errs)


emptyState :: State
emptyState :: State
emptyState = Env -> Mark -> [Error] -> State
State Env
forall k a. Map k a
Map.empty (Mark -> Mark
nextMark Mark
noMark) []


-- | SOLVER
type Env = Map.Map Ident.Identifier Variable


type Pools = MVector.IOVector [Variable]


data State = State
  { State -> Env
_env :: Env
  , State -> Mark
_mark :: Mark
  , State -> [Error]
_errors :: [ET.Error]
  }


solve :: Env -> Int -> Pools -> State -> Constraint -> TC State
solve :: Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
env Int
rank Pools
pools State
state Constraint
constraint = case Constraint
constraint of
  Constraint
CTrue -> State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return State
state
  Constraint
CSaveTheEnvironment -> State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return (State
state{_env :: Env
_env = Env
env})
  CEqual Type
tipe Type
expectation -> do
    Variable
actual <- Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools Type
tipe
    Variable
expected <- Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools Type
expectation
    Answer
answer <- Variable -> Variable -> TC Answer
Unify.unify Variable
actual Variable
expected
    case Answer
answer of
      Answer
Unify.Ok -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return State
state
      Unify.Err Type
actualType Type
expectedType -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> TC State) -> State -> TC State
forall a b. (a -> b) -> a -> b
$ State -> Error -> State
addError State
state (Error -> State) -> Error -> State
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Error
ET.BadExpr Type
actualType Type
expectedType
  CLocal Identifier
name Type
expectation -> do
    Variable
actual <- Int -> Pools -> Variable -> TC Variable
makeCopy Int
rank Pools
pools (Env
env Env -> Identifier -> Variable
forall k a. Ord k => Map k a -> k -> a
! Identifier
name)
    Variable
expected <- Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools Type
expectation
    Answer
answer <- Variable -> Variable -> TC Answer
Unify.unify Variable
actual Variable
expected
    case Answer
answer of
      Answer
Unify.Ok -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return State
state
      Unify.Err Type
actualType Type
expectedType -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> TC State) -> State -> TC State
forall a b. (a -> b) -> a -> b
$ State -> Error -> State
addError State
state (Error -> State) -> Error -> State
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Error
ET.BadExpr Type
actualType Type
expectedType
  CForeign (Can.Forall FreeVars
freeVars Type
srcType) Type
expectation -> do
    Variable
actual <- Int -> Pools -> FreeVars -> Type -> TC Variable
schemeToVariable Int
rank Pools
pools FreeVars
freeVars Type
srcType
    Variable
expected <- Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools Type
expectation
    Answer
answer <- Variable -> Variable -> TC Answer
Unify.unify Variable
actual Variable
expected
    case Answer
answer of
      Answer
Unify.Ok -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return State
state
      Unify.Err Type
actualType Type
expectedType -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> TC State) -> State -> TC State
forall a b. (a -> b) -> a -> b
$ State -> Error -> State
addError State
state (Error -> State) -> Error -> State
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Error
ET.BadExpr Type
actualType Type
expectedType
  CPattern Type
tipe Type
expectation -> do
    Variable
actual <- Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools Type
tipe
    Variable
expected <- Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools Type
expectation
    Answer
answer <- Variable -> Variable -> TC Answer
Unify.unify Variable
actual Variable
expected
    case Answer
answer of
      Answer
Unify.Ok -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return State
state
      Unify.Err Type
actualType Type
expectedType -> do
        -- introduce rank pools vars
        State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> TC State) -> State -> TC State
forall a b. (a -> b) -> a -> b
$ State -> Error -> State
addError State
state (Error -> State) -> Error -> State
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Error
ET.BadPattern Type
actualType Type
expectedType
  CAnd [Constraint]
constraints -> (State -> Constraint -> TC State)
-> State -> [Constraint] -> TC State
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
env Int
rank Pools
pools) State
state [Constraint]
constraints
  CLet [] [Variable]
flexs Map Identifier Type
_ Constraint
headerCon Constraint
CTrue -> do
    Int -> Pools -> [Variable] -> TC ()
introduce Int
rank Pools
pools [Variable]
flexs
    Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
env Int
rank Pools
pools State
state Constraint
headerCon
  CLet [] [] Map Identifier Type
header Constraint
headerCon Constraint
subCon -> do
    State
state1 <- Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
env Int
rank Pools
pools State
state Constraint
headerCon
    Env
locals <- (Type -> TC Variable)
-> Map Identifier Type
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Env
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools) Map Identifier Type
header
    let newEnv :: Env
newEnv = Env -> Env -> Env
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Env
locals Env
env
    State
state2 <- Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
newEnv Int
rank Pools
pools State
state1 Constraint
subCon
    (State -> (Identifier, Variable) -> TC State)
-> State -> [(Identifier, Variable)] -> TC State
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM State -> (Identifier, Variable) -> TC State
occurs State
state2 ([(Identifier, Variable)] -> TC State)
-> [(Identifier, Variable)] -> TC State
forall a b. (a -> b) -> a -> b
$ Env -> [(Identifier, Variable)]
forall k a. Map k a -> [(k, a)]
Map.toList Env
locals
  CLet [Variable]
rigids [Variable]
flexs Map Identifier Type
header Constraint
headerCon Constraint
subCon -> do
    -- work in the next pool to localize header
    let nextRank :: Int
nextRank = Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    let poolsLength :: Int
poolsLength = Pools -> Int
forall s a. MVector s a -> Int
MVector.length Pools
pools
    Pools
nextPools <-
      if Int
nextRank Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
poolsLength
        then Pools
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Pools
forall (m :: * -> *) a. Monad m => a -> m a
return Pools
pools
        else MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> Int
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (MVector
        (PrimState
           (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
        [Variable])
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m (MVector (PrimState m) a)
MVector.grow Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools Int
poolsLength

    -- introduce variables
    let vars :: [Variable]
vars = [Variable]
rigids [Variable] -> [Variable] -> [Variable]
forall a. [a] -> [a] -> [a]
++ [Variable]
flexs
    [Variable] -> (Variable -> TC ()) -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Variable]
vars ((Variable -> TC ()) -> TC ()) -> (Variable -> TC ()) -> TC ()
forall a b. (a -> b) -> a -> b
$ \Variable
var -> Variable -> (Descriptor -> Descriptor) -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> (a -> a) -> m ()
UF.modify Variable
var ((Descriptor -> Descriptor) -> TC ())
-> (Descriptor -> Descriptor) -> TC ()
forall a b. (a -> b) -> a -> b
$ \(Descriptor Content
content Int
_ Mark
mark Maybe Variable
copy) ->
      Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
nextRank Mark
mark Maybe Variable
copy
    MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> Int -> [Variable] -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.write Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
nextPools Int
nextRank [Variable]
vars

    -- run solver in next pool
    Env
locals <- (Type -> TC Variable)
-> Map Identifier Type
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Env
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Int -> Pools -> Type -> TC Variable
typeToVariable Int
nextRank Pools
nextPools) Map Identifier Type
header
    (State Env
savedEnv Mark
mark [Error]
errors) <- Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
env Int
nextRank Pools
nextPools State
state Constraint
headerCon

    let youngMark :: Mark
youngMark = Mark
mark
    let visitMark :: Mark
visitMark = Mark -> Mark
nextMark Mark
youngMark
    let finalMark :: Mark
finalMark = Mark -> Mark
nextMark Mark
visitMark

    -- pop pool
    Mark -> Mark -> Int -> Pools -> TC ()
generalize Mark
youngMark Mark
visitMark Int
nextRank Pools
nextPools
    MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> Int -> [Variable] -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> a -> m ()
MVector.write Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
nextPools Int
nextRank []

    -- check that things went well
    (Variable -> TC ()) -> [Variable] -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Variable -> TC ()
isGeneric [Variable]
rigids

    let newEnv :: Env
newEnv = Env -> Env -> Env
forall k a. Ord k => Map k a -> Map k a -> Map k a
Map.union Env
locals Env
env
    let tempState :: State
tempState = Env -> Mark -> [Error] -> State
State Env
savedEnv Mark
finalMark [Error]
errors
    State
newState <- Env -> Int -> Pools -> State -> Constraint -> TC State
solve Env
newEnv Int
rank Pools
nextPools State
tempState Constraint
subCon

    (State -> (Identifier, Variable) -> TC State)
-> State -> [(Identifier, Variable)] -> TC State
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM State -> (Identifier, Variable) -> TC State
occurs State
newState (Env -> [(Identifier, Variable)]
forall k a. Map k a -> [(k, a)]
Map.toList Env
locals)


-- Check that a variable has rank == noRank, meaning that it can be generalized
isGeneric :: Variable -> TC ()
isGeneric :: Variable -> TC ()
isGeneric Variable
var = do
  (Descriptor Content
_ Int
rank Mark
_ Maybe Variable
_) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
var
  if Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
noRank
    then () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    else do
      String -> TC ()
forall a. HasCallStack => String -> a
error String
"Compiler bug: unification variable should be generic"


-- | ERROR HELPERS
addError :: State -> ET.Error -> State
addError :: State -> Error -> State
addError (State Env
savedEnv Mark
rank [Error]
errors) Error
err = Env -> Mark -> [Error] -> State
State Env
savedEnv Mark
rank (Error
err Error -> [Error] -> [Error]
forall a. a -> [a] -> [a]
: [Error]
errors)


-- | OCCURS CHECK
occurs :: State -> (Ident.Identifier, Variable) -> TC State
occurs :: State -> (Identifier, Variable) -> TC State
occurs State
state (Identifier
name, Variable
variable) = do
  Bool
hasOccurred <- Variable -> TC Bool
Occurs.occurs Variable
variable
  if Bool
hasOccurred
    then do
      Type
errorType <- Variable -> TC Type
Type.toErrorType Variable
variable
      (Descriptor Content
_ Int
rank Mark
mark Maybe Variable
copy) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
variable
      Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
variable (Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
Error Int
rank Mark
mark Maybe Variable
copy)
      State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return (State -> TC State) -> State -> TC State
forall a b. (a -> b) -> a -> b
$ State -> Error -> State
addError State
state (Error -> State) -> Error -> State
forall a b. (a -> b) -> a -> b
$ Identifier -> Type -> Error
ET.InfiniteType Identifier
name Type
errorType
    else State -> TC State
forall (m :: * -> *) a. Monad m => a -> m a
return State
state


-- | GENERALIZE

{- | Every variable has rank less than or equal to the maxRank of the pool.
This sorts variables into the young and old pools accordingly.
-}
generalize :: Mark -> Mark -> Int -> Pools -> TC ()
generalize :: Mark -> Mark -> Int -> Pools -> TC ()
generalize Mark
youngMark Mark
visitMark Int
youngRank Pools
pools = do
  [Variable]
youngVars <- MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> Int
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Variable]
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> Int -> m a
MVector.read Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools Int
youngRank
  Vector [Variable]
rankTable <- Mark -> Int -> [Variable] -> TC (Vector [Variable])
poolToRankTable Mark
youngMark Int
youngRank [Variable]
youngVars

  -- get the ranks right for each entry.
  -- start at low ranks so that we only have to pass
  -- over the information once.
  (Int -> [Variable] -> TC ()) -> Vector [Variable] -> TC ()
forall (m :: * -> *) a b.
Monad m =>
(Int -> a -> m b) -> Vector a -> m ()
Vector.imapM_
    ((Variable
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int)
-> [Variable] -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Variable
  -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int)
 -> [Variable] -> TC ())
-> (Int
    -> Variable
    -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int)
-> Int
-> [Variable]
-> TC ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Mark
-> Mark
-> Int
-> Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
adjustRank Mark
youngMark Mark
visitMark)
    Vector [Variable]
rankTable

  -- For variables that have rank lowerer than youngRank, register them in
  -- the appropriate old pool if they are not redundant.
  Vector [Variable] -> ([Variable] -> TC ()) -> TC ()
forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
Vector.forM_ (Vector [Variable] -> Vector [Variable]
forall a. Vector a -> Vector a
Vector.unsafeInit Vector [Variable]
rankTable) (([Variable] -> TC ()) -> TC ()) -> ([Variable] -> TC ()) -> TC ()
forall a b. (a -> b) -> a -> b
$ \[Variable]
vars -> [Variable] -> (Variable -> TC ()) -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Variable]
vars ((Variable -> TC ()) -> TC ()) -> (Variable -> TC ()) -> TC ()
forall a b. (a -> b) -> a -> b
$ \Variable
var -> do
    Bool
isRedundant <- Variable -> TC Bool
forall (m :: * -> *) a. MonadIO m => Point a -> m Bool
UF.redundant Variable
var
    if Bool
isRedundant
      then () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      else do
        (Descriptor Content
_ Int
rank Mark
_ Maybe Variable
_) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
var
        MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools (Variable
var Variable -> [Variable] -> [Variable]
forall a. a -> [a] -> [a]
:) Int
rank

  -- For variables with rank youngRank
  --   If rank < youngRank: register in oldPool
  --   otherwise generalize
  [Variable] -> (Variable -> TC ()) -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Vector [Variable] -> [Variable]
forall a. Vector a -> a
Vector.unsafeLast Vector [Variable]
rankTable) ((Variable -> TC ()) -> TC ()) -> (Variable -> TC ()) -> TC ()
forall a b. (a -> b) -> a -> b
$ \Variable
var -> do
    Bool
isRedundant <- Variable -> TC Bool
forall (m :: * -> *) a. MonadIO m => Point a -> m Bool
UF.redundant Variable
var
    if Bool
isRedundant
      then () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      else do
        (Descriptor Content
content Int
rank Mark
mark Maybe Variable
copy) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
var
        if Int
rank Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
youngRank
          then MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools (Variable
var Variable -> [Variable] -> [Variable]
forall a. a -> [a] -> [a]
:) Int
rank
          else Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
var (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
noRank Mark
mark Maybe Variable
copy


poolToRankTable :: Mark -> Int -> [Variable] -> TC (Vector.Vector [Variable])
poolToRankTable :: Mark -> Int -> [Variable] -> TC (Vector [Variable])
poolToRankTable Mark
youngMark Int
youngRank [Variable]
youngInhabitants = do
  Pools
mutableTable <- Int
-> [Variable]
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (MVector
        (PrimState
           (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
        [Variable])
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MVector.replicate (Int
youngRank Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) []

  -- Sort the youngPool variables into buckets by rank.
  [Variable] -> (Variable -> TC ()) -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Variable]
youngInhabitants ((Variable -> TC ()) -> TC ()) -> (Variable -> TC ()) -> TC ()
forall a b. (a -> b) -> a -> b
$ \Variable
var -> do
    (Descriptor Content
content Int
rank Mark
_ Maybe Variable
copy) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
var
    Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
var (Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
rank Mark
youngMark Maybe Variable
copy)
    MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
mutableTable (Variable
var Variable -> [Variable] -> [Variable]
forall a. a -> [a] -> [a]
:) Int
rank

  MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> TC (Vector [Variable])
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
Vector.unsafeFreeze Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
mutableTable


-- | ADJUST RANK

--
-- Adjust variable ranks such that ranks never increase as you move deeper.
-- This way the outermost rank is representative of the entire structure.
--
adjustRank :: Mark -> Mark -> Int -> Variable -> TC Int
adjustRank :: Mark
-> Mark
-> Int
-> Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
adjustRank Mark
youngMark Mark
visitMark Int
groupRank Variable
var = do
  (Descriptor Content
content Int
rank Mark
mark Maybe Variable
copy) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
var
  if Mark
mark Mark -> Mark -> Bool
forall a. Eq a => a -> a -> Bool
== Mark
youngMark
    then do
      -- Set the variable as marked first because it may be cyclic.
      Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
var (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
rank Mark
visitMark Maybe Variable
copy
      Int
maxRank <- Mark
-> Mark
-> Int
-> Content
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
adjustRankContent Mark
youngMark Mark
visitMark Int
groupRank Content
content
      Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
var (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
maxRank Mark
visitMark Maybe Variable
copy
      Int -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
maxRank
    else
      if Mark
mark Mark -> Mark -> Bool
forall a. Eq a => a -> a -> Bool
== Mark
visitMark
        then Int -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
rank
        else do
          let minRank :: Int
minRank = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
groupRank Int
rank
          -- TODO how can minRank ever be groupRank?
          Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
var (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
minRank Mark
visitMark Maybe Variable
copy
          Int -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
minRank


adjustRankContent :: Mark -> Mark -> Int -> Content -> TC Int
adjustRankContent :: Mark
-> Mark
-> Int
-> Content
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
adjustRankContent Mark
youngMark Mark
visitMark Int
groupRank Content
content =
  let go :: Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
go = Mark
-> Mark
-> Int
-> Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
adjustRank Mark
youngMark Mark
visitMark Int
groupRank
   in case Content
content of
        FlexVar TVarId
_ -> Int -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
groupRank
        RigidVar TVarId
_ -> Int -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
groupRank
        Structure FlatType
flatType -> case FlatType
flatType of
          TCon1 TConId
_ [Variable]
args ->
            (Int
 -> Variable
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int)
-> Int
-> [Variable]
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Int
rank Variable
arg -> Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
rank (Int -> Int)
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
go Variable
arg) Int
outermostRank [Variable]
args
        Content
Error -> Int -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Int
forall (m :: * -> *) a. Monad m => a -> m a
return Int
groupRank


-- | REGISTER VARIABLES
introduce :: Int -> Pools -> [Variable] -> TC ()
introduce :: Int -> Pools -> [Variable] -> TC ()
introduce Int
rank Pools
pools [Variable]
variables = do
  MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools ([Variable]
variables [Variable] -> [Variable] -> [Variable]
forall a. [a] -> [a] -> [a]
++) Int
rank
  [Variable] -> (Variable -> TC ()) -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Variable]
variables ((Variable -> TC ()) -> TC ()) -> (Variable -> TC ()) -> TC ()
forall a b. (a -> b) -> a -> b
$ \Variable
var -> Variable -> (Descriptor -> Descriptor) -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> (a -> a) -> m ()
UF.modify Variable
var ((Descriptor -> Descriptor) -> TC ())
-> (Descriptor -> Descriptor) -> TC ()
forall a b. (a -> b) -> a -> b
$
    \(Descriptor Content
content Int
_ Mark
mark Maybe Variable
copy) -> Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
rank Mark
mark Maybe Variable
copy


-- | TYPE TO VARIABLE
typeToVariable :: Int -> Pools -> Type -> TC Variable
typeToVariable :: Int -> Pools -> Type -> TC Variable
typeToVariable Int
rank Pools
pools = Int -> Pools -> Map TVarId Variable -> Type -> TC Variable
typeToVar Int
rank Pools
pools Map TVarId Variable
forall k a. Map k a
Map.empty


typeToVar
  :: Int -> Pools -> Map.Map Ident.TVarId Variable -> Type -> TC Variable
typeToVar :: Int -> Pools -> Map TVarId Variable -> Type -> TC Variable
typeToVar Int
rank Pools
pools Map TVarId Variable
aliasDict Type
tipe =
  let go :: Type -> TC Variable
go = Int -> Pools -> Map TVarId Variable -> Type -> TC Variable
typeToVar Int
rank Pools
pools Map TVarId Variable
aliasDict
   in case Type
tipe of
        TVarN Variable
v -> Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
v
        TConN TConId
name [Type]
args -> do
          [Variable]
argVars <- (Type -> TC Variable)
-> [Type]
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Variable]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Type -> TC Variable
go [Type]
args
          Int -> Pools -> Content -> TC Variable
register Int
rank Pools
pools (FlatType -> Content
Structure (TConId -> [Variable] -> FlatType
TCon1 TConId
name [Variable]
argVars))


register :: Int -> Pools -> Content -> TC Variable
register :: Int -> Pools -> Content -> TC Variable
register Int
rank Pools
pools Content
content = do
  Variable
var <- Descriptor -> TC Variable
forall (m :: * -> *) a. MonadIO m => a -> m (Point a)
UF.fresh (Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
rank Mark
noMark Maybe Variable
forall a. Maybe a
Nothing)
  MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools (Variable
var Variable -> [Variable] -> [Variable]
forall a. a -> [a] -> [a]
:) Int
rank
  Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
var


-- SOURCE TYPE TO VARIABLE

schemeToVariable
  :: Int -> Pools -> Map.Map Ident.TVarId () -> Can.Type -> TC Variable
schemeToVariable :: Int -> Pools -> FreeVars -> Type -> TC Variable
schemeToVariable Int
rank Pools
pools FreeVars
freeVars Type
srcType =
  let nameToContent :: TVarId -> Content
nameToContent TVarId
name = TVarId -> Content
FlexVar TVarId
name

      makeVar :: TVarId -> p -> m Variable
makeVar TVarId
name p
_ =
        Descriptor -> m Variable
forall (m :: * -> *) a. MonadIO m => a -> m (Point a)
UF.fresh (Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor (TVarId -> Content
nameToContent TVarId
name) Int
rank Mark
noMark Maybe Variable
forall a. Maybe a
Nothing)
   in do
        Map TVarId Variable
flexVars <- (TVarId -> () -> TC Variable)
-> FreeVars
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (Map TVarId Variable)
forall (t :: * -> *) k a b.
Applicative t =>
(k -> a -> t b) -> Map k a -> t (Map k b)
Map.traverseWithKey TVarId -> () -> TC Variable
forall (m :: * -> *) p. MonadIO m => TVarId -> p -> m Variable
makeVar FreeVars
freeVars
        MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools (Map TVarId Variable -> [Variable]
forall k a. Map k a -> [a]
Map.elems Map TVarId Variable
flexVars [Variable] -> [Variable] -> [Variable]
forall a. [a] -> [a] -> [a]
++) Int
rank
        Int -> Pools -> Map TVarId Variable -> Type -> TC Variable
schemeToVar Int
rank Pools
pools Map TVarId Variable
flexVars Type
srcType


schemeToVar
  :: Int -> Pools -> Map.Map Ident.TVarId Variable -> Can.Type -> TC Variable
schemeToVar :: Int -> Pools -> Map TVarId Variable -> Type -> TC Variable
schemeToVar Int
rank Pools
pools Map TVarId Variable
flexVars Type
srcType =
  let go :: Type -> TC Variable
go = Int -> Pools -> Map TVarId Variable -> Type -> TC Variable
schemeToVar Int
rank Pools
pools Map TVarId Variable
flexVars
   in case Type
srcType of
        Can.TVar TVarId
name -> Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return (Map TVarId Variable
flexVars Map TVarId Variable -> TVarId -> Variable
forall k a. Ord k => Map k a -> k -> a
! TVarId
name)
        Can.TCon TConId
name [Type]
args -> do
          [Variable]
argVars <- (Type -> TC Variable)
-> [Type]
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Variable]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Type -> TC Variable
go [Type]
args
          Int -> Pools -> Content -> TC Variable
register Int
rank Pools
pools (FlatType -> Content
Structure (TConId -> [Variable] -> FlatType
TCon1 TConId
name [Variable]
argVars))


-- | COPY
makeCopy :: Int -> Pools -> Variable -> TC Variable
makeCopy :: Int -> Pools -> Variable -> TC Variable
makeCopy Int
rank Pools
pools Variable
var = do
  Variable
copy <- Int -> Pools -> Variable -> TC Variable
makeCopyHelp Int
rank Pools
pools Variable
var
  Variable -> TC ()
restore Variable
var
  Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
copy


makeCopyHelp :: Int -> Pools -> Variable -> TC Variable
makeCopyHelp :: Int -> Pools -> Variable -> TC Variable
makeCopyHelp Int
maxRank Pools
pools Variable
variable = do
  (Descriptor Content
content Int
rank Mark
_ Maybe Variable
maybeCopy) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
variable

  case Maybe Variable
maybeCopy of
    Just Variable
copy -> Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
copy
    Maybe Variable
Nothing ->
      if Int
rank Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
noRank
        then Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
variable
        else do
          let makeDescriptor :: Content -> Descriptor
makeDescriptor Content
c = Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
c Int
maxRank Mark
noMark Maybe Variable
forall a. Maybe a
Nothing
          Variable
copy <- Descriptor -> TC Variable
forall (m :: * -> *) a. MonadIO m => a -> m (Point a)
UF.fresh (Descriptor -> TC Variable) -> Descriptor -> TC Variable
forall a b. (a -> b) -> a -> b
$ Content -> Descriptor
makeDescriptor Content
content
          MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
-> ([Variable] -> [Variable]) -> Int -> TC ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MVector.modify Pools
MVector
  (PrimState
     (StateT TCState (ExceptT Error (WriterT (Doc String) IO))))
  [Variable]
pools (Variable
copy Variable -> [Variable] -> [Variable]
forall a. a -> [a] -> [a]
:) Int
maxRank

          -- Link the original variable to the new variable. This lets us
          -- avoid making multiple copies of the variable we are instantiating.
          --
          -- Need to do this before recursively copying to avoid looping.
          Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
variable (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
rank Mark
noMark (Variable -> Maybe Variable
forall a. a -> Maybe a
Just Variable
copy)

          -- Now we recursively copy the content of the variable.
          -- We have already marked the variable as copied, so we
          -- will not repeat this work or crawl this variable again.
          case Content
content of
            Structure FlatType
term -> do
              FlatType
newTerm <- (Variable -> TC Variable) -> FlatType -> TC FlatType
traverseFlatType (Int -> Pools -> Variable -> TC Variable
makeCopyHelp Int
maxRank Pools
pools) FlatType
term
              Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
copy (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Descriptor
makeDescriptor (FlatType -> Content
Structure FlatType
newTerm)
              Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
copy
            FlexVar TVarId
_ -> Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
copy
            RigidVar TVarId
name -> do
              Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
copy (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Descriptor
makeDescriptor (Content -> Descriptor) -> Content -> Descriptor
forall a b. (a -> b) -> a -> b
$ TVarId -> Content
FlexVar TVarId
name
              Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
copy
            Content
Error -> Variable -> TC Variable
forall (m :: * -> *) a. Monad m => a -> m a
return Variable
copy


-- RESTORE

restore :: Variable -> TC ()
restore :: Variable -> TC ()
restore Variable
variable = do
  (Descriptor Content
content Int
_ Mark
_ Maybe Variable
maybeCopy) <- Variable
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) Descriptor
forall (m :: * -> *) a. MonadIO m => Point a -> m a
UF.get Variable
variable
  case Maybe Variable
maybeCopy of
    Maybe Variable
Nothing -> () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Just Variable
_ -> do
      Variable -> Descriptor -> TC ()
forall (m :: * -> *) a. MonadIO m => Point a -> a -> m ()
UF.set Variable
variable (Descriptor -> TC ()) -> Descriptor -> TC ()
forall a b. (a -> b) -> a -> b
$ Content -> Int -> Mark -> Maybe Variable -> Descriptor
Descriptor Content
content Int
noRank Mark
noMark Maybe Variable
forall a. Maybe a
Nothing
      Content -> TC ()
restoreContent Content
content


restoreContent :: Content -> TC ()
restoreContent :: Content -> TC ()
restoreContent Content
content = case Content
content of
  FlexVar TVarId
_ -> () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  RigidVar TVarId
_ -> () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  Structure FlatType
term -> case FlatType
term of
    TCon1 TConId
_ [Variable]
args -> (Variable -> TC ()) -> [Variable] -> TC ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Variable -> TC ()
restore [Variable]
args
  Content
Error -> () -> TC ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()


--  | TRAVERSE FLAT TYPE

traverseFlatType :: (Variable -> TC Variable) -> FlatType -> TC FlatType
traverseFlatType :: (Variable -> TC Variable) -> FlatType -> TC FlatType
traverseFlatType Variable -> TC Variable
f FlatType
flatType = case FlatType
flatType of
  TCon1 TConId
name [Variable]
args -> ([Variable] -> FlatType)
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Variable]
-> TC FlatType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TConId -> [Variable] -> FlatType
TCon1 TConId
name) ((Variable -> TC Variable)
-> [Variable]
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Variable]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Variable -> TC Variable
f [Variable]
args)