diff --git a/crates/wasm_interp/tests/test_basics.rs b/crates/wasm_interp/tests/test_basics.rs index 39eec28571..825f9b3a06 100644 --- a/crates/wasm_interp/tests/test_basics.rs +++ b/crates/wasm_interp/tests/test_basics.rs @@ -2,7 +2,8 @@ use bumpalo::{collections::Vec, Bump}; use roc_wasm_interp::test_utils::{const_value, create_exported_function_no_locals, default_state}; -use roc_wasm_interp::{Action, Instance, ValueStack, DEFAULT_IMPORTS}; +use roc_wasm_interp::{Action, ImportDispatcher, Instance, ValueStack, DEFAULT_IMPORTS}; +use roc_wasm_module::sections::{Import, ImportDesc}; use roc_wasm_module::{ opcodes::OpCode, sections::ElementSegment, Export, ExportType, SerialBuffer, Serialize, Signature, Value, ValueType, WasmModule, @@ -456,6 +457,96 @@ fn test_br_table_help(condition: i32, expected: i32) { assert_eq!(state.value_stack.pop(), Value::I32(expected)) } +struct TestDispatcher { + internal_state: i32, +} + +impl ImportDispatcher for TestDispatcher { + fn dispatch( + &mut self, + module_name: &str, + function_name: &str, + arguments: &[Value], + _memory: &mut [u8], + ) -> Option { + assert_eq!(module_name, "env"); + assert_eq!(function_name, "increment_state"); + assert_eq!(arguments.len(), 1); + let val = arguments[0].unwrap_i32(); + self.internal_state += val; + dbg!(val, self.internal_state); + Some(Value::I32(self.internal_state)) + } +} + +#[test] +fn test_call_import() { + let arena = Bump::new(); + let mut module = WasmModule::new(&arena); + let start_fn_name = "test"; + + // User-provided non-Wasm code, with state + let import_dispatcher = TestDispatcher { + internal_state: 100, + }; + + // Function 0 is the import + module.import.imports.push(Import { + module: "env", + name: "increment_state", + description: ImportDesc::Func { signature_index: 0 }, + }); + module.types.insert(Signature { + param_types: bumpalo::vec![in &arena; ValueType::I32], + ret_type: Some(ValueType::I32), + }); + + // Function 1, which calls the import + module.code.function_count = 1; + let func0_offset = module.code.bytes.len() as u32; + module.code.function_offsets.push(func0_offset); + module.add_function_signature(Signature { + param_types: Vec::new_in(&arena), + ret_type: Some(ValueType::I32), + }); + module.export.append(Export { + name: start_fn_name, + ty: ExportType::Func, + index: 1, + }); + [ + 0, // no locals + OpCode::I32CONST as u8, + 11, // argument to increment_state + OpCode::CALL as u8, + 0, // function 0 + OpCode::I32CONST as u8, + 12, // argument to increment_state + OpCode::CALL as u8, + 0, // function 0 + OpCode::I32ADD as u8, + OpCode::END as u8, + ] + .serialize(&mut module.code.bytes); + + if false { + let mut buf = Vec::new_in(&arena); + module.serialize(&mut buf); + let filename = "/tmp/roc/call-return.wasm"; + std::fs::write(filename, buf).unwrap(); + println!("Wrote to {}", filename); + } + + let mut inst = Instance::for_module(&arena, &module, import_dispatcher, true).unwrap(); + + let return_val = inst + .call_export(&module, start_fn_name, []) + .unwrap() + .unwrap(); + + assert_eq!(return_val, Value::I32(234)); +} + #[test] fn test_call_return_no_args() { let arena = Bump::new();