{-# LANGUAGE OverloadedStrings #-}

-- | Parse OpRegion nodes inside of an AST 'Program'.
module Front.ParseOperators (
  parseOperators,
  A.Fixity (..),
) where

import qualified Common.Compiler as Compiler
import Common.Identifiers (Identifier)

import qualified Front.Ast as A

import Data.Bifunctor (Bifunctor (..))
import qualified Data.Map.Strict as Map


data Stack
  = BOS
  | Stack Stack A.Expr Identifier


-- FIXME: These should be defined and included in the standard library

-- | Default fixity of operators.
defaultOps :: [A.Fixity]
defaultOps :: [Fixity]
defaultOps =
  [ Int -> Identifier -> Fixity
A.Infixl Int
4 Identifier
"=="
  , Int -> Identifier -> Fixity
A.Infixl Int
4 Identifier
"!="
  , Int -> Identifier -> Fixity
A.Infixl Int
4 Identifier
"<="
  , Int -> Identifier -> Fixity
A.Infixl Int
4 Identifier
">="
  , Int -> Identifier -> Fixity
A.Infixl Int
4 Identifier
"<"
  , Int -> Identifier -> Fixity
A.Infixl Int
4 Identifier
">"
  , Int -> Identifier -> Fixity
A.Infixl Int
6 Identifier
"+"
  , Int -> Identifier -> Fixity
A.Infixl Int
6 Identifier
"-"
  , Int -> Identifier -> Fixity
A.Infixl Int
8 Identifier
"*"
  , Int -> Identifier -> Fixity
A.Infixl Int
8 Identifier
"/"
  , Int -> Identifier -> Fixity
A.Infixl Int
8 Identifier
"%"
  , Int -> Identifier -> Fixity
A.Infixr Int
8 Identifier
"^"
  ]


-- | Parse OpRegion nodes inside of an AST 'Program'.
parseOperators :: A.Program -> Compiler.Pass A.Program
parseOperators :: Program -> Pass Program
parseOperators (A.Program [TopDef]
decls) = Program -> Pass Program
forall (m :: * -> *) a. Monad m => a -> m a
return (Program -> Pass Program) -> Program -> Pass Program
forall a b. (a -> b) -> a -> b
$ [TopDef] -> Program
A.Program ([TopDef] -> Program) -> [TopDef] -> Program
forall a b. (a -> b) -> a -> b
$ (TopDef -> TopDef) -> [TopDef] -> [TopDef]
forall a b. (a -> b) -> [a] -> [b]
map ([Fixity] -> TopDef -> TopDef
parseTop [Fixity]
ops) [TopDef]
decls
 where
  ops :: [Fixity]
ops = [Fixity]
defaultOps
  parseTop :: [Fixity] -> TopDef -> TopDef
parseTop [Fixity]
fs (A.TopDef Definition
d) = Definition -> TopDef
A.TopDef (Definition -> TopDef) -> Definition -> TopDef
forall a b. (a -> b) -> a -> b
$ [Fixity] -> Definition -> Definition
parseDef [Fixity]
fs Definition
d
  parseTop [Fixity]
_ TopDef
t = TopDef
t

  parseDef :: [Fixity] -> Definition -> Definition
parseDef [Fixity]
fs (A.DefFn Identifier
v [Pat]
bs TypFn
t Expr
e) = Identifier -> [Pat] -> TypFn -> Expr -> Definition
A.DefFn Identifier
v [Pat]
bs TypFn
t (Expr -> Definition) -> Expr -> Definition
forall a b. (a -> b) -> a -> b
$ [Fixity] -> Expr -> Expr
parseExprOps [Fixity]
fs Expr
e
  parseDef [Fixity]
fs (A.DefPat Pat
b Expr
e) = Pat -> Expr -> Definition
A.DefPat Pat
b (Expr -> Definition) -> Expr -> Definition
forall a b. (a -> b) -> a -> b
$ [Fixity] -> Expr -> Expr
parseExprOps [Fixity]
fs Expr
e


{- | Remove the OpRegion constructs in the AST by parsing the operators
   according to the given Fixity specifications
-}
parseExprOps :: [A.Fixity] -> A.Expr -> A.Expr
parseExprOps :: [Fixity] -> Expr -> Expr
parseExprOps [Fixity]
fixity = Expr -> Expr
rw
 where
  rw :: Expr -> Expr
rw r :: Expr
r@(A.OpRegion Expr
_ OpRegion
_) = let A.OpRegion Expr
e OpRegion
r' = (Expr -> Expr) -> Expr -> Expr
rewrite Expr -> Expr
rw Expr
r in Stack -> Expr -> OpRegion -> Expr
step Stack
BOS Expr
e OpRegion
r'
  rw Expr
e = (Expr -> Expr) -> Expr -> Expr
rewrite Expr -> Expr
rw Expr
e

  defaultPrec :: (Int, Int)
defaultPrec = (Int
18, Int
17)
  opMap :: Map Identifier (Int, Int)
opMap = [(Identifier, (Int, Int))] -> Map Identifier (Int, Int)
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList ([(Identifier, (Int, Int))] -> Map Identifier (Int, Int))
-> [(Identifier, (Int, Int))] -> Map Identifier (Int, Int)
forall a b. (a -> b) -> a -> b
$ (Fixity -> (Identifier, (Int, Int)))
-> [Fixity] -> [(Identifier, (Int, Int))]
forall a b. (a -> b) -> [a] -> [b]
map Fixity -> (Identifier, (Int, Int))
fixToPair [Fixity]
fixity

  fixToPair :: Fixity -> (Identifier, (Int, Int))
fixToPair (A.Infixl Int
prec Identifier
op) = (Identifier
op, (Int
prec Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2, Int
prec Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
  fixToPair (A.Infixr Int
prec Identifier
op) = (Identifier
op, (Int
prec Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2, Int
prec Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

  step :: Stack -> A.Expr -> A.OpRegion -> A.Expr
  step :: Stack -> Expr -> OpRegion -> Expr
step Stack
BOS Expr
e OpRegion
A.EOR = Expr
e
  step Stack
BOS Expr
e1 (A.NextOp Identifier
op Expr
e2 OpRegion
ts) = Stack -> Expr -> Identifier -> Expr -> OpRegion -> Expr
shift Stack
BOS Expr
e1 Identifier
op Expr
e2 OpRegion
ts
  step (Stack Stack
s Expr
e1 Identifier
op) Expr
e2 OpRegion
A.EOR = Stack -> Expr -> Identifier -> Expr -> OpRegion -> Expr
reduce Stack
s Expr
e1 Identifier
op Expr
e2 OpRegion
A.EOR
  step s0 :: Stack
s0@(Stack Stack
s1 Expr
e1 Identifier
op1) Expr
e2 t0 :: OpRegion
t0@(A.NextOp Identifier
op2 Expr
e3 OpRegion
t1)
    | Identifier
op1 Identifier -> Identifier -> Bool
`shouldShift` Identifier
op2 = Stack -> Expr -> Identifier -> Expr -> OpRegion -> Expr
shift Stack
s0 Expr
e2 Identifier
op2 Expr
e3 OpRegion
t1
    | Bool
otherwise = Stack -> Expr -> Identifier -> Expr -> OpRegion -> Expr
reduce Stack
s1 Expr
e1 Identifier
op1 Expr
e2 OpRegion
t0

  shift, reduce :: Stack -> A.Expr -> Identifier -> A.Expr -> A.OpRegion -> A.Expr
  shift :: Stack -> Expr -> Identifier -> Expr -> OpRegion -> Expr
shift Stack
s Expr
e1 Identifier
op Expr
e2 OpRegion
ts = Stack -> Expr -> OpRegion -> Expr
step (Stack -> Expr -> Identifier -> Stack
Stack Stack
s Expr
e1 Identifier
op) Expr
e2 OpRegion
ts
  reduce :: Stack -> Expr -> Identifier -> Expr -> OpRegion -> Expr
reduce Stack
s Expr
e1 Identifier
op Expr
e2 OpRegion
ts = Stack -> Expr -> OpRegion -> Expr
step Stack
s (Expr -> Expr -> Expr
A.Apply (Expr -> Expr -> Expr
A.Apply (Identifier -> Expr
A.Id Identifier
op) Expr
e1) Expr
e2) OpRegion
ts

  shouldShift :: Identifier -> Identifier -> Bool
  shouldShift :: Identifier -> Identifier -> Bool
shouldShift Identifier
opl Identifier
opr = Int
pl Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
pr
   where
    pl :: Int
pl = (Int, Int) -> Int
forall a b. (a, b) -> a
fst (Int, Int)
prel
    pr :: Int
pr = (Int, Int) -> Int
forall a b. (a, b) -> b
snd (Int, Int)
prer
    prel :: (Int, Int)
prel = (Int, Int) -> Identifier -> Map Identifier (Int, Int) -> (Int, Int)
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault (Int, Int)
defaultPrec Identifier
opl Map Identifier (Int, Int)
opMap
    prer :: (Int, Int)
prer = (Int, Int) -> Identifier -> Map Identifier (Int, Int) -> (Int, Int)
forall k a. Ord k => a -> k -> Map k a -> a
Map.findWithDefault (Int, Int)
defaultPrec Identifier
opr Map Identifier (Int, Int)
opMap
  rewrite :: (A.Expr -> A.Expr) -> A.Expr -> A.Expr
  rewrite :: (Expr -> Expr) -> Expr -> Expr
rewrite Expr -> Expr
f (A.Apply Expr
e1 Expr
e2) = Expr -> Expr -> Expr
A.Apply (Expr -> Expr
f Expr
e1) (Expr -> Expr
f Expr
e2)
  rewrite Expr -> Expr
f (A.OpRegion Expr
e OpRegion
r) = Expr -> OpRegion -> Expr
A.OpRegion (Expr -> Expr
f Expr
e) (OpRegion -> OpRegion
h OpRegion
r)
   where
    h :: OpRegion -> OpRegion
h OpRegion
A.EOR = OpRegion
A.EOR
    h (A.NextOp Identifier
op Expr
e' OpRegion
r') = Identifier -> Expr -> OpRegion -> OpRegion
A.NextOp Identifier
op (Expr -> Expr
f Expr
e') (OpRegion -> OpRegion
h OpRegion
r')
  rewrite Expr -> Expr
f (A.Let [Definition]
d Expr
b) = [Definition] -> Expr -> Expr
A.Let ((Definition -> Definition) -> [Definition] -> [Definition]
forall a b. (a -> b) -> [a] -> [b]
map Definition -> Definition
g [Definition]
d) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
f Expr
b
   where
    g :: Definition -> Definition
g (A.DefFn Identifier
v [Pat]
bs TypFn
t Expr
e) = Identifier -> [Pat] -> TypFn -> Expr -> Definition
A.DefFn Identifier
v [Pat]
bs TypFn
t (Expr -> Definition) -> Expr -> Definition
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
f Expr
e
    g (A.DefPat Pat
p Expr
e) = Pat -> Expr -> Definition
A.DefPat Pat
p (Expr -> Definition) -> Expr -> Definition
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
f Expr
e
  rewrite Expr -> Expr
f (A.While Expr
e1 Expr
e2) = Expr -> Expr -> Expr
A.While (Expr -> Expr
f Expr
e1) (Expr -> Expr
f Expr
e2)
  rewrite Expr -> Expr
f (A.Loop Expr
e) = Expr -> Expr
A.Loop (Expr -> Expr
f Expr
e)
  rewrite Expr -> Expr
f (A.Par [Expr]
e) = [Expr] -> Expr
A.Par ([Expr] -> Expr) -> [Expr] -> Expr
forall a b. (a -> b) -> a -> b
$ (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> Expr
f [Expr]
e
  rewrite Expr -> Expr
f (A.Wait [Expr]
e) = [Expr] -> Expr
A.Wait ([Expr] -> Expr) -> [Expr] -> Expr
forall a b. (a -> b) -> a -> b
$ (Expr -> Expr) -> [Expr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map Expr -> Expr
f [Expr]
e
  rewrite Expr -> Expr
f (A.IfElse Expr
e1 Expr
e2 Expr
e3) = Expr -> Expr -> Expr -> Expr
A.IfElse (Expr -> Expr
f Expr
e1) (Expr -> Expr
f Expr
e2) (Expr -> Expr
f Expr
e3)
  rewrite Expr -> Expr
f (A.After Expr
e1 Expr
p Expr
e2) = Expr -> Expr -> Expr -> Expr
A.After (Expr -> Expr
f Expr
e1) Expr
p (Expr -> Expr
f Expr
e2)
  rewrite Expr -> Expr
f (A.Assign Expr
p Expr
e) = Expr -> Expr -> Expr
A.Assign Expr
p (Expr -> Expr
f Expr
e)
  rewrite Expr -> Expr
f (A.Constraint Expr
e TypAnn
t) = Expr -> TypAnn -> Expr
A.Constraint (Expr -> Expr
f Expr
e) TypAnn
t
  rewrite Expr -> Expr
f (A.Seq Expr
e1 Expr
e2) = Expr -> Expr -> Expr
A.Seq (Expr -> Expr
f Expr
e1) (Expr -> Expr
f Expr
e2)
  rewrite Expr -> Expr
f (A.Lambda [Pat]
ps Expr
b) = [Pat] -> Expr -> Expr
A.Lambda [Pat]
ps (Expr -> Expr
f Expr
b)
  rewrite Expr -> Expr
f (A.Match Expr
s [(Pat, Expr)]
as) = Expr -> [(Pat, Expr)] -> Expr
A.Match (Expr -> Expr
f Expr
s) ([(Pat, Expr)] -> Expr) -> [(Pat, Expr)] -> Expr
forall a b. (a -> b) -> a -> b
$ ((Pat, Expr) -> (Pat, Expr)) -> [(Pat, Expr)] -> [(Pat, Expr)]
forall a b. (a -> b) -> [a] -> [b]
map ((Expr -> Expr) -> (Pat, Expr) -> (Pat, Expr)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Expr -> Expr
f) [(Pat, Expr)]
as
  rewrite Expr -> Expr
_ Expr
e = Expr
e