{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DerivingVia #-}
{-# OPTIONS_HADDOCK prune #-}

-- | Data types and helpers used to compose the compiler pipeline.
module Common.Compiler (
  ErrorMsg,
  Error (..),
  Warning (..),
  Pass (..),
  MonadError (..),
  MonadWriter (..),
  fromString,
  runPass,
  dump,
  unexpected,
  warn,
  todo,
  passIO,
  liftEither,
  typeError,
) where

import Common.Pretty (Pretty (pretty))
import Control.Monad.Except (
  Except,
  MonadError (..),
  liftEither,
  runExcept,
  throwError,
 )
import Control.Monad.Writer.Strict (
  MonadWriter (..),
  WriterT (..),
 )
import Data.String (IsString (..))
import System.Exit (
  exitFailure,
  exitSuccess,
 )
import System.IO (
  hPrint,
  stderr,
 )


-- | Type for error messages.
newtype ErrorMsg = ErrorMsg String
  deriving (Int -> ErrorMsg -> ShowS
[ErrorMsg] -> ShowS
ErrorMsg -> String
(Int -> ErrorMsg -> ShowS)
-> (ErrorMsg -> String) -> ([ErrorMsg] -> ShowS) -> Show ErrorMsg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ErrorMsg] -> ShowS
$cshowList :: [ErrorMsg] -> ShowS
show :: ErrorMsg -> String
$cshow :: ErrorMsg -> String
showsPrec :: Int -> ErrorMsg -> ShowS
$cshowsPrec :: Int -> ErrorMsg -> ShowS
Show)
  deriving (b -> ErrorMsg -> ErrorMsg
NonEmpty ErrorMsg -> ErrorMsg
ErrorMsg -> ErrorMsg -> ErrorMsg
(ErrorMsg -> ErrorMsg -> ErrorMsg)
-> (NonEmpty ErrorMsg -> ErrorMsg)
-> (forall b. Integral b => b -> ErrorMsg -> ErrorMsg)
-> Semigroup ErrorMsg
forall b. Integral b => b -> ErrorMsg -> ErrorMsg
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
stimes :: b -> ErrorMsg -> ErrorMsg
$cstimes :: forall b. Integral b => b -> ErrorMsg -> ErrorMsg
sconcat :: NonEmpty ErrorMsg -> ErrorMsg
$csconcat :: NonEmpty ErrorMsg -> ErrorMsg
<> :: ErrorMsg -> ErrorMsg -> ErrorMsg
$c<> :: ErrorMsg -> ErrorMsg -> ErrorMsg
Semigroup) via String
  deriving (Semigroup ErrorMsg
ErrorMsg
Semigroup ErrorMsg
-> ErrorMsg
-> (ErrorMsg -> ErrorMsg -> ErrorMsg)
-> ([ErrorMsg] -> ErrorMsg)
-> Monoid ErrorMsg
[ErrorMsg] -> ErrorMsg
ErrorMsg -> ErrorMsg -> ErrorMsg
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
mconcat :: [ErrorMsg] -> ErrorMsg
$cmconcat :: [ErrorMsg] -> ErrorMsg
mappend :: ErrorMsg -> ErrorMsg -> ErrorMsg
$cmappend :: ErrorMsg -> ErrorMsg -> ErrorMsg
mempty :: ErrorMsg
$cmempty :: ErrorMsg
$cp1Monoid :: Semigroup ErrorMsg
Monoid) via String


instance IsString ErrorMsg where
  fromString :: String -> ErrorMsg
fromString = String -> ErrorMsg
ErrorMsg


instance Eq ErrorMsg where
  ErrorMsg
_ == :: ErrorMsg -> ErrorMsg -> Bool
== ErrorMsg
_ = Bool
True


-- | Types of compiler errors that can be thrown during compilation.
data Error
  = -- | Halt compiler to dump output (not an error)
    Dump String
  | -- | Internal error; should be unreachable
    UnexpectedError ErrorMsg
  | -- | "It's a research artifact"
    UnimplementedError ErrorMsg
  | -- | Round peg in square hole
    TypeError ErrorMsg
  | -- | Identifier is out of scope
    ScopeError ErrorMsg
  | -- | Invalid naming convention at binding
    NameError ErrorMsg
  | -- | Malformed pattern
    PatternError ErrorMsg
  | -- | Error encountered by scanner
    LexError ErrorMsg
  | -- | Error encountered by parser
    ParseError ErrorMsg
  deriving (Int -> Error -> ShowS
[Error] -> ShowS
Error -> String
(Int -> Error -> ShowS)
-> (Error -> String) -> ([Error] -> ShowS) -> Show Error
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Error] -> ShowS
$cshowList :: [Error] -> ShowS
show :: Error -> String
$cshow :: Error -> String
showsPrec :: Int -> Error -> ShowS
$cshowsPrec :: Int -> Error -> ShowS
Show, Error -> Error -> Bool
(Error -> Error -> Bool) -> (Error -> Error -> Bool) -> Eq Error
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Error -> Error -> Bool
$c/= :: Error -> Error -> Bool
== :: Error -> Error -> Bool
$c== :: Error -> Error -> Bool
Eq)


-- | Types of compiler warnings that can be logged during compilation.
data Warning
  = -- | Warning about type
    TypeWarning ErrorMsg
  | -- | Warning related to identifier names
    NameWarning ErrorMsg
  | -- | Warning related to patterns
    PatternWarning ErrorMsg
  deriving (Int -> Warning -> ShowS
[Warning] -> ShowS
Warning -> String
(Int -> Warning -> ShowS)
-> (Warning -> String) -> ([Warning] -> ShowS) -> Show Warning
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Warning] -> ShowS
$cshowList :: [Warning] -> ShowS
show :: Warning -> String
$cshow :: Warning -> String
showsPrec :: Int -> Warning -> ShowS
$cshowsPrec :: Int -> Warning -> ShowS
Show, Warning -> Warning -> Bool
(Warning -> Warning -> Bool)
-> (Warning -> Warning -> Bool) -> Eq Warning
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Warning -> Warning -> Bool
$c/= :: Warning -> Warning -> Bool
== :: Warning -> Warning -> Bool
$c== :: Warning -> Warning -> Bool
Eq)


-- | Type alias for underlying compiler pipeline monad.
type PassMonad = WriterT [Warning] (Except Error)


-- | The compiler pipeline monad; supports throwing errors, logging, etc.
newtype Pass a = Pass (PassMonad a)
  deriving (a -> Pass b -> Pass a
(a -> b) -> Pass a -> Pass b
(forall a b. (a -> b) -> Pass a -> Pass b)
-> (forall a b. a -> Pass b -> Pass a) -> Functor Pass
forall a b. a -> Pass b -> Pass a
forall a b. (a -> b) -> Pass a -> Pass b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Pass b -> Pass a
$c<$ :: forall a b. a -> Pass b -> Pass a
fmap :: (a -> b) -> Pass a -> Pass b
$cfmap :: forall a b. (a -> b) -> Pass a -> Pass b
Functor) via PassMonad
  deriving (Functor Pass
a -> Pass a
Functor Pass
-> (forall a. a -> Pass a)
-> (forall a b. Pass (a -> b) -> Pass a -> Pass b)
-> (forall a b c. (a -> b -> c) -> Pass a -> Pass b -> Pass c)
-> (forall a b. Pass a -> Pass b -> Pass b)
-> (forall a b. Pass a -> Pass b -> Pass a)
-> Applicative Pass
Pass a -> Pass b -> Pass b
Pass a -> Pass b -> Pass a
Pass (a -> b) -> Pass a -> Pass b
(a -> b -> c) -> Pass a -> Pass b -> Pass c
forall a. a -> Pass a
forall a b. Pass a -> Pass b -> Pass a
forall a b. Pass a -> Pass b -> Pass b
forall a b. Pass (a -> b) -> Pass a -> Pass b
forall a b c. (a -> b -> c) -> Pass a -> Pass b -> Pass c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Pass a -> Pass b -> Pass a
$c<* :: forall a b. Pass a -> Pass b -> Pass a
*> :: Pass a -> Pass b -> Pass b
$c*> :: forall a b. Pass a -> Pass b -> Pass b
liftA2 :: (a -> b -> c) -> Pass a -> Pass b -> Pass c
$cliftA2 :: forall a b c. (a -> b -> c) -> Pass a -> Pass b -> Pass c
<*> :: Pass (a -> b) -> Pass a -> Pass b
$c<*> :: forall a b. Pass (a -> b) -> Pass a -> Pass b
pure :: a -> Pass a
$cpure :: forall a. a -> Pass a
$cp1Applicative :: Functor Pass
Applicative) via PassMonad
  deriving (Applicative Pass
a -> Pass a
Applicative Pass
-> (forall a b. Pass a -> (a -> Pass b) -> Pass b)
-> (forall a b. Pass a -> Pass b -> Pass b)
-> (forall a. a -> Pass a)
-> Monad Pass
Pass a -> (a -> Pass b) -> Pass b
Pass a -> Pass b -> Pass b
forall a. a -> Pass a
forall a b. Pass a -> Pass b -> Pass b
forall a b. Pass a -> (a -> Pass b) -> Pass b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Pass a
$creturn :: forall a. a -> Pass a
>> :: Pass a -> Pass b -> Pass b
$c>> :: forall a b. Pass a -> Pass b -> Pass b
>>= :: Pass a -> (a -> Pass b) -> Pass b
$c>>= :: forall a b. Pass a -> (a -> Pass b) -> Pass b
$cp1Monad :: Applicative Pass
Monad) via PassMonad
  deriving (MonadError Error) via PassMonad
  deriving (MonadWriter [Warning]) via PassMonad


instance MonadFail Pass where
  fail :: String -> Pass a
fail = Error -> Pass a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> Pass a) -> (String -> Error) -> String -> Pass a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg -> Error
UnexpectedError (ErrorMsg -> Error) -> (String -> ErrorMsg) -> String -> Error
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ErrorMsg
forall a. IsString a => String -> a
fromString


-- | Invoke a compiler 'Pass'.
runPass :: Pass a -> Either Error (a, [Warning])
runPass :: Pass a -> Either Error (a, [Warning])
runPass (Pass PassMonad a
p) = Except Error (a, [Warning]) -> Either Error (a, [Warning])
forall e a. Except e a -> Either e a
runExcept (PassMonad a -> Except Error (a, [Warning])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT PassMonad a
p)


-- | Dump pretty-printable output from within a compiler pass.
dump :: Pretty a => a -> Pass x
dump :: a -> Pass x
dump = Error -> Pass x
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> Pass x) -> (a -> Error) -> a -> Pass x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Error
Dump (String -> Error) -> (a -> String) -> a -> Error
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Doc Any -> String
forall a. Show a => a -> String
show (Doc Any -> String) -> (a -> Doc Any) -> a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Doc Any
forall a ann. Pretty a => a -> Doc ann
pretty


-- | Report unexpected compiler error and halt pipeline.
unexpected :: (MonadError Error m) => String -> m a
unexpected :: String -> m a
unexpected = Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m a) -> (String -> Error) -> String -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg -> Error
UnexpectedError (ErrorMsg -> Error) -> (String -> ErrorMsg) -> String -> Error
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ErrorMsg
forall a. IsString a => String -> a
fromString


-- | Log a compiler warning.
warn :: MonadWriter [Warning] m => Warning -> m ()
warn :: Warning -> m ()
warn Warning
w = m ((), [Warning] -> [Warning]) -> m ()
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (m ((), [Warning] -> [Warning]) -> m ())
-> m ((), [Warning] -> [Warning]) -> m ()
forall a b. (a -> b) -> a -> b
$ ((), [Warning] -> [Warning]) -> m ((), [Warning] -> [Warning])
forall (m :: * -> *) a. Monad m => a -> m a
return ((), ([Warning] -> [Warning] -> [Warning]
forall a. [a] -> [a] -> [a]
++ [Warning
w]))


-- | Report unexpected compiler error and halt pipeline.
todo :: (MonadError Error m) => String -> m a
todo :: String -> m a
todo = Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m a) -> (String -> Error) -> String -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg -> Error
UnimplementedError (ErrorMsg -> Error) -> (String -> ErrorMsg) -> String -> Error
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ErrorMsg
forall a. IsString a => String -> a
fromString


-- | Execute compiler pass in I/O monad, exiting upon exception.
passIO :: Pass a -> IO (a, [Warning])
passIO :: Pass a -> IO (a, [Warning])
passIO Pass a
p = case Pass a -> Either Error (a, [Warning])
forall a. Pass a -> Either Error (a, [Warning])
runPass Pass a
p of
  Left (Dump String
s) -> String -> IO ()
putStrLn String
s IO () -> IO (a, [Warning]) -> IO (a, [Warning])
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (a, [Warning])
forall a. IO a
exitSuccess
  Left Error
e -> Handle -> Error -> IO ()
forall a. Show a => Handle -> a -> IO ()
hPrint Handle
stderr Error
e IO () -> IO (a, [Warning]) -> IO (a, [Warning])
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (a, [Warning])
forall a. IO a
exitFailure
  Right (a, [Warning])
a -> (a, [Warning]) -> IO (a, [Warning])
forall (m :: * -> *) a. Monad m => a -> m a
return (a, [Warning])
a


-- | Throw a type error with some String error message.
typeError :: (MonadError Error m) => String -> m a
typeError :: String -> m a
typeError = Error -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> m a) -> (String -> Error) -> String -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg -> Error
TypeError (ErrorMsg -> Error) -> (String -> ErrorMsg) -> String -> Error
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ErrorMsg
forall a. IsString a => String -> a
fromString