module IR.Constraint.Constrain where

import Data.Bifunctor (bimap)
import qualified IR.Constraint.Canonical as Can
import qualified IR.Constraint.Constrain.Program as Prog
import IR.Constraint.Monad (TC)
import IR.Constraint.Type
import qualified IR.IR as I


run :: I.Program Can.Annotations -> TC (Constraint, I.Program Variable)
run :: Program Annotations -> TC (Constraint, Program Variable)
run Program Annotations
pAnn = do
  Program (Annotations, Variable)
pSprinkled <- Program Annotations -> TC (Program (Annotations, Variable))
sprinkleVariables Program Annotations
pAnn
  Constraint
constraint <- Program (Annotations, Variable) -> TC Constraint
Prog.constrain Program (Annotations, Variable)
pSprinkled
  let pVar :: Program Variable
pVar = Program (Annotations, Variable) -> Program Variable
discardAnnotations Program (Annotations, Variable)
pSprinkled
  (Constraint, Program Variable) -> TC (Constraint, Program Variable)
forall (m :: * -> *) a. Monad m => a -> m a
return (Constraint
constraint, Program Variable
pVar)


sprinkleVariables
  :: I.Program Can.Annotations -> TC (I.Program (Can.Annotations, Variable))
sprinkleVariables :: Program Annotations -> TC (Program (Annotations, Variable))
sprinkleVariables Program Annotations
prog = do
  [(Binder (Annotations, Variable), Expr (Annotations, Variable))]
sprinkledDefs <- ((Binder Annotations, Expr Annotations)
 -> StateT
      TCState
      (ExceptT Error (WriterT (Doc String) IO))
      (Binder (Annotations, Variable), Expr (Annotations, Variable)))
-> [(Binder Annotations, Expr Annotations)]
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     [(Binder (Annotations, Variable), Expr (Annotations, Variable))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Binder Annotations, Expr Annotations)
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (Binder (Annotations, Variable), Expr (Annotations, Variable))
forall (t :: * -> *) (t :: * -> *) a a.
(Traversable t, Traversable t) =>
(t a, t a)
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (t (a, Variable), t (a, Variable))
sprinkleDef (Program Annotations -> [(Binder Annotations, Expr Annotations)]
forall t. Program t -> [(Binder t, Expr t)]
I.programDefs Program Annotations
prog)
  Program (Annotations, Variable)
-> TC (Program (Annotations, Variable))
forall (m :: * -> *) a. Monad m => a -> m a
return Program Annotations
prog{programDefs :: [(Binder (Annotations, Variable), Expr (Annotations, Variable))]
I.programDefs = [(Binder (Annotations, Variable), Expr (Annotations, Variable))]
sprinkledDefs, symTable :: Map VarId (SymInfo (Annotations, Variable))
I.symTable = Map VarId (SymInfo (Annotations, Variable))
forall t. Map VarId (SymInfo t)
I.uninitializedSymTable}
 where
  sprinkleDef :: (t a, t a)
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (t (a, Variable), t (a, Variable))
sprinkleDef (t a
name, t a
expr) = do
    t (a, Variable)
name' <- (a
 -> StateT
      TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable))
-> t a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (t (a, Variable))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable)
forall a.
a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable)
sprinkle t a
name
    t (a, Variable)
expr' <- (a
 -> StateT
      TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable))
-> t a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (t (a, Variable))
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable)
forall a.
a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable)
sprinkle t a
expr
    (t (a, Variable), t (a, Variable))
-> StateT
     TCState
     (ExceptT Error (WriterT (Doc String) IO))
     (t (a, Variable), t (a, Variable))
forall (m :: * -> *) a. Monad m => a -> m a
return (t (a, Variable)
name', t (a, Variable)
expr')
  sprinkle :: a
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable)
sprinkle a
ann = do
    Variable
v <- TC Variable
mkIRFlexVar
    (a, Variable)
-> StateT
     TCState (ExceptT Error (WriterT (Doc String) IO)) (a, Variable)
forall (m :: * -> *) a. Monad m => a -> m a
return (a
ann, Variable
v)


discardAnnotations
  :: I.Program (Can.Annotations, Variable) -> I.Program Variable
discardAnnotations :: Program (Annotations, Variable) -> Program Variable
discardAnnotations Program (Annotations, Variable)
sprinkledProg =
  let discardedDefs :: [(Binder Variable, Expr Variable)]
discardedDefs = ((Binder (Annotations, Variable), Expr (Annotations, Variable))
 -> (Binder Variable, Expr Variable))
-> [(Binder (Annotations, Variable), Expr (Annotations, Variable))]
-> [(Binder Variable, Expr Variable)]
forall a b. (a -> b) -> [a] -> [b]
map ((Binder (Annotations, Variable) -> Binder Variable)
-> (Expr (Annotations, Variable) -> Expr Variable)
-> (Binder (Annotations, Variable), Expr (Annotations, Variable))
-> (Binder Variable, Expr Variable)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap (((Annotations, Variable) -> Variable)
-> Binder (Annotations, Variable) -> Binder Variable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Annotations, Variable) -> Variable
forall a b. (a, b) -> b
snd) (((Annotations, Variable) -> Variable)
-> Expr (Annotations, Variable) -> Expr Variable
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Annotations, Variable) -> Variable
forall a b. (a, b) -> b
snd)) (Program (Annotations, Variable)
-> [(Binder (Annotations, Variable), Expr (Annotations, Variable))]
forall t. Program t -> [(Binder t, Expr t)]
I.programDefs Program (Annotations, Variable)
sprinkledProg)
   in Program (Annotations, Variable)
sprinkledProg{programDefs :: [(Binder Variable, Expr Variable)]
I.programDefs = [(Binder Variable, Expr Variable)]
discardedDefs, symTable :: Map VarId (SymInfo Variable)
I.symTable = Map VarId (SymInfo Variable)
forall t. Map VarId (SymInfo t)
I.uninitializedSymTable}