mirror of
https://github.com/tweag/asterius.git
synced 2024-10-05 21:30:49 +03:00
Support exporting one-shot Haskell function closures dynamically (#527)
This commit is contained in:
parent
031eccd2bb
commit
547d8d426c
@ -68,7 +68,7 @@ export class Exports {
|
||||
return this.context.scheduler.submitCmdCreateThread("createIOThread", p);
|
||||
}
|
||||
|
||||
newHaskellCallback(sp, arg_tag, ret_tag, io) {
|
||||
newHaskellCallback(sp, arg_tag, ret_tag, io, finalizer) {
|
||||
const arg_mk_funcs = decodeTys(this.context.rtsMkFuncs, arg_tag),
|
||||
ret_get_funcs = decodeTys(this.context.rtsGetFuncs, ret_tag),
|
||||
run_func = this.context.symbolTable[
|
||||
@ -83,19 +83,23 @@ export class Exports {
|
||||
throw new WebAssembly.RuntimeError(`Multiple returns not supported`);
|
||||
}
|
||||
const cb = async (...args) => {
|
||||
if (args.length !== arg_mk_funcs.length) {
|
||||
throw new WebAssembly.RuntimeError(
|
||||
`Expected ${arg_mk_funcs.length} arguments, got ${args.length}`
|
||||
);
|
||||
}
|
||||
let p = this.context.stablePtrManager.deRefStablePtr(sp);
|
||||
for (let i = 0; i < args.length; ++i) {
|
||||
p = this.rts_apply(p, arg_mk_funcs[i](args[i]));
|
||||
}
|
||||
p = this.rts_apply(run_func, p);
|
||||
const tid = await eval_func(p);
|
||||
if (ret_get_funcs.length) {
|
||||
return ret_get_funcs[0](this.context.scheduler.getTSOret(tid));
|
||||
try {
|
||||
if (args.length !== arg_mk_funcs.length) {
|
||||
throw new WebAssembly.RuntimeError(
|
||||
`Expected ${arg_mk_funcs.length} arguments, got ${args.length}`
|
||||
);
|
||||
}
|
||||
let p = this.context.stablePtrManager.deRefStablePtr(sp);
|
||||
for (let i = 0; i < args.length; ++i) {
|
||||
p = this.rts_apply(p, arg_mk_funcs[i](args[i]));
|
||||
}
|
||||
p = this.rts_apply(run_func, p);
|
||||
const tid = await eval_func(p);
|
||||
if (ret_get_funcs.length) {
|
||||
return ret_get_funcs[0](this.context.scheduler.getTSOret(tid));
|
||||
}
|
||||
} finally {
|
||||
finalizer();
|
||||
}
|
||||
};
|
||||
this.context.callbackStablePtrs.set(cb, sp);
|
||||
|
@ -175,7 +175,20 @@ export async function newAsteriusInstance(req) {
|
||||
Unicode: modulify(__asterius_unicode),
|
||||
Tracing: modulify(__asterius_tracer),
|
||||
Exports: {
|
||||
newHaskellCallback: (sp, arg_tag, ret_tag, io) => __asterius_stableptr_manager.newJSVal(__asterius_exports.newHaskellCallback(sp, arg_tag, ret_tag, io)),
|
||||
newHaskellCallback: (sp, arg_tag, ret_tag, io, oneshot) => {
|
||||
let sn = [];
|
||||
let cb = __asterius_exports.newHaskellCallback(
|
||||
sp,
|
||||
arg_tag,
|
||||
ret_tag,
|
||||
io,
|
||||
oneshot
|
||||
? () => __asterius_exports.freeHaskellCallback(sn[0])
|
||||
: () => {}
|
||||
);
|
||||
sn[0] = __asterius_stableptr_manager.newJSVal(cb);
|
||||
return sn[0];
|
||||
},
|
||||
freeHaskellCallback: sn => __asterius_exports.freeHaskellCallback(sn)
|
||||
},
|
||||
Scheduler: modulify(__asterius_scheduler)
|
||||
@ -190,11 +203,14 @@ export async function newAsteriusInstance(req) {
|
||||
__asterius_scheduler.setGC(__asterius_gc);
|
||||
|
||||
for (const [f, p, a, r, i] of req.exportsStatic) {
|
||||
__asterius_exports[f] = __asterius_exports.newHaskellCallback(
|
||||
__asterius_exports[
|
||||
f
|
||||
] = __asterius_exports.newHaskellCallback(
|
||||
__asterius_stableptr_manager.newStablePtr(p),
|
||||
a,
|
||||
r,
|
||||
i
|
||||
i,
|
||||
() => {}
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -18,7 +18,7 @@ exportsImports =
|
||||
externalBaseName = "newHaskellCallback",
|
||||
functionType =
|
||||
FunctionType
|
||||
{ paramTypes = [F64, F64, F64, F64],
|
||||
{ paramTypes = [F64, F64, F64, F64, F64],
|
||||
returnTypes = [F64]
|
||||
}
|
||||
},
|
||||
@ -36,7 +36,7 @@ exportsCBits = newHaskellCallback <> freeHaskellCallback
|
||||
newHaskellCallback :: AsteriusModule
|
||||
newHaskellCallback = runEDSL "newHaskellCallback" $ do
|
||||
setReturnTypes [I64]
|
||||
args <- params [I64, I64, I64, I64]
|
||||
args <- params [I64, I64, I64, I64, I64]
|
||||
truncUFloat64ToInt64
|
||||
<$> callImport'
|
||||
"__asterius_newHaskellCallback"
|
||||
|
@ -1,6 +1,7 @@
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
{-# LANGUAGE RecordWildCards #-}
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE ViewPatterns #-}
|
||||
{-# OPTIONS_GHC -Wno-name-shadowing #-}
|
||||
|
||||
module Asterius.Foreign.DsForeign
|
||||
@ -49,8 +50,8 @@ asteriusDsForeigns fos = do
|
||||
do_decl (XForeignDecl _) = panic "asteriusDsForeigns"
|
||||
|
||||
asteriusDsFImport :: Id -> Coercion -> ForeignImport -> DsM [Binding]
|
||||
asteriusDsFImport id co (CImport cconv safety mHeader spec _) =
|
||||
asteriusDsCImport id co spec (unLoc cconv) (unLoc safety) mHeader
|
||||
asteriusDsFImport id co (CImport cconv safety mHeader spec (unLoc -> src)) =
|
||||
asteriusDsCImport id co spec (unLoc cconv) (unLoc safety) mHeader src
|
||||
|
||||
asteriusDsCImport ::
|
||||
Id ->
|
||||
@ -59,12 +60,13 @@ asteriusDsCImport ::
|
||||
CCallConv ->
|
||||
Safety ->
|
||||
Maybe Header ->
|
||||
SourceText ->
|
||||
DsM [Binding]
|
||||
asteriusDsCImport id co (CFunction target) cconv safety _ =
|
||||
asteriusDsCImport id co (CFunction target) cconv safety _ _ =
|
||||
asteriusDsFCall id co (CCall (CCallSpec target cconv safety))
|
||||
asteriusDsCImport id co CWrapper JavaScriptCallConv _ _ =
|
||||
asteriusDsFExportDynamic id co
|
||||
asteriusDsCImport id co spec cconv safety mHeader = do
|
||||
asteriusDsCImport id co CWrapper JavaScriptCallConv _ _ src =
|
||||
asteriusDsFExportDynamic id co src
|
||||
asteriusDsCImport id co spec cconv safety mHeader _ = do
|
||||
(r, _, _) <- dsCImport id co spec cconv safety mHeader
|
||||
pure r
|
||||
|
||||
@ -106,8 +108,8 @@ asteriusDsFCall fn_id co fcall = do
|
||||
fn_id `setIdUnfolding` mkInlineUnfoldingWithArity (length args) wrap_rhs'
|
||||
return [(work_id, work_rhs), (fn_id_w_inl, wrap_rhs')]
|
||||
|
||||
asteriusDsFExportDynamic :: Id -> Coercion -> DsM [Binding]
|
||||
asteriusDsFExportDynamic id co0 = do
|
||||
asteriusDsFExportDynamic :: Id -> Coercion -> SourceText -> DsM [Binding]
|
||||
asteriusDsFExportDynamic id co0 src = do
|
||||
dflags <- getDynFlags
|
||||
cback <- newSysLocalDs arg_ty
|
||||
newStablePtrId <- dsLookupGlobalId newStablePtrName
|
||||
@ -119,7 +121,8 @@ asteriusDsFExportDynamic id co0 = do
|
||||
[ Var stbl_value,
|
||||
mkIntLitInt dflags (fromIntegral ffi_params_tag),
|
||||
mkIntLitInt dflags (fromIntegral ffi_ret_tag),
|
||||
mkIntLitInt dflags (if ffiInIO then 1 else 0)
|
||||
mkIntLitInt dflags (if ffiInIO then 1 else 0),
|
||||
mkIntLitInt dflags (if oneshot then 1 else 0)
|
||||
]
|
||||
new_hs_callback = fsLit "newHaskellCallback"
|
||||
ccall_adj <-
|
||||
@ -140,6 +143,10 @@ asteriusDsFExportDynamic id co0 = do
|
||||
fed = (id `setInlineActivation` NeverActive, Cast io_app co0)
|
||||
return [fed]
|
||||
where
|
||||
oneshot
|
||||
| src == SourceText "\"wrapper\"" = False
|
||||
| src == SourceText "\"wrapper oneshot\"" = True
|
||||
| otherwise = error "asteriusDsFExportDynamic"
|
||||
ty = pFst (coercionKind co0)
|
||||
(tvs, sans_foralls) = tcSplitForAllTys ty
|
||||
([arg_ty], fn_res_ty) = tcSplitFunTys sans_foralls
|
||||
|
@ -58,7 +58,10 @@ asteriusTcFImport d = pprPanic "asteriusTcFImport" (ppr d)
|
||||
|
||||
asteriusTcCheckFIType :: [Type] -> Type -> ForeignImport -> TcM ForeignImport
|
||||
asteriusTcCheckFIType arg_tys res_ty (CImport (L lc cconv) (L ls safety) mh (CFunction target) src)
|
||||
| cconv == JavaScriptCallConv && unLoc src == SourceText "\"wrapper\"" =
|
||||
| cconv == JavaScriptCallConv && unLoc src
|
||||
`elem` map
|
||||
SourceText
|
||||
["\"wrapper\"", "\"wrapper oneshot\""] =
|
||||
do
|
||||
case arg_tys of
|
||||
[arg1_ty] -> do
|
||||
|
Loading…
Reference in New Issue
Block a user