module IR.Constraint.Typechecking (
  typecheckProgram,
) where

import qualified Common.Compiler as Compiler
import Common.Pretty (hardline, pretty)
import qualified IR.Constraint.Canonical as Can
import qualified IR.Constraint.Constrain as Constrain
import qualified IR.Constraint.Elaborate as Elaborate
import IR.Constraint.Monad (
  mkTCState,
  runTC,
 )
import qualified IR.Constraint.Solve as Solve
import qualified IR.IR as I
import IR.Pretty ()

import IR.Constraint.PrettyPrint (
  Doc,
  printConstraint,
 )

import IR.Constraint.Type as Typ (
  toCanType,
 )

import Prettyprinter (
  surround,
  vsep,
 )

import Control.Monad.Writer (
  tell,
  when,
 )


typecheckProgram ::
  I.Program Can.Annotations -> Bool -> Compiler.Pass (I.Program Can.Type)
typecheckProgram :: Program Annotations -> Bool -> Pass (Program Type)
typecheckProgram Program Annotations
pAnn Bool
False = do
  let (Either Error (Program Type)
result, Doc String
_) = Program Annotations
-> Bool -> (Either Error (Program Type), Doc String)
unsafeTypecheckProgram Program Annotations
pAnn Bool
False
  case Either Error (Program Type)
result of
    Left Error
e -> Error -> Pass (Program Type)
forall e (m :: * -> *) a. MonadError e m => e -> m a
Compiler.throwError Error
e
    Right Program Type
pType -> Program Type -> Pass (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Program Type
pType
typecheckProgram Program Annotations
pAnn Bool
True = do
  let (Either Error (Program Type)
result, Doc String
pp) = Program Annotations
-> Bool -> (Either Error (Program Type), Doc String)
unsafeTypecheckProgram Program Annotations
pAnn Bool
True
  case Either Error (Program Type)
result of
    Left Error
e -> do
      -- Note: Provisionally misusing exception/error by showing it and appending it to dump, open to revision.
      String -> Pass (Program Type)
forall a x. Pretty a => a -> Pass x
Compiler.dump (String -> Pass (Program Type)) -> String -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ Doc String -> String
forall a. Show a => a -> String
show Doc String
pp String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Error -> String
forall a. Show a => a -> String
show Error
e 
    Right Program Type
_ -> String -> Pass (Program Type)
forall a x. Pretty a => a -> Pass x
Compiler.dump (String -> Pass (Program Type)) -> String -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ Doc String -> String
forall a. Show a => a -> String
show Doc String
pp


unsafeTypecheckProgram ::
  I.Program Can.Annotations -> Bool -> (Either Compiler.Error (I.Program Can.Type), Doc String)
unsafeTypecheckProgram :: Program Annotations
-> Bool -> (Either Error (Program Type), Doc String)
unsafeTypecheckProgram Program Annotations
pAnn Bool
prettyprint = TCState
-> TC (Program Type) -> (Either Error (Program Type), Doc String)
forall a. TCState -> TC a -> (Either Error a, Doc String)
runTC (Program Annotations -> TCState
mkTCState Program Annotations
pAnn) (TC (Program Type) -> (Either Error (Program Type), Doc String))
-> TC (Program Type) -> (Either Error (Program Type), Doc String)
forall a b. (a -> b) -> a -> b
$ do
  (Constraint
constraint, Program Variable
pVar) <- Program Annotations -> TC (Constraint, Program Variable)
Constrain.run Program Annotations
pAnn

  -- Get IORefs in the program IR, then type variable names
  let refs :: [Variable]
refs = (Variable -> [Variable] -> [Variable])
-> [Variable] -> Program Variable -> [Variable]
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (:) [] Program Variable
pVar
      vars :: StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
vars = (Variable
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Type)
-> [Variable]
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Type
toCanType [Variable]
refs

  -- Pretty-printing separator
  let hrule :: Doc ann
hrule = Doc ann
forall ann. Doc ann
hardline Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
hardline Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (Int -> Char -> String
forall a. Int -> a -> [a]
replicate Int
20 Char
'-') Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
hardline Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
forall ann. Doc ann
hardline

  -- Log the pretty-printed constraints and program
  Bool
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
prettyprint (StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ())
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall a b. (a -> b) -> a -> b
$ do
    -- Convert IORef Variables to printable "type variables" embedded in IR
    Doc String
pIR <- Program Type -> Doc String
forall a ann. Pretty a => a -> Doc ann
pretty (Program Type -> Doc String)
-> TC (Program Type)
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (Doc String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Variable
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Type)
-> Program Variable -> TC (Program Type)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Variable
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) Type
toCanType Program Variable
pVar
    Doc String
pConstraint <- Constraint
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (Doc String)
forall ann. Constraint -> TC (Doc ann)
printConstraint Constraint
constraint

    Doc String
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Doc String
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ())
-> Doc String
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall a b. (a -> b) -> a -> b
$ Doc String
pConstraint Doc String -> Doc String -> Doc String
forall a. Semigroup a => a -> a -> a
<> Doc String
forall ann. Doc ann
hrule Doc String -> Doc String -> Doc String
forall a. Semigroup a => a -> a -> a
<> Doc String
pIR

  -- Depends on being called before solve
  -- Gets the prettified version of the type variable names
  [Doc String]
namesDoc <- (Type -> Doc String) -> [Type] -> [Doc String]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc String
forall a ann. Pretty a => a -> Doc ann
pretty ([Type] -> [Doc String])
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Doc String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
vars

  -- Runs constraint solver
  Constraint
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
Solve.run Constraint
constraint

  -- Depends on being called after solve
  -- Gets the prettified version of the unification results
  [Doc String]
resolutionDoc <- (Type -> Doc String) -> [Type] -> [Doc String]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Doc String
forall a ann. Pretty a => a -> Doc ann
pretty ([Type] -> [Doc String])
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) [Doc String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) [Type]
vars

  -- Document containing the mapping from flex vars to types
  let mappingDoc :: Doc String
mappingDoc = [Doc String] -> Doc String
forall ann. [Doc ann] -> Doc ann
vsep ([Doc String] -> Doc String) -> [Doc String] -> Doc String
forall a b. (a -> b) -> a -> b
$ (Doc String -> Doc String -> Doc String)
-> [Doc String] -> [Doc String] -> [Doc String]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Doc String -> Doc String -> Doc String -> Doc String
forall ann. Doc ann -> Doc ann -> Doc ann -> Doc ann
surround (String -> Doc String
forall a ann. Pretty a => a -> Doc ann
pretty String
" = ")) [Doc String]
namesDoc [Doc String]
resolutionDoc
      finalDoc :: Doc String
finalDoc = Doc String
forall ann. Doc ann
hrule Doc String -> Doc String -> Doc String
forall a. Semigroup a => a -> a -> a
<> Doc String
mappingDoc

  -- Log the unification results
  Bool
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
prettyprint (StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
 -> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ())
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall a b. (a -> b) -> a -> b
$ Doc String
-> StateT TCState (ExceptT Error (WriterT (Doc String) IO)) ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Doc String
finalDoc

  Program Type
program <- Program Variable -> TC (Program Type)
Elaborate.run Program Variable
pVar

  Program Type -> TC (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Program Type
program