{-# LANGUAGE DerivingVia #-}
module IR.ExternToCall (
externToCall,
) where
import qualified Common.Compiler as Compiler
import Common.Compiler (MonadError)
import Common.Identifiers (
fromId,
fromString,
ident,
)
import qualified IR.IR as I
import qualified IR.Types as I
import Control.Monad.Reader (
MonadReader (..),
ReaderT (..),
asks,
)
import Data.Generics.Aliases (mkM)
import Data.Generics.Schemes (everywhereM)
import qualified Data.Set as S
type ExternEnv = S.Set I.VarId
newtype ExternFn a = ExternFn (ReaderT ExternEnv Compiler.Pass a)
deriving (a -> ExternFn b -> ExternFn a
(a -> b) -> ExternFn a -> ExternFn b
(forall a b. (a -> b) -> ExternFn a -> ExternFn b)
-> (forall a b. a -> ExternFn b -> ExternFn a) -> Functor ExternFn
forall a b. a -> ExternFn b -> ExternFn a
forall a b. (a -> b) -> ExternFn a -> ExternFn b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ExternFn b -> ExternFn a
$c<$ :: forall a b. a -> ExternFn b -> ExternFn a
fmap :: (a -> b) -> ExternFn a -> ExternFn b
$cfmap :: forall a b. (a -> b) -> ExternFn a -> ExternFn b
Functor) via (ReaderT ExternEnv Compiler.Pass)
deriving (Functor ExternFn
a -> ExternFn a
Functor ExternFn
-> (forall a. a -> ExternFn a)
-> (forall a b. ExternFn (a -> b) -> ExternFn a -> ExternFn b)
-> (forall a b c.
(a -> b -> c) -> ExternFn a -> ExternFn b -> ExternFn c)
-> (forall a b. ExternFn a -> ExternFn b -> ExternFn b)
-> (forall a b. ExternFn a -> ExternFn b -> ExternFn a)
-> Applicative ExternFn
ExternFn a -> ExternFn b -> ExternFn b
ExternFn a -> ExternFn b -> ExternFn a
ExternFn (a -> b) -> ExternFn a -> ExternFn b
(a -> b -> c) -> ExternFn a -> ExternFn b -> ExternFn c
forall a. a -> ExternFn a
forall a b. ExternFn a -> ExternFn b -> ExternFn a
forall a b. ExternFn a -> ExternFn b -> ExternFn b
forall a b. ExternFn (a -> b) -> ExternFn a -> ExternFn b
forall a b c.
(a -> b -> c) -> ExternFn a -> ExternFn b -> ExternFn c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ExternFn a -> ExternFn b -> ExternFn a
$c<* :: forall a b. ExternFn a -> ExternFn b -> ExternFn a
*> :: ExternFn a -> ExternFn b -> ExternFn b
$c*> :: forall a b. ExternFn a -> ExternFn b -> ExternFn b
liftA2 :: (a -> b -> c) -> ExternFn a -> ExternFn b -> ExternFn c
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExternFn a -> ExternFn b -> ExternFn c
<*> :: ExternFn (a -> b) -> ExternFn a -> ExternFn b
$c<*> :: forall a b. ExternFn (a -> b) -> ExternFn a -> ExternFn b
pure :: a -> ExternFn a
$cpure :: forall a. a -> ExternFn a
$cp1Applicative :: Functor ExternFn
Applicative) via (ReaderT ExternEnv Compiler.Pass)
deriving (Applicative ExternFn
a -> ExternFn a
Applicative ExternFn
-> (forall a b. ExternFn a -> (a -> ExternFn b) -> ExternFn b)
-> (forall a b. ExternFn a -> ExternFn b -> ExternFn b)
-> (forall a. a -> ExternFn a)
-> Monad ExternFn
ExternFn a -> (a -> ExternFn b) -> ExternFn b
ExternFn a -> ExternFn b -> ExternFn b
forall a. a -> ExternFn a
forall a b. ExternFn a -> ExternFn b -> ExternFn b
forall a b. ExternFn a -> (a -> ExternFn b) -> ExternFn b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ExternFn a
$creturn :: forall a. a -> ExternFn a
>> :: ExternFn a -> ExternFn b -> ExternFn b
$c>> :: forall a b. ExternFn a -> ExternFn b -> ExternFn b
>>= :: ExternFn a -> (a -> ExternFn b) -> ExternFn b
$c>>= :: forall a b. ExternFn a -> (a -> ExternFn b) -> ExternFn b
$cp1Monad :: Applicative ExternFn
Monad) via (ReaderT ExternEnv Compiler.Pass)
deriving (Monad ExternFn
Monad ExternFn
-> (forall a. String -> ExternFn a) -> MonadFail ExternFn
String -> ExternFn a
forall a. String -> ExternFn a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> ExternFn a
$cfail :: forall a. String -> ExternFn a
$cp1MonadFail :: Monad ExternFn
MonadFail) via (ReaderT ExternEnv Compiler.Pass)
deriving (MonadError Compiler.Error) via (ReaderT ExternEnv Compiler.Pass)
deriving (MonadReader ExternEnv) via (ReaderT ExternEnv Compiler.Pass)
runExternFn :: ExternEnv -> ExternFn a -> Compiler.Pass a
runExternFn :: ExternEnv -> ExternFn a -> Pass a
runExternFn ExternEnv
env (ExternFn ReaderT ExternEnv Pass a
m) = ReaderT ExternEnv Pass a -> ExternEnv -> Pass a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT ExternEnv Pass a
m ExternEnv
env
externToCall :: I.Program I.Type -> Compiler.Pass (I.Program I.Type)
externToCall :: Program Type -> Pass (Program Type)
externToCall p :: Program Type
p@I.Program{programDefs :: forall t. Program t -> [(Binder t, Expr t)]
I.programDefs = [(Binder Type, Expr Type)]
defs, externDecls :: forall t. Program t -> [(VarId, Type)]
I.externDecls = [(VarId, Type)]
xds} =
ExternEnv -> ExternFn (Program Type) -> Pass (Program Type)
forall a. ExternEnv -> ExternFn a -> Pass a
runExternFn ([VarId] -> ExternEnv
forall a. Ord a => [a] -> Set a
S.fromList ([VarId] -> ExternEnv) -> [VarId] -> ExternEnv
forall a b. (a -> b) -> a -> b
$ ((VarId, Type) -> VarId) -> [(VarId, Type)] -> [VarId]
forall a b. (a -> b) -> [a] -> [b]
map (VarId, Type) -> VarId
forall a b. (a, b) -> a
fst [(VarId, Type)]
xds) (ExternFn (Program Type) -> Pass (Program Type))
-> ExternFn (Program Type) -> Pass (Program Type)
forall a b. (a -> b) -> a -> b
$ do
[(Binder Type, Expr Type)]
externs' <- ((VarId, Type) -> ExternFn (Binder Type, Expr Type))
-> [(VarId, Type)] -> ExternFn [(Binder Type, Expr Type)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VarId, Type) -> ExternFn (Binder Type, Expr Type)
makeExternFunc [(VarId, Type)]
xds
[(Binder Type, Expr Type)]
defs' <- GenericM ExternFn
-> [(Binder Type, Expr Type)]
-> ExternFn [(Binder Type, Expr Type)]
forall (m :: * -> *). Monad m => GenericM m -> GenericM m
everywhereM ((Expr Type -> ExternFn (Expr Type)) -> a -> ExternFn a
forall (m :: * -> *) a b.
(Monad m, Typeable a, Typeable b) =>
(b -> m b) -> a -> m a
mkM Expr Type -> ExternFn (Expr Type)
dataToApp) [(Binder Type, Expr Type)]
defs
Program Type -> ExternFn (Program Type)
forall (m :: * -> *) a. Monad m => a -> m a
return Program Type
p{programDefs :: [(Binder Type, Expr Type)]
I.programDefs = [(Binder Type, Expr Type)]
externs' [(Binder Type, Expr Type)]
-> [(Binder Type, Expr Type)] -> [(Binder Type, Expr Type)]
forall a. [a] -> [a] -> [a]
++ [(Binder Type, Expr Type)]
defs'}
dataToApp :: I.Expr I.Type -> ExternFn (I.Expr I.Type)
dataToApp :: Expr Type -> ExternFn (Expr Type)
dataToApp (I.Var VarId
n Type
t) = do
Bool
isExtern <- (ExternEnv -> Bool) -> ExternFn Bool
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (VarId -> ExternEnv -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VarId
n)
Expr Type -> ExternFn (Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr Type -> ExternFn (Expr Type))
-> Expr Type -> ExternFn (Expr Type)
forall a b. (a -> b) -> a -> b
$ VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var (if Bool
isExtern then VarId -> VarId
liftExtern VarId
n else VarId
n) Type
t
dataToApp Expr Type
a = Expr Type -> ExternFn (Expr Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Expr Type
a
makeExternFunc :: (I.VarId, I.Type) -> ExternFn (I.Binder I.Type, I.Expr I.Type)
makeExternFunc :: (VarId, Type) -> ExternFn (Binder Type, Expr Type)
makeExternFunc (VarId
x, Type
t) = do
let ([Type]
ats, Type
rt) = Type -> ([Type], Type)
I.unfoldArrow Type
t
args :: [(VarId, Type)]
args = [VarId] -> [Type] -> [(VarId, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Int -> VarId) -> [Int] -> [VarId]
forall a b. (a -> b) -> [a] -> [b]
map Int -> VarId
argName [Int
0 ..]) [Type]
ats
body :: Expr Type
body = Primitive -> [Expr Type] -> Type -> Expr Type
forall t. Primitive -> [Expr t] -> t -> Expr t
I.Prim (VarId -> Primitive
I.FfiCall (VarId -> Primitive) -> VarId -> Primitive
forall a b. (a -> b) -> a -> b
$ VarId -> VarId
forall a b. (Identifiable a, Identifiable b) => a -> b
fromId VarId
x) (((VarId, Type) -> Expr Type) -> [(VarId, Type)] -> [Expr Type]
forall a b. (a -> b) -> [a] -> [b]
map ((VarId -> Type -> Expr Type) -> (VarId, Type) -> Expr Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Expr Type
forall t. VarId -> t -> Expr t
I.Var) [(VarId, Type)]
args) Type
rt
var :: VarId
var = VarId -> VarId
liftExtern VarId
x
func :: Expr Type
func = [Binder Type] -> Expr Type -> Expr Type
I.foldLambda (((VarId, Type) -> Binder Type) -> [(VarId, Type)] -> [Binder Type]
forall a b. (a -> b) -> [a] -> [b]
map ((VarId -> Type -> Binder Type) -> (VarId, Type) -> Binder Type
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar) [(VarId, Type)]
args) Expr Type
body
if [Type] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Type]
ats
then String -> ExternFn (Binder Type, Expr Type)
forall (m :: * -> *) a. MonadError Error m => String -> m a
Compiler.typeError String
errMsg
else (Binder Type, Expr Type) -> ExternFn (Binder Type, Expr Type)
forall (m :: * -> *) a. Monad m => a -> m a
return (VarId -> Type -> Binder Type
forall t. VarId -> t -> Binder t
I.BindVar VarId
var (Type -> Binder Type) -> Type -> Binder Type
forall a b. (a -> b) -> a -> b
$ Expr Type -> Type
forall (c :: * -> *) a. Carrier c => c a -> a
I.extract Expr Type
func, Expr Type
func)
where
errMsg :: String
errMsg = String
"Extern symbol does not have function type: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarId -> String
forall a. Show a => a -> String
show VarId
x
argName :: Int -> I.VarId
argName :: Int -> VarId
argName Int
i = String -> VarId
forall a. IsString a => String -> a
fromString (String
"__arg" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
i)
liftExtern :: I.VarId -> I.VarId
liftExtern :: VarId -> VarId
liftExtern VarId
dconid = String -> VarId
forall a. IsString a => String -> a
fromString (String -> VarId) -> String -> VarId
forall a b. (a -> b) -> a -> b
$ String
"__" String -> String -> String
forall a. [a] -> [a] -> [a]
++ VarId -> String
forall i. Identifiable i => i -> String
ident VarId
dconid