1
1
mirror of https://github.com/tweag/asterius.git synced 2024-10-05 21:30:49 +03:00

Support exporting one-shot Haskell function closures dynamically (#527)

This commit is contained in:
Shao Cheng 2020-03-30 20:38:57 +02:00 committed by GitHub
parent 031eccd2bb
commit 547d8d426c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 59 additions and 29 deletions

View File

@ -68,7 +68,7 @@ export class Exports {
return this.context.scheduler.submitCmdCreateThread("createIOThread", p);
}
newHaskellCallback(sp, arg_tag, ret_tag, io) {
newHaskellCallback(sp, arg_tag, ret_tag, io, finalizer) {
const arg_mk_funcs = decodeTys(this.context.rtsMkFuncs, arg_tag),
ret_get_funcs = decodeTys(this.context.rtsGetFuncs, ret_tag),
run_func = this.context.symbolTable[
@ -83,19 +83,23 @@ export class Exports {
throw new WebAssembly.RuntimeError(`Multiple returns not supported`);
}
const cb = async (...args) => {
if (args.length !== arg_mk_funcs.length) {
throw new WebAssembly.RuntimeError(
`Expected ${arg_mk_funcs.length} arguments, got ${args.length}`
);
}
let p = this.context.stablePtrManager.deRefStablePtr(sp);
for (let i = 0; i < args.length; ++i) {
p = this.rts_apply(p, arg_mk_funcs[i](args[i]));
}
p = this.rts_apply(run_func, p);
const tid = await eval_func(p);
if (ret_get_funcs.length) {
return ret_get_funcs[0](this.context.scheduler.getTSOret(tid));
try {
if (args.length !== arg_mk_funcs.length) {
throw new WebAssembly.RuntimeError(
`Expected ${arg_mk_funcs.length} arguments, got ${args.length}`
);
}
let p = this.context.stablePtrManager.deRefStablePtr(sp);
for (let i = 0; i < args.length; ++i) {
p = this.rts_apply(p, arg_mk_funcs[i](args[i]));
}
p = this.rts_apply(run_func, p);
const tid = await eval_func(p);
if (ret_get_funcs.length) {
return ret_get_funcs[0](this.context.scheduler.getTSOret(tid));
}
} finally {
finalizer();
}
};
this.context.callbackStablePtrs.set(cb, sp);

View File

@ -175,7 +175,20 @@ export async function newAsteriusInstance(req) {
Unicode: modulify(__asterius_unicode),
Tracing: modulify(__asterius_tracer),
Exports: {
newHaskellCallback: (sp, arg_tag, ret_tag, io) => __asterius_stableptr_manager.newJSVal(__asterius_exports.newHaskellCallback(sp, arg_tag, ret_tag, io)),
newHaskellCallback: (sp, arg_tag, ret_tag, io, oneshot) => {
let sn = [];
let cb = __asterius_exports.newHaskellCallback(
sp,
arg_tag,
ret_tag,
io,
oneshot
? () => __asterius_exports.freeHaskellCallback(sn[0])
: () => {}
);
sn[0] = __asterius_stableptr_manager.newJSVal(cb);
return sn[0];
},
freeHaskellCallback: sn => __asterius_exports.freeHaskellCallback(sn)
},
Scheduler: modulify(__asterius_scheduler)
@ -190,11 +203,14 @@ export async function newAsteriusInstance(req) {
__asterius_scheduler.setGC(__asterius_gc);
for (const [f, p, a, r, i] of req.exportsStatic) {
__asterius_exports[f] = __asterius_exports.newHaskellCallback(
__asterius_exports[
f
] = __asterius_exports.newHaskellCallback(
__asterius_stableptr_manager.newStablePtr(p),
a,
r,
i
i,
() => {}
);
}

View File

@ -18,7 +18,7 @@ exportsImports =
externalBaseName = "newHaskellCallback",
functionType =
FunctionType
{ paramTypes = [F64, F64, F64, F64],
{ paramTypes = [F64, F64, F64, F64, F64],
returnTypes = [F64]
}
},
@ -36,7 +36,7 @@ exportsCBits = newHaskellCallback <> freeHaskellCallback
newHaskellCallback :: AsteriusModule
newHaskellCallback = runEDSL "newHaskellCallback" $ do
setReturnTypes [I64]
args <- params [I64, I64, I64, I64]
args <- params [I64, I64, I64, I64, I64]
truncUFloat64ToInt64
<$> callImport'
"__asterius_newHaskellCallback"

View File

@ -1,6 +1,7 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
module Asterius.Foreign.DsForeign
@ -49,8 +50,8 @@ asteriusDsForeigns fos = do
do_decl (XForeignDecl _) = panic "asteriusDsForeigns"
asteriusDsFImport :: Id -> Coercion -> ForeignImport -> DsM [Binding]
asteriusDsFImport id co (CImport cconv safety mHeader spec _) =
asteriusDsCImport id co spec (unLoc cconv) (unLoc safety) mHeader
asteriusDsFImport id co (CImport cconv safety mHeader spec (unLoc -> src)) =
asteriusDsCImport id co spec (unLoc cconv) (unLoc safety) mHeader src
asteriusDsCImport ::
Id ->
@ -59,12 +60,13 @@ asteriusDsCImport ::
CCallConv ->
Safety ->
Maybe Header ->
SourceText ->
DsM [Binding]
asteriusDsCImport id co (CFunction target) cconv safety _ =
asteriusDsCImport id co (CFunction target) cconv safety _ _ =
asteriusDsFCall id co (CCall (CCallSpec target cconv safety))
asteriusDsCImport id co CWrapper JavaScriptCallConv _ _ =
asteriusDsFExportDynamic id co
asteriusDsCImport id co spec cconv safety mHeader = do
asteriusDsCImport id co CWrapper JavaScriptCallConv _ _ src =
asteriusDsFExportDynamic id co src
asteriusDsCImport id co spec cconv safety mHeader _ = do
(r, _, _) <- dsCImport id co spec cconv safety mHeader
pure r
@ -106,8 +108,8 @@ asteriusDsFCall fn_id co fcall = do
fn_id `setIdUnfolding` mkInlineUnfoldingWithArity (length args) wrap_rhs'
return [(work_id, work_rhs), (fn_id_w_inl, wrap_rhs')]
asteriusDsFExportDynamic :: Id -> Coercion -> DsM [Binding]
asteriusDsFExportDynamic id co0 = do
asteriusDsFExportDynamic :: Id -> Coercion -> SourceText -> DsM [Binding]
asteriusDsFExportDynamic id co0 src = do
dflags <- getDynFlags
cback <- newSysLocalDs arg_ty
newStablePtrId <- dsLookupGlobalId newStablePtrName
@ -119,7 +121,8 @@ asteriusDsFExportDynamic id co0 = do
[ Var stbl_value,
mkIntLitInt dflags (fromIntegral ffi_params_tag),
mkIntLitInt dflags (fromIntegral ffi_ret_tag),
mkIntLitInt dflags (if ffiInIO then 1 else 0)
mkIntLitInt dflags (if ffiInIO then 1 else 0),
mkIntLitInt dflags (if oneshot then 1 else 0)
]
new_hs_callback = fsLit "newHaskellCallback"
ccall_adj <-
@ -140,6 +143,10 @@ asteriusDsFExportDynamic id co0 = do
fed = (id `setInlineActivation` NeverActive, Cast io_app co0)
return [fed]
where
oneshot
| src == SourceText "\"wrapper\"" = False
| src == SourceText "\"wrapper oneshot\"" = True
| otherwise = error "asteriusDsFExportDynamic"
ty = pFst (coercionKind co0)
(tvs, sans_foralls) = tcSplitForAllTys ty
([arg_ty], fn_res_ty) = tcSplitFunTys sans_foralls

View File

@ -58,7 +58,10 @@ asteriusTcFImport d = pprPanic "asteriusTcFImport" (ppr d)
asteriusTcCheckFIType :: [Type] -> Type -> ForeignImport -> TcM ForeignImport
asteriusTcCheckFIType arg_tys res_ty (CImport (L lc cconv) (L ls safety) mh (CFunction target) src)
| cconv == JavaScriptCallConv && unLoc src == SourceText "\"wrapper\"" =
| cconv == JavaScriptCallConv && unLoc src
`elem` map
SourceText
["\"wrapper\"", "\"wrapper oneshot\""] =
do
case arg_tys of
[arg1_ty] -> do