module Front.DesugarPatTup (desugarPatTup) where

import qualified Common.Compiler as Compiler
import Common.Identifiers
import Data.Bifunctor (first)
import Data.Generics (Data (..), everywhere, mkT)
import qualified Front.Ast as A


desugarPatTup :: A.Program -> Compiler.Pass A.Program
desugarPatTup :: Program -> Pass Program
desugarPatTup Program
p = Program -> Pass Program
forall (m :: * -> *) a. Monad m => a -> m a
return (Program -> Pass Program) -> Program -> Pass Program
forall a b. (a -> b) -> a -> b
$ Program -> Program
forall a. Data a => a -> a
desugarSubst Program
p


desugarSubst :: (Data a) => a -> a
desugarSubst :: a -> a
desugarSubst = a -> a
forall a. Data a => a -> a
desugarExpr (a -> a) -> (a -> a) -> a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> a
forall a. Data a => a -> a
desugarDef


desugarDef :: (Data a) => a -> a
desugarDef :: a -> a
desugarDef = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$ (Definition -> Definition) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT Definition -> Definition
substDef


desugarExpr :: (Data a) => a -> a
desugarExpr :: a -> a
desugarExpr = (forall a. Data a => a -> a) -> forall a. Data a => a -> a
everywhere ((forall a. Data a => a -> a) -> forall a. Data a => a -> a)
-> (forall a. Data a => a -> a) -> forall a. Data a => a -> a
forall a b. (a -> b) -> a -> b
$ (Expr -> Expr) -> a -> a
forall a b. (Typeable a, Typeable b) => (b -> b) -> a -> a
mkT Expr -> Expr
substExpr


substDef :: A.Definition -> A.Definition
substDef :: Definition -> Definition
substDef (A.DefFn Identifier
i [Pat]
ps TypFn
t Expr
e) = Identifier -> [Pat] -> TypFn -> Expr -> Definition
A.DefFn Identifier
i [Pat]
pats TypFn
t Expr
rese
 where
  desugarPat :: [Pat] -> Expr -> a -> ([Pat], Expr)
desugarPat [] Expr
ex a
_ = ([], Expr
ex)
  desugarPat (Pat
p : [Pat]
rps) Expr
ex a
n =
    case Pat
p of
      A.PatTup [Pat]
_ ->
        let ([Pat]
rpats, Expr
rexpr) = [Pat] -> Expr -> a -> ([Pat], Expr)
desugarPat [Pat]
rps (Expr -> [(Pat, Expr)] -> Expr
A.Match (Identifier -> Expr
A.Id (String -> Identifier
Identifier (String
"_temp_id_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n))) [(Pat
p, Expr
ex)]) (a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
         in (Identifier -> Pat
A.PatId (String -> Identifier
forall a. IsString a => String -> a
fromString (String
"_temp_id_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n)) Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
: [Pat]
rpats, Expr
rexpr)
      Pat
_ -> ([Pat] -> [Pat]) -> ([Pat], Expr) -> ([Pat], Expr)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Pat
p Pat -> [Pat] -> [Pat]
forall a. a -> [a] -> [a]
:) (([Pat], Expr) -> ([Pat], Expr)) -> ([Pat], Expr) -> ([Pat], Expr)
forall a b. (a -> b) -> a -> b
$ [Pat] -> Expr -> a -> ([Pat], Expr)
desugarPat [Pat]
rps Expr
e (a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
  ([Pat]
pats, Expr
rese) = [Pat] -> Expr -> Int -> ([Pat], Expr)
forall a. (Show a, Num a) => [Pat] -> Expr -> a -> ([Pat], Expr)
desugarPat [Pat]
ps Expr
e (Int
0 :: Int)
substDef Definition
e = Definition
e


substExpr :: A.Expr -> A.Expr
substExpr :: Expr -> Expr
substExpr (A.Let [Definition]
defs Expr
e) = [Definition] -> Expr -> Expr
A.Let [Definition]
ndefs Expr
rese
 where
  desugarPat :: [Definition] -> Expr -> a -> ([Definition], Expr)
desugarPat [] Expr
ex a
_ = ([], Expr
ex)
  desugarPat (Definition
def : [Definition]
rdefs) Expr
ex a
n =
    case Definition
def of
      (A.DefPat p :: Pat
p@(A.PatTup [Pat]
_) Expr
defe) ->
        let ([Definition]
d, Expr
nexpr) = [Definition] -> Expr -> a -> ([Definition], Expr)
desugarPat [Definition]
rdefs (Expr -> [(Pat, Expr)] -> Expr
A.Match (Identifier -> Expr
A.Id (String -> Identifier
forall a. IsString a => String -> a
fromString (String
"_temp_id_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n))) [(Pat
p, Expr
ex)]) (a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
         in (Pat -> Expr -> Definition
A.DefPat (Identifier -> Pat
A.PatId (String -> Identifier
forall a. IsString a => String -> a
fromString (String
"_temp_id_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ a -> String
forall a. Show a => a -> String
show a
n))) Expr
defe Definition -> [Definition] -> [Definition]
forall a. a -> [a] -> [a]
: [Definition]
d, Expr
nexpr)
      Definition
_ -> ([Definition] -> [Definition])
-> ([Definition], Expr) -> ([Definition], Expr)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (Definition
def Definition -> [Definition] -> [Definition]
forall a. a -> [a] -> [a]
:) (([Definition], Expr) -> ([Definition], Expr))
-> ([Definition], Expr) -> ([Definition], Expr)
forall a b. (a -> b) -> a -> b
$ [Definition] -> Expr -> a -> ([Definition], Expr)
desugarPat [Definition]
rdefs Expr
e (a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1)
  ([Definition]
ndefs, Expr
rese) = [Definition] -> Expr -> Int -> ([Definition], Expr)
forall a.
(Show a, Num a) =>
[Definition] -> Expr -> a -> ([Definition], Expr)
desugarPat [Definition]
defs Expr
e (Int
0 :: Int)
substExpr Expr
e = Expr
e