From 6b9b03e366ec5fdb0df3377af19dc48d7d6da80f Mon Sep 17 00:00:00 2001 From: Shao Cheng Date: Tue, 18 Jun 2019 12:10:30 +0300 Subject: [PATCH] Asyncify the rts api (#190) --- asterius/rts/rts.exports.mjs | 38 ++++++ asterius/rts/rts.mjs | 39 +++--- asterius/rts/rts.tso.mjs | 15 ++- asterius/src/Asterius/Builtins.hs | 69 +++-------- asterius/src/Asterius/JSFFI.hs | 156 ++++++++---------------- asterius/src/Asterius/JSRun/Main.hs | 4 +- asterius/src/Asterius/Ld.hs | 4 +- asterius/src/Asterius/Main.hs | 16 ++- asterius/src/Asterius/Resolve.hs | 38 +++--- asterius/test/cloudflare/cloudflare.mjs | 4 +- asterius/test/jsffi/jsffi.mjs | 12 +- asterius/test/nomain.hs | 12 +- asterius/test/rtsapi/rtsapi.mjs | 24 ++-- docs/rts-api.md | 10 +- 14 files changed, 202 insertions(+), 239 deletions(-) create mode 100644 asterius/rts/rts.exports.mjs diff --git a/asterius/rts/rts.exports.mjs b/asterius/rts/rts.exports.mjs new file mode 100644 index 00000000..8318072c --- /dev/null +++ b/asterius/rts/rts.exports.mjs @@ -0,0 +1,38 @@ +export class Exports { + constructor(reentrancy_guard, symbol_table, tso_manager, exports) { + this.context = Object.freeze({ + reentrancyGuard: reentrancy_guard, + symbolTable: symbol_table, + tsoManager: tso_manager + }); + Object.assign(this, exports); + } + + async rts_eval(p) { + this.context.reentrancyGuard.enter(0); + const tso = this.createGenThread(p); + this.scheduleWaitThread(tso, false); + this.context.reentrancyGuard.exit(0); + return this.context.tsoManager.getTSOid(tso); + } + + async rts_evalIO(p) { + this.context.reentrancyGuard.enter(0); + const tso = this.createStrictIOThread(p); + this.scheduleWaitThread(tso, false); + this.context.reentrancyGuard.exit(0); + return this.context.tsoManager.getTSOid(tso); + } + + async rts_evalLazyIO(p) { + this.context.reentrancyGuard.enter(0); + const tso = this.createIOThread(p); + this.scheduleWaitThread(tso, false); + this.context.reentrancyGuard.exit(0); + return this.context.tsoManager.getTSOid(tso); + } + + main() { + return this.rts_evalLazyIO(this.context.symbolTable.Main_main_closure); + } +} diff --git a/asterius/rts/rts.mjs b/asterius/rts/rts.mjs index da164613..306bb853 100644 --- a/asterius/rts/rts.mjs +++ b/asterius/rts/rts.mjs @@ -23,6 +23,7 @@ import { ThreadPaused } from "./rts.threadpaused.mjs"; import { MD5 } from "./rts.md5.mjs" import { FloatCBits } from "./rts.float.mjs"; import { Unicode } from "./rts.unicode.mjs"; +import { Exports } from "./rts.exports.mjs"; import * as rtsConstants from "./rts.constants.mjs"; export function newAsteriusInstance(req) { @@ -37,7 +38,7 @@ export function newAsteriusInstance(req) { __asterius_mblockalloc = new MBlockAlloc(), __asterius_heapalloc = new HeapAlloc(__asterius_memory, __asterius_mblockalloc), __asterius_stableptr_manager = new StablePtrManager(), - __asterius_tso_manager = new TSOManager(), + __asterius_tso_manager = new TSOManager(__asterius_memory, req.symbolTable), __asterius_heap_builder = new HeapBuilder(req.symbolTable, __asterius_heapalloc, __asterius_memory, __asterius_stableptr_manager), __asterius_integer_manager = new IntegerManager(__asterius_stableptr_manager, __asterius_heap_builder), __asterius_fs = new MemoryFileSystem(__asterius_logger), @@ -46,9 +47,10 @@ export function newAsteriusInstance(req) { __asterius_exception_helper = new ExceptionHelper(__asterius_memory, __asterius_heapalloc, req.infoTables, req.symbolTable), __asterius_threadpaused = new ThreadPaused(__asterius_memory, req.infoTables, req.symbolTable), __asterius_float_cbits = new FloatCBits(__asterius_memory), + __asterius_messages = new Messages(__asterius_memory, __asterius_fs), __asterius_unicode = new Unicode(), - __asterius_md5 = new MD5(__asterius_memory), - __asterius_messages = new Messages(__asterius_memory, __asterius_fs); + __asterius_exports = new Exports(__asterius_reentrancy_guard, req.symbolTable, __asterius_tso_manager, req.exports), + __asterius_md5 = new MD5(__asterius_memory); function __asterius_show_I64(x) { return "0x" + x.toString(16).padStart(8, "0"); @@ -67,30 +69,30 @@ export function newAsteriusInstance(req) { newTmpJSVal: v => __asterius_stableptr_manager.newTmpJSVal(v), mutTmpJSVal: (i, f) => __asterius_stableptr_manager.mutTmpJSVal(i, f), freezeTmpJSVal: i => __asterius_stableptr_manager.freezeTmpJSVal(i), - makeHaskellCallback: sp => () => { - const tid = __asterius_wasm_instance.exports.rts_evalLazyIO(__asterius_stableptr_manager.deRefStablePtr(sp)); - __asterius_wasm_instance.exports.rts_checkSchedStatus(tid); + makeHaskellCallback: sp => async () => { + const tid = await __asterius_exports.rts_evalLazyIO(__asterius_stableptr_manager.deRefStablePtr(sp)); + __asterius_exports.rts_checkSchedStatus(tid); }, - makeHaskellCallback1: sp => ev => { - const tid = __asterius_wasm_instance.exports.rts_evalLazyIO( - __asterius_wasm_instance.exports.rts_apply( + makeHaskellCallback1: sp => async ev => { + const tid = await __asterius_exports.rts_evalLazyIO( + __asterius_exports.rts_apply( __asterius_stableptr_manager.deRefStablePtr(sp), - __asterius_wasm_instance.exports.rts_mkInt( + __asterius_exports.rts_mkInt( __asterius_stableptr_manager.newJSVal(ev) ) ) ); - __asterius_wasm_instance.exports.rts_checkSchedStatus(tid); + __asterius_exports.rts_checkSchedStatus(tid); }, - makeHaskellCallback2: sp => (x, y) => { - const tid = __asterius_wasm_instance.exports.rts_evalLazyIO( - __asterius_wasm_instance.exports.rts_apply( - __asterius_wasm_instance.exports.rts_apply( - __asterius_stableptr_manager.deRefStablePtr(sp), __asterius_wasm_instance.exports.rts_mkInt( + makeHaskellCallback2: sp => async (x, y) => { + const tid = await __asterius_exports.rts_evalLazyIO( + __asterius_exports.rts_apply( + __asterius_exports.rts_apply( + __asterius_stableptr_manager.deRefStablePtr(sp), __asterius_exports.rts_mkInt( __asterius_stableptr_manager.newJSVal(x))), - __asterius_wasm_instance.exports.rts_mkInt( + __asterius_exports.rts_mkInt( __asterius_stableptr_manager.newJSVal(y)))); - __asterius_wasm_instance.exports.rts_checkSchedStatus(tid); + __asterius_exports.rts_checkSchedStatus(tid); }, Integer: __asterius_integer_manager, FloatCBits: __asterius_float_cbits, @@ -162,6 +164,7 @@ export function newAsteriusInstance(req) { return Object.assign(__asterius_jsffi_instance, { wasmModule: req.module, wasmInstance: __asterius_wasm_instance, + exports: Object.freeze(Object.assign(__asterius_exports, __asterius_wasm_instance.exports)), symbolTable: req.symbolTable, logger: __asterius_logger }); diff --git a/asterius/rts/rts.tso.mjs b/asterius/rts/rts.tso.mjs index 79f4258d..4d3bf013 100644 --- a/asterius/rts/rts.tso.mjs +++ b/asterius/rts/rts.tso.mjs @@ -1,16 +1,21 @@ +import { Memory } from "./rts.memory.mjs"; +import * as rtsConstants from "./rts.constants.mjs"; + class TSO { constructor() { this.addr = undefined; this.ret = undefined; this.rstat = undefined; - Object.seal(this); } } export class TSOManager { - constructor() { + constructor(memory, symbol_table) { + this.memory = memory; + this.symbolTable = symbol_table; + this.last = 0; this.tsos = []; - Object.freeze(this); + Object.seal(this); } newTSO() { return this.tsos.push(new TSO()) - 1; } @@ -26,4 +31,8 @@ export class TSOManager { setTSOret(i, ret) { this.tsos[i].ret = ret; } setTSOrstat(i, rstat) { this.tsos[i].rstat = rstat; } + + getTSOid(tso) { + return this.memory.i32Load(tso + rtsConstants.offset_StgTSO_id); + } } diff --git a/asterius/src/Asterius/Builtins.hs b/asterius/src/Asterius/Builtins.hs index c0bc59f8..d1559ddb 100644 --- a/asterius/src/Asterius/Builtins.hs +++ b/asterius/src/Asterius/Builtins.hs @@ -119,13 +119,8 @@ rtsAsteriusModule opts = Map.fromList $ map (\(func_sym, (_, func)) -> (func_sym, func)) (byteStringCBits <> floatCBits <> unicodeCBits <> md5CBits) - } <> mainFunction opts - <> hsInitFunction opts - <> scheduleWaitThreadFunction opts + } <> hsInitFunction opts <> createThreadFunction opts - <> createGenThreadFunction opts - <> createIOThreadFunction opts - <> createStrictIOThreadFunction opts <> genAllocateFunction opts "allocate" <> genAllocateFunction opts "allocateMightFail" <> allocatePinnedFunction opts @@ -168,9 +163,10 @@ rtsAsteriusModule opts = generateRtsExternalInterfaceModule :: BuiltinsOptions -> AsteriusModule generateRtsExternalInterfaceModule opts = mempty <> rtsApplyFunction opts - <> rtsEvalFunction opts - <> rtsEvalIOFunction opts - <> rtsEvalLazyIOFunction opts + <> createGenThreadFunction opts + <> createIOThreadFunction opts + <> createStrictIOThreadFunction opts + <> scheduleWaitThreadFunction opts <> rtsGetSchedStatusFunction opts <> rtsCheckSchedStatusFunction opts <> getStablePtrWrapperFunction opts @@ -505,8 +501,8 @@ rtsFunctionImports debug = else []) <> map (fst . snd) (byteStringCBits <> floatCBits <> unicodeCBits <> md5CBits) -rtsFunctionExports :: Bool -> Bool -> [FunctionExport] -rtsFunctionExports debug has_main = +rtsFunctionExports :: Bool -> [FunctionExport] +rtsFunctionExports debug = [ FunctionExport {internalName = f <> "_wrapper", externalName = f} | f <- [ "loadI64" @@ -525,9 +521,10 @@ rtsFunctionExports debug has_main = , "rts_getPtr" , "rts_getStablePtr" , "rts_apply" - , "rts_eval" - , "rts_evalIO" - , "rts_evalLazyIO" + , "createGenThread" + , "createStrictIOThread" + , "createIOThread" + , "scheduleWaitThread" , "rts_getSchedStatus" , "rts_checkSchedStatus" , "getStablePtr" @@ -547,8 +544,7 @@ rtsFunctionExports debug has_main = , "__asterius_Load_HpLim" ] else []) <> - ["hs_init"] <> - ["main" | has_main] + ["hs_init"] ] emitErrorMessage :: [ValueType] -> SBS.ShortByteString -> Expression @@ -697,12 +693,8 @@ generateWrapperModule mod = mod { -mainFunction, hsInitFunction, rtsApplyFunction, rtsEvalFunction, rtsEvalIOFunction, rtsEvalLazyIOFunction, rtsGetSchedStatusFunction, rtsCheckSchedStatusFunction, scheduleWaitThreadFunction, createThreadFunction, createGenThreadFunction, createIOThreadFunction, createStrictIOThreadFunction, allocatePinnedFunction, newCAFFunction, stgReturnFunction, getStablePtrWrapperFunction, deRefStablePtrWrapperFunction, freeStablePtrWrapperFunction, rtsMkBoolFunction, rtsMkDoubleFunction, rtsMkCharFunction, rtsMkIntFunction, rtsMkWordFunction, rtsMkPtrFunction, rtsMkStablePtrFunction, rtsGetBoolFunction, rtsGetDoubleFunction, loadI64Function, printI64Function, assertEqI64Function, printF32Function, printF64Function, strlenFunction, memchrFunction, memcpyFunction, memsetFunction, memcmpFunction, fromJSArrayBufferFunction, toJSArrayBufferFunction, fromJSStringFunction, fromJSArrayFunction, threadPausedFunction, dirtyMutVarFunction, raiseExceptionHelperFunction, barfFunction, getProgArgvFunction, suspendThreadFunction, resumeThreadFunction, performMajorGCFunction, performGCFunction, localeEncodingFunction :: +hsInitFunction, rtsApplyFunction, rtsGetSchedStatusFunction, rtsCheckSchedStatusFunction, scheduleWaitThreadFunction, createThreadFunction, createGenThreadFunction, createIOThreadFunction, createStrictIOThreadFunction, allocatePinnedFunction, newCAFFunction, stgReturnFunction, getStablePtrWrapperFunction, deRefStablePtrWrapperFunction, freeStablePtrWrapperFunction, rtsMkBoolFunction, rtsMkDoubleFunction, rtsMkCharFunction, rtsMkIntFunction, rtsMkWordFunction, rtsMkPtrFunction, rtsMkStablePtrFunction, rtsGetBoolFunction, rtsGetDoubleFunction, loadI64Function, printI64Function, assertEqI64Function, printF32Function, printF64Function, strlenFunction, memchrFunction, memcpyFunction, memsetFunction, memcmpFunction, fromJSArrayBufferFunction, toJSArrayBufferFunction, fromJSStringFunction, fromJSArrayFunction, threadPausedFunction, dirtyMutVarFunction, raiseExceptionHelperFunction, barfFunction, getProgArgvFunction, suspendThreadFunction, resumeThreadFunction, performMajorGCFunction, performGCFunction, localeEncodingFunction :: BuiltinsOptions -> AsteriusModule -mainFunction BuiltinsOptions {} = - runEDSL "main" $ do - tid <- call' "rts_evalLazyIO" [symbol "Main_main_closure"] I32 - call "rts_checkSchedStatus" [tid] initCapability :: EDSL () initCapability = do @@ -739,16 +731,6 @@ enter i = callImport "__asterius_enter" [constI32 i] exit i = callImport "__asterius_exit" [constI32 i] -rtsEvalHelper :: BuiltinsOptions -> AsteriusEntitySymbol -> EDSL () -rtsEvalHelper BuiltinsOptions {..} create_thread_func_sym = do - setReturnTypes [I32] - p <- param I64 - enter 0 - tso <- call' create_thread_func_sym [p] I64 - call "scheduleWaitThread" [tso] - exit 0 - emit $ loadI32 tso offset_StgTSO_id - rtsApplyFunction _ = runEDSL "rts_apply" $ do setReturnTypes [I64] @@ -763,13 +745,6 @@ rtsApplyFunction _ = storeI64 ap (offset_StgThunk_payload + 8) arg emit ap -rtsEvalFunction opts = runEDSL "rts_eval" $ rtsEvalHelper opts "createGenThread" - -rtsEvalIOFunction opts = - runEDSL "rts_evalIO" $ rtsEvalHelper opts "createStrictIOThread" - -rtsEvalLazyIOFunction opts = runEDSL "rts_evalLazyIO" $ rtsEvalHelper opts "createIOThread" - rtsGetSchedStatusFunction _ = runEDSL "rts_getSchedStatus" $ do setReturnTypes [I32] @@ -801,7 +776,8 @@ dirtySTACK _ stack = scheduleWaitThreadFunction BuiltinsOptions {} = runEDSL "scheduleWaitThread" $ do - t <- param I64 + [t, load_regs] <- params [I64, I32] + tid <- i32Local $ loadI32 t offset_StgTSO_id block' [] $ \sched_block_lbl -> loop' [] $ \sched_loop_lbl -> do storeI64 @@ -844,7 +820,7 @@ scheduleWaitThreadFunction BuiltinsOptions {} = constI32 next_ThreadComplete) (do callImport "__asterius_setTSOret" - [ loadI32 t offset_StgTSO_id + [ tid , convertUInt64ToFloat64 $ loadI64 (loadI64 @@ -854,22 +830,15 @@ scheduleWaitThreadFunction BuiltinsOptions {} = ] callImport "__asterius_setTSOrstat" - [ loadI32 t offset_StgTSO_id - , constI32 scheduler_Success - ] + [tid, constI32 scheduler_Success] break' sched_block_lbl Nothing) - (do callImport - "__asterius_setTSOret" - [loadI32 t offset_StgTSO_id, ConstF64 0] + (do callImport "__asterius_setTSOret" [tid, ConstF64 0] callImport "__asterius_setTSOrstat" - [ loadI32 t offset_StgTSO_id - , constI32 scheduler_Killed - ] + [tid, constI32 scheduler_Killed] break' sched_block_lbl Nothing)) ] , emit $ emitErrorMessage [] "IllegalThreadReturnCode") - callImport "__asterius_gcRootTSO" [convertUInt64ToFloat64 t] createThreadFunction BuiltinsOptions {..} = runEDSL "createThread" $ do diff --git a/asterius/src/Asterius/JSFFI.hs b/asterius/src/Asterius/JSFFI.hs index 68dba8dc..105c1e42 100644 --- a/asterius/src/Asterius/JSFFI.hs +++ b/asterius/src/Asterius/JSFFI.hs @@ -11,6 +11,7 @@ module Asterius.JSFFI ( addFFIProcessor , generateFFIFunctionImports , generateFFIImportObjectFactory + , generateFFIExportObject ) where import Asterius.Builtins @@ -467,106 +468,14 @@ generateFFIImportWrapperFunction k FFIImportDecl {..} = import_func_type = recoverWasmImportFunctionType ffiFunctionType wrapper_func_type = recoverWasmWrapperFunctionType ffiFunctionType -generateFFIExportFunction :: FFIExportDecl -> Function -generateFFIExportFunction FFIExportDecl {..} = - adjustLocalRegs - Function - { functionType = recoverWasmWrapperFunctionType ffiFunctionType - , varTypes = [] - , body = - Block - { name = "" - , bodys = - [ UnresolvedSetLocal - { unresolvedLocalReg = tid - , value = - Call - { target = - if ffiInIO ffiFunctionType - then "rts_evalIO" - else "rts_eval" - , operands = - [ foldl' - (\tot_expr (ffi_param_i, ffi_param_t) -> - Call - { target = "rts_apply" - , operands = - [ tot_expr - , Call - { target = - AsteriusEntitySymbol - { entityName = - "rts_mk" <> getHsTyCon ffi_param_t - } - , operands = - [ GetLocal - { index = ffi_param_i - , valueType = - recoverWasmWrapperValueType - ffi_param_t - } - ] - , callReturnTypes = [I64] - } - ] - , callReturnTypes = [I64] - }) - Symbol - { unresolvedSymbol = ffiExportClosure - , symbolOffset = 0 - } - (zip [0 ..] $ ffiParamTypes ffiFunctionType) - ] - , callReturnTypes = [I32] - } - } - , Call - { target = "rts_checkSchedStatus" - , operands = [UnresolvedGetLocal {unresolvedLocalReg = tid}] - , callReturnTypes = [] - } - ] <> - case ffiResultTypes ffiFunctionType of - [ffi_result_t] -> - [ Call - { target = - AsteriusEntitySymbol - {entityName = "rts_get" <> getHsTyCon ffi_result_t} - , operands = - [ Unary TruncUFloat64ToInt64 $ - CallImport - { target' = "__asterius_getTSOret" - , operands = - [ UnresolvedGetLocal - {unresolvedLocalReg = tid} - ] - , callImportReturnTypes = [F64] - } - ] - , callReturnTypes = - [recoverWasmWrapperValueType ffi_result_t] - } - ] - _ -> [] - , blockReturnTypes = - map recoverWasmWrapperValueType $ ffiResultTypes ffiFunctionType - } - } - where - tid = UniqueLocalReg 0 I32 - getHsTyCon FFI_VAL {..} = hsTyCon - getHsTyCon FFI_JSVAL = "StablePtr" - generateFFIWrapperModule :: FFIMarshalState -> AsteriusModule generateFFIWrapperModule mod_ffi_state@FFIMarshalState {..} = mempty { functionMap = - M.fromList $ - [ (k <> "_wrapper", wrapper_func) - | (k, wrapper_func) <- import_wrapper_funcs - ] <> - export_funcs <> - export_wrapper_funcs + M.fromList + [ (k <> "_wrapper", wrapper_func) + | (k, wrapper_func) <- import_wrapper_funcs + ] , ffiMarshalState = mod_ffi_state } where @@ -574,16 +483,6 @@ generateFFIWrapperModule mod_ffi_state@FFIMarshalState {..} = [ (k, generateFFIImportWrapperFunction k ffi_decl) | (k, ffi_decl) <- M.toList ffiImportDecls ] - export_funcs = - [ (k, generateFFIExportFunction ffi_decl) - | (k, ffi_decl) <- M.toList ffiExportDecls - ] - export_wrapper_funcs = - [ ( AsteriusEntitySymbol - {entityName = "__asterius_jsffi_export_" <> entityName k} - , generateWrapperFunction k f) - | (k, f) <- export_funcs - ] generateFFIFunctionImports :: FFIMarshalState -> [FunctionImport] generateFFIFunctionImports FFIMarshalState {..} = @@ -625,3 +524,48 @@ generateFFIImportObjectFactory FFIMarshalState {..} = | (k, ffi_decl) <- M.toList ffiImportDecls ]) <> "}})" + +generateFFIExportObject :: FFIMarshalState -> Builder +generateFFIExportObject FFIMarshalState {..} = + "Object.freeze({" <> + mconcat + (intersperse + "," + [ shortByteString (coerce k) <> ":" <> + generateFFIExportLambda export_decl + | (k, export_decl) <- M.toList ffiExportDecls + ]) <> + "})" + +generateFFIExportLambda :: FFIExportDecl -> Builder +generateFFIExportLambda FFIExportDecl { ffiFunctionType = FFIFunctionType {..} + , .. + } = + "function(" <> + mconcat (intersperse "," ["_" <> intDec i | i <- [1 .. length ffiParamTypes]]) <> + "){" <> + (if null ffiResultTypes + then tid + else "return " <> ret) <> + "}" + where + ret = + case ffiResultTypes of + [t] -> "this.rts_get" <> getHsTyCon t <> "(" <> ret_closure <> ")" + _ -> error "Asterius.JSFFI.generateFFIExportLambda" + ret_closure = "this.context.tsoManager.getTSOret(" <> tid <> ")" + tid = "this." <> eval_func <> "(" <> eval_closure <> ")" + eval_func + | ffiInIO = "rts_evalIO" + | otherwise = "rts_eval" + eval_closure = + foldl' + (\acc (i, t) -> + "this.rts_apply(" <> acc <> ",this.rts_mk" <> getHsTyCon t <> "(_" <> + intDec i <> + "))") + ("this.context.symbolTable." <> + shortByteString (coerce ffiExportClosure)) + (zip [1 ..] ffiParamTypes) + getHsTyCon FFI_VAL {..} = shortByteString hsTyCon + getHsTyCon FFI_JSVAL = "StablePtr" diff --git a/asterius/src/Asterius/JSRun/Main.hs b/asterius/src/Asterius/JSRun/Main.hs index d8be5824..ff1825ac 100644 --- a/asterius/src/Asterius/JSRun/Main.hs +++ b/asterius/src/Asterius/JSRun/Main.hs @@ -20,10 +20,10 @@ newAsteriusInstance s lib_path mod_buf = do eval s $ takeJSVal f_val <> "(" <> takeJSVal mod_val <> ")" hsInit :: JSSession -> JSVal -> IO () -hsInit s i = eval s $ deRefJSVal i <> ".wasmInstance.exports.hs_init()" +hsInit s i = eval s $ deRefJSVal i <> ".exports.hs_init()" hsMain :: JSSession -> JSVal -> IO () -hsMain s i = eval s $ deRefJSVal i <> ".wasmInstance.exports.main()" +hsMain s i = eval s $ deRefJSVal i <> ".exports.main()" hsStdOut :: JSSession -> JSVal -> IO LBS.ByteString hsStdOut s i = eval s $ deRefJSVal i <> ".stdio.stdout()" diff --git a/asterius/src/Asterius/Ld.hs b/asterius/src/Asterius/Ld.hs index 4d0b1a66..a41919d1 100644 --- a/asterius/src/Asterius/Ld.hs +++ b/asterius/src/Asterius/Ld.hs @@ -52,6 +52,7 @@ rtsUsedSymbols = , "ghczmprim_GHCziTypes_ZC_con_info" , "ghczmprim_GHCziTypes_ZMZN_closure" , "integerzmwiredzmin_GHCziIntegerziType_Integer_con_info" + , "Main_main_closure" , "stg_ARR_WORDS_info" , "stg_BLACKHOLE_info" , "stg_DEAD_WEAK_info" @@ -66,7 +67,6 @@ linkModules :: linkModules LinkTask {..} m = linkStart debug - True gcSections binaryen verboseErr @@ -79,7 +79,7 @@ linkModules LinkTask {..} m = , rtsUsedSymbols , Set.fromList [ AsteriusEntitySymbol {entityName = internalName} - | FunctionExport {..} <- rtsFunctionExports debug True + | FunctionExport {..} <- rtsFunctionExports debug ] ]) exportFunctions diff --git a/asterius/src/Asterius/Main.hs b/asterius/src/Asterius/Main.hs index ae54f52e..beffc88f 100644 --- a/asterius/src/Asterius/Main.hs +++ b/asterius/src/Asterius/Main.hs @@ -201,6 +201,8 @@ genLib Task {..} LinkReport {..} = , "export default module => \n" , "rts.newAsteriusInstance({module: module, jsffiFactory: " , generateFFIImportObjectFactory bundledFFIMarshalState + , ", exports: " + , generateFFIExportObject bundledFFIMarshalState , ", symbolTable: " , genSymbolDict symbol_table , ", infoTables: " @@ -223,7 +225,13 @@ genLib Task {..} LinkReport {..} = | fullSymTable = raw_symbol_table | otherwise = M.restrictKeys raw_symbol_table $ - S.fromList extraRootSymbols <> rtsUsedSymbols + S.fromList + [ ffiExportClosure + | FFIExportDecl {..} <- + M.elems $ ffiExportDecls bundledFFIMarshalState + ] <> + S.fromList extraRootSymbols <> + rtsUsedSymbols genDefEntry :: Task -> Builder genDefEntry Task {..} = @@ -242,13 +250,13 @@ genDefEntry Task {..} = , mconcat [ "module.then(m => " , out_base - , "(m)).then(i => {\n" + , "(m)).then(async i => {\n" , if debug then "i.logger.onEvent = ev => console.log(`[${ev.level}] ${ev.event}`);\n" else mempty , "try {\n" - , "i.wasmInstance.exports.hs_init();\n" - , "i.wasmInstance.exports.main();\n" + , "i.exports.hs_init();\n" + , "await i.exports.main();\n" , "} catch (err) {\n" , "console.log(i.stdio.stdout());\n" , "throw err;\n" diff --git a/asterius/src/Asterius/Resolve.hs b/asterius/src/Asterius/Resolve.hs index 8f5b140c..0505dce0 100644 --- a/asterius/src/Asterius/Resolve.hs +++ b/asterius/src/Asterius/Resolve.hs @@ -22,7 +22,6 @@ import Data.Binary import qualified Data.ByteString as BS import qualified Data.ByteString.Short as SBS import Data.Data (Data, gmapQl) -import Data.List import qualified Data.Map.Lazy as LM import qualified Data.Set as S import Data.String @@ -85,8 +84,9 @@ mergeSymbols :: -> Bool -> AsteriusModule -> S.Set AsteriusEntitySymbol + -> [AsteriusEntitySymbol] -> (AsteriusModule, LinkReport) -mergeSymbols _ gc_sections verbose_err store_mod root_syms +mergeSymbols _ gc_sections verbose_err store_mod root_syms export_funcs | not gc_sections = (store_mod, mempty {bundledFFIMarshalState = ffi_all}) | otherwise = (final_m, mempty {bundledFFIMarshalState = ffi_this}) where @@ -96,8 +96,17 @@ mergeSymbols _ gc_sections verbose_err store_mod root_syms { ffiImportDecls = flip LM.filterWithKey (ffiImportDecls ffi_all) $ \k _ -> (k <> "_wrapper") `LM.member` functionMap final_m + , ffiExportDecls = ffi_exports } - (_, _, final_m) = go (root_syms, S.empty, mempty) + ffi_exports + | not gc_sections = ffiExportDecls (ffiMarshalState store_mod) + | otherwise = + ffiExportDecls (ffiMarshalState store_mod) `LM.restrictKeys` + S.fromList export_funcs + root_syms' = + S.fromList [ffiExportClosure | FFIExportDecl {..} <- LM.elems ffi_exports] <> + root_syms + (_, _, final_m) = go (root_syms', S.empty, mempty) go i@(i_staging_syms, _, _) | S.null i_staging_syms = i | otherwise = go $ iter i @@ -163,9 +172,7 @@ makeInfoTableSet AsteriusModule {..} sym_map = resolveAsteriusModule :: Bool -> Bool - -> Bool -> FFIMarshalState - -> [AsteriusEntitySymbol] -> AsteriusModule -> Int64 -> Int64 @@ -174,7 +181,7 @@ resolveAsteriusModule :: , LM.Map AsteriusEntitySymbol Int64 , Int , Int) -resolveAsteriusModule debug has_main _ bundled_ffi_state export_funcs m_globals_resolved func_start_addr data_start_addr = +resolveAsteriusModule debug _ bundled_ffi_state m_globals_resolved func_start_addr data_start_addr = (new_mod, ss_sym_map, func_sym_map, table_slots, initial_mblocks) where (func_sym_map, last_func_addr) = @@ -195,12 +202,7 @@ resolveAsteriusModule debug has_main _ bundled_ffi_state export_funcs m_globals_ Module { functionMap' = new_function_map , functionImports = func_imports - , functionExports = - rtsFunctionExports debug has_main <> - [ FunctionExport - {internalName = "__asterius_jsffi_export_" <> k, externalName = k} - | k <- map entityName export_funcs - ] + , functionExports = rtsFunctionExports debug , functionTable = func_table , tableImport = TableImport @@ -220,12 +222,11 @@ linkStart :: -> Bool -> Bool -> Bool - -> Bool -> AsteriusModule -> S.Set AsteriusEntitySymbol -> [AsteriusEntitySymbol] -> (AsteriusModule, Module, LinkReport) -linkStart debug has_main gc_sections binaryen verbose_err store root_syms export_funcs = +linkStart debug gc_sections binaryen verbose_err store root_syms export_funcs = ( merged_m , result_m , report @@ -242,12 +243,7 @@ linkStart debug has_main gc_sections binaryen verbose_err store root_syms export gc_sections verbose_err store - (root_syms <> - S.fromList - [ AsteriusEntitySymbol - {entityName = "__asterius_jsffi_export_" <> entityName k} - | k <- export_funcs - ]) + root_syms export_funcs merged_m1 | debug = addMemoryTrap merged_m0 | otherwise = merged_m0 @@ -266,10 +262,8 @@ linkStart debug has_main gc_sections binaryen verbose_err store root_syms export (result_m, ss_sym_map, func_sym_map, tbl_slots, static_mbs) = resolveAsteriusModule debug - has_main binaryen (bundledFFIMarshalState report) - export_funcs merged_m (1 .|. functionTag `shiftL` 32) (dataTag `shiftL` 32) diff --git a/asterius/test/cloudflare/cloudflare.mjs b/asterius/test/cloudflare/cloudflare.mjs index 8c913174..540d6a93 100644 --- a/asterius/test/cloudflare/cloudflare.mjs +++ b/asterius/test/cloudflare/cloudflare.mjs @@ -1,5 +1,5 @@ import cloudflare from "./cloudflare.lib.mjs"; let i = cloudflare(m); -i.wasmInstance.exports.hs_init(); -i.wasmInstance.exports.main(); +i.exports.hs_init(); +i.exports.main(); diff --git a/asterius/test/jsffi/jsffi.mjs b/asterius/test/jsffi/jsffi.mjs index e480df40..841e9a59 100644 --- a/asterius/test/jsffi/jsffi.mjs +++ b/asterius/test/jsffi/jsffi.mjs @@ -3,10 +3,10 @@ import jsffi from "./jsffi.lib.mjs"; process.on("unhandledRejection", err => { throw err; }); -module.then(m => jsffi(m)).then(i => { - i.wasmInstance.exports.hs_init(); - i.wasmInstance.exports.main(); - console.log(i.wasmInstance.exports.mult_hs_int(9, 9)); - console.log(i.wasmInstance.exports.mult_hs_double(9, 9)); - i.wasmInstance.exports.putchar("H".codePointAt(0)); +module.then(m => jsffi(m)).then(async i => { + i.exports.hs_init(); + await i.exports.main(); + console.log(i.exports.mult_hs_int(9, 9)); + console.log(i.exports.mult_hs_double(9, 9)); + i.exports.putchar("H".codePointAt(0)); }); diff --git a/asterius/test/nomain.hs b/asterius/test/nomain.hs index 056f1ab6..315caa1d 100644 --- a/asterius/test/nomain.hs +++ b/asterius/test/nomain.hs @@ -27,12 +27,10 @@ main = do hsInit s i let x_closure = deRefJSVal i <> ".symbolTable.NoMain_x_closure" x_tid = - deRefJSVal i <> ".wasmInstance.exports.rts_eval(" <> x_closure <> ")" - x_ret = - deRefJSVal i <> ".wasmInstance.exports.getTSOret(" <> x_tid <> ")" - x_sp = - deRefJSVal i <> ".wasmInstance.exports.rts_getStablePtr(" <> x_ret <> - ")" - x_val = deRefJSVal i <> ".getJSVal(" <> x_sp <> ")" + "await " <> deRefJSVal i <> ".exports.rts_eval(" <> x_closure <> ")" + x_ret = deRefJSVal i <> ".exports.getTSOret(" <> x_tid <> ")" + x_sp = deRefJSVal i <> ".exports.rts_getStablePtr(" <> x_ret <> ")" + x_val' = deRefJSVal i <> ".getJSVal(" <> x_sp <> ")" + x_val = "(async () => " <> x_val' <> ")()" x <- eval s x_val LBS.putStr x diff --git a/asterius/test/rtsapi/rtsapi.mjs b/asterius/test/rtsapi/rtsapi.mjs index 7fd83d18..7ea5bf8f 100644 --- a/asterius/test/rtsapi/rtsapi.mjs +++ b/asterius/test/rtsapi/rtsapi.mjs @@ -3,18 +3,18 @@ import rtsapi from "./rtsapi.lib.mjs"; process.on("unhandledRejection", err => { throw err; }); -module.then(m => rtsapi(m)).then(i => { - i.wasmInstance.exports.hs_init(); - i.wasmInstance.exports.main(); - i.wasmInstance.exports.rts_evalLazyIO(i.wasmInstance.exports.rts_apply(i.symbolTable.Main_printInt_closure, i.wasmInstance.exports.rts_apply(i.symbolTable.Main_fact_closure, i.wasmInstance.exports.rts_mkInt(5)))); - const tid_p1 = i.wasmInstance.exports.rts_eval(i.wasmInstance.exports.rts_apply(i.symbolTable.Main_fact_closure, i.wasmInstance.exports.rts_mkInt(5))); - console.log(i.wasmInstance.exports.rts_getInt(i.wasmInstance.exports.getTSOret(tid_p1))); - console.log(i.wasmInstance.exports.rts_getBool(i.symbolTable.ghczmprim_GHCziTypes_False_closure)); - console.log(i.wasmInstance.exports.rts_getBool(i.symbolTable.ghczmprim_GHCziTypes_True_closure)); - console.log(i.wasmInstance.exports.rts_getBool(i.wasmInstance.exports.rts_mkBool(0))); - console.log(i.wasmInstance.exports.rts_getBool(i.wasmInstance.exports.rts_mkBool(42))); +module.then(m => rtsapi(m)).then(async i => { + i.exports.hs_init(); + await i.exports.main(); + await i.exports.rts_evalLazyIO(i.exports.rts_apply(i.symbolTable.Main_printInt_closure, i.exports.rts_apply(i.symbolTable.Main_fact_closure, i.exports.rts_mkInt(5)))); + const tid_p1 = await i.exports.rts_eval(i.exports.rts_apply(i.symbolTable.Main_fact_closure, i.exports.rts_mkInt(5))); + console.log(i.exports.rts_getInt(i.exports.getTSOret(tid_p1))); + console.log(i.exports.rts_getBool(i.symbolTable.ghczmprim_GHCziTypes_False_closure)); + console.log(i.exports.rts_getBool(i.symbolTable.ghczmprim_GHCziTypes_True_closure)); + console.log(i.exports.rts_getBool(i.exports.rts_mkBool(0))); + console.log(i.exports.rts_getBool(i.exports.rts_mkBool(42))); const x0 = Math.random(); - const tid_p3 = i.wasmInstance.exports.rts_eval(i.wasmInstance.exports.rts_apply(i.symbolTable.base_GHCziBase_id_closure, i.wasmInstance.exports.rts_mkDouble(x0))); - const x1 = i.wasmInstance.exports.rts_getDouble(i.wasmInstance.exports.getTSOret(tid_p3)); + const tid_p3 = await i.exports.rts_eval(i.exports.rts_apply(i.symbolTable.base_GHCziBase_id_closure, i.exports.rts_mkDouble(x0))); + const x1 = i.exports.rts_getDouble(i.exports.getTSOret(tid_p3)); console.log([x0, x1, x0 === x1]); }); diff --git a/docs/rts-api.md b/docs/rts-api.md index 031e5f0f..4836c99b 100644 --- a/docs/rts-api.md +++ b/docs/rts-api.md @@ -21,11 +21,11 @@ The next step is locating the pointer of `fact`. The "asterius instance" type we Since we'd like to call `fact`, we need to apply it to an argument, build a thunk representing the result, then evaluate the thunk to WHNF and retrieve the result. Assuming we're passing `--asterius-instance-callback=i=>{ ... }` to `ahc-link`, in the callback body, we can use RTS API like this: ```JavaScript -i.wasmInstance.exports.hs_init(); -const argument = i.wasmInstance.exports.rts_mkInt(5); -const thunk = i.wasmInstance.exports.rts_apply(i.staticsSymbolMap.Main_fact_closure, argument); -const tid = i.wasmInstance.exports.rts_eval(thunk); -console.log(i.wasmInstance.exports.rts_getInt(i.wasmInstance.exports.getTSOret(tid))); +i.exports.hs_init(); +const argument = i.exports.rts_mkInt(5); +const thunk = i.exports.rts_apply(i.staticsSymbolMap.Main_fact_closure, argument); +const tid = i.exports.rts_eval(thunk); +console.log(i.exports.rts_getInt(i.exports.getTSOret(tid))); ``` A line-by-line explanation follows: