diff --git a/crates/cli-support/src/descriptor.rs b/crates/cli-support/src/descriptor.rs index bd6bc25f7..b98a2ec5e 100644 --- a/crates/cli-support/src/descriptor.rs +++ b/crates/cli-support/src/descriptor.rs @@ -187,6 +187,12 @@ impl Descriptor { _ => return None, } } + Descriptor::RefMut(ref d) => { + match **d { + Descriptor::Slice(ref d) => &**d, + _ => return None, + } + } _ => return None, }; match *inner { @@ -234,6 +240,13 @@ impl Descriptor { _ => false, } } + + pub fn is_mut_ref(&self) -> bool { + match *self { + Descriptor::RefMut(_) => true, + _ => false, + } + } } fn get(a: &mut &[u32]) -> u32 { diff --git a/crates/cli-support/src/js/js2rust.rs b/crates/cli-support/src/js/js2rust.rs index d4f9d7391..d300e79e3 100644 --- a/crates/cli-support/src/js/js2rust.rs +++ b/crates/cli-support/src/js/js2rust.rs @@ -118,6 +118,15 @@ impl<'a, 'b> Js2Rust<'a, 'b> { setGlobalArgument(len{i}, {global_idx});\n\ ", i = i, func = func, arg = name, global_idx = global_idx)); if arg.is_by_ref() { + if arg.is_mut_ref() { + let get = self.cx.memview_function(kind); + self.finally(&format!("\ + {arg}.set({get}().subarray(\ + ptr{i} / {size}, \ + ptr{i} / {size} + len{i}\ + ));\n\ + ", i = i, arg = name, get = get, size = kind.size())); + } self.finally(&format!("\ wasm.__wbindgen_free(ptr{i}, len{i} * {size});\n\ ", i = i, size = kind.size())); diff --git a/crates/cli-support/src/js/mod.rs b/crates/cli-support/src/js/mod.rs index b6dba1c6b..f9c84db68 100644 --- a/crates/cli-support/src/js/mod.rs +++ b/crates/cli-support/src/js/mod.rs @@ -926,7 +926,7 @@ impl<'a> Context<'a> { self.expose_uint8_memory(); self.global(&format!(" function getStringFromWasm(ptr, len) {{ - return cachedDecoder.decode(getUint8Memory().slice(ptr, ptr + len)); + return cachedDecoder.decode(getUint8Memory().subarray(ptr, ptr + len)); }} ")); } @@ -940,7 +940,7 @@ impl<'a> Context<'a> { self.global(&format!(" function getArrayJsValueFromWasm(ptr, len) {{ const mem = getUint32Memory(); - const slice = mem.slice(ptr / 4, ptr / 4 + len); + const slice = mem.subarray(ptr / 4, ptr / 4 + len); const result = []; for (let i = 0; i < slice.length; i++) {{ result.push(takeObject(slice[i])) @@ -951,158 +951,154 @@ impl<'a> Context<'a> { } fn expose_get_array_i8_from_wasm(&mut self) { - self.expose_uint8_memory(); - if !self.exposed_globals.insert("get_array_i8_from_wasm") { - return; - } - self.global(&format!(" - function getArrayI8FromWasm(ptr, len) {{ - const mem = getUint8Memory(); - const slice = mem.slice(ptr, ptr + len); - return new Int8Array(slice); - }} - ")); + self.expose_int8_memory(); + self.arrayget("getArrayI8FromWasm", "getInt8Memory", 1); } fn expose_get_array_u8_from_wasm(&mut self) { self.expose_uint8_memory(); - if !self.exposed_globals.insert("get_array_u8_from_wasm") { - return; - } - self.global(&format!(" - function getArrayU8FromWasm(ptr, len) {{ - const mem = getUint8Memory(); - const slice = mem.slice(ptr, ptr + len); - return new Uint8Array(slice); - }} - ")); + self.arrayget("getArrayU8FromWasm", "getUint8Memory", 1); } fn expose_get_array_i16_from_wasm(&mut self) { - self.expose_uint16_memory(); - if !self.exposed_globals.insert("get_array_i16_from_wasm") { - return; - } - self.global(&format!(" - function getArrayI16FromWasm(ptr, len) {{ - const mem = getUint16Memory(); - const slice = mem.slice(ptr / 2, ptr / 2 + len); - return new Int16Array(slice); - }} - ")); + self.expose_int16_memory(); + self.arrayget("getArrayI16FromWasm", "getInt16Memory", 2); } fn expose_get_array_u16_from_wasm(&mut self) { self.expose_uint16_memory(); - if !self.exposed_globals.insert("get_array_u16_from_wasm") { - return; - } - self.global(&format!(" - function getArrayU16FromWasm(ptr, len) {{ - const mem = getUint16Memory(); - const slice = mem.slice(ptr / 2, ptr / 2 + len); - return new Uint16Array(slice); - }} - ")); + self.arrayget("getArrayU16FromWasm", "getUint16Memory", 2); } fn expose_get_array_i32_from_wasm(&mut self) { - self.expose_uint32_memory(); - if !self.exposed_globals.insert("get_array_i32_from_wasm") { - return; - } - self.global(&format!(" - function getArrayI32FromWasm(ptr, len) {{ - const mem = getUint32Memory(); - const slice = mem.slice(ptr / 4, ptr / 4 + len); - return new Int32Array(slice); - }} - ")); + self.expose_int32_memory(); + self.arrayget("getArrayI32FromWasm", "getInt32Memory", 4); } fn expose_get_array_u32_from_wasm(&mut self) { self.expose_uint32_memory(); - if !self.exposed_globals.insert("get_array_u32_from_wasm") { - return; - } - self.global(&format!(" - function getArrayU32FromWasm(ptr, len) {{ - const mem = getUint32Memory(); - const slice = mem.slice(ptr / 4, ptr / 4 + len); - return new Uint32Array(slice); - }} - ")); + self.arrayget("getArrayU32FromWasm", "getUint32Memory", 4); } fn expose_get_array_f32_from_wasm(&mut self) { - if !self.exposed_globals.insert("get_array_f32_from_wasm") { - return; - } - self.global(&format!(" - function getArrayF32FromWasm(ptr, len) {{ - const mem = new Float32Array(wasm.memory.buffer); - const slice = mem.slice(ptr / 4, ptr / 4 + len); - return new Float32Array(slice); - }} - ")); + self.expose_f32_memory(); + self.arrayget("getArrayF32FromWasm", "getFloat32Memory", 4); } fn expose_get_array_f64_from_wasm(&mut self) { - if !self.exposed_globals.insert("get_array_f64_from_wasm") { + self.expose_f64_memory(); + self.arrayget("getArrayF64FromWasm", "getFloat64Memory", 8); + } + + fn arrayget(&mut self, name: &'static str, mem: &'static str, size: usize) { + if !self.exposed_globals.insert(name) { return; } self.global(&format!(" - function getArrayF64FromWasm(ptr, len) {{ - const mem = new Float64Array(wasm.memory.buffer); - const slice = mem.slice(ptr / 8, ptr / 8 + len); - return new Float64Array(slice); + function {name}(ptr, len) {{ + return {mem}().subarray(ptr / {size}, ptr / {size} + len); }} - ")); + ", + name = name, + mem = mem, + size = size, + )); + } + + fn expose_int8_memory(&mut self) { + self.memview("getInt8Memory", "Int8Array"); } fn expose_uint8_memory(&mut self) { - if !self.exposed_globals.insert("uint8_memory") { - return; - } - self.global(&format!(" - let cachedUint8Memory = null; - function getUint8Memory() {{ - if (cachedUint8Memory === null || - cachedUint8Memory.buffer !== wasm.memory.buffer) - cachedUint8Memory = new Uint8Array(wasm.memory.buffer); - return cachedUint8Memory; - }} - ")); + self.memview("getUint8Memory", "Uint8Array"); + } + + fn expose_int16_memory(&mut self) { + self.memview("getInt16Memory", "Int16Array"); } fn expose_uint16_memory(&mut self) { - if !self.exposed_globals.insert("uint16_memory") { - return; - } - self.global(&format!(" - let cachedUint16Memory = null; - function getUint16Memory() {{ - if (cachedUint16Memory === null || - cachedUint16Memory.buffer !== wasm.memory.buffer) - cachedUint16Memory = new Uint16Array(wasm.memory.buffer); - return cachedUint16Memory; - }} - ")); + self.memview("getUint16Memory", "Uint16Array"); + } + + fn expose_int32_memory(&mut self) { + self.memview("getInt32Memory", "Int32Array"); } fn expose_uint32_memory(&mut self) { - if !self.exposed_globals.insert("uint32_memory") { + self.memview("getUint32Memory", "Uint32Array"); + } + + fn expose_f32_memory(&mut self) { + self.memview("getFloat32Memory", "Float32Array"); + } + + fn expose_f64_memory(&mut self) { + self.memview("getFloat64Memory", "Float64Array"); + } + + fn memview_function(&mut self, t: VectorKind) -> &'static str { + match t { + VectorKind::String => { + self.expose_uint8_memory(); + "getUint8Memory" + } + VectorKind::I8 => { + self.expose_int8_memory(); + "getInt8Memory" + } + VectorKind::U8 => { + self.expose_uint8_memory(); + "getUint8Memory" + } + VectorKind::I16 => { + self.expose_int16_memory(); + "getInt16Memory" + } + VectorKind::U16 => { + self.expose_uint16_memory(); + "getUint16Memory" + } + VectorKind::I32 => { + self.expose_int32_memory(); + "getInt32Memory" + } + VectorKind::U32 => { + self.expose_uint32_memory(); + "getUint32Memory" + } + VectorKind::F32 => { + self.expose_f32_memory(); + "getFloat32Memory" + } + VectorKind::F64 => { + self.expose_f64_memory(); + "getFloat64Memory" + } + VectorKind::Anyref => { + self.expose_uint32_memory(); + "getUint32Memory" + } + } + } + + + fn memview(&mut self, name: &'static str, js: &str) { + if !self.exposed_globals.insert(name) { return; } self.global(&format!(" - let cachedUint32Memory = null; - function getUint32Memory() {{ - if (cachedUint32Memory === null || - cachedUint32Memory.buffer !== wasm.memory.buffer) - cachedUint32Memory = new Uint32Array(wasm.memory.buffer); - return cachedUint32Memory; + let cache{name} = null; + function {name}() {{ + if (cache{name} === null || + cache{name}.buffer !== wasm.memory.buffer) + cache{name} = new {js}(wasm.memory.buffer); + return cache{name}; }} - ")); + ", + name = name, + js = js, + )); } fn expose_assert_class(&mut self) { @@ -1268,6 +1264,60 @@ impl<'a> Context<'a> { } } + fn expose_commit_slice_to_wasm(&mut self, ty: VectorKind) + -> Result<&'static str, Error> + { + let gen = |me: &mut Context, name: &'static str, size: usize, get: &str| { + me.global(&format!(" + function {name}(ptr, view) {{ + if (view.buffer !== wasm.memory.buffer) + {get}().set(view, ptr / {size}); + }} + ", + name = name, + size = size, + get = get, + )); + name + }; + match ty { + VectorKind::String => bail!("strings cannot be used with mutable slices"), + VectorKind::Anyref => bail!("js values cannot be used with mutable slices"), + VectorKind::I8 => { + self.expose_int8_memory(); + Ok(gen(self, "commitI8ToWasm", 1, "getInt8Memory")) + } + VectorKind::U8 => { + self.expose_uint8_memory(); + Ok(gen(self, "commitU8ToWasm", 1, "getUint8Memory")) + } + VectorKind::I16 => { + self.expose_int16_memory(); + Ok(gen(self, "commitI16ToWasm", 2, "getInt16Memory")) + } + VectorKind::U16 => { + self.expose_uint16_memory(); + Ok(gen(self, "commitU16ToWasm", 2, "getUint16Memory")) + } + VectorKind::I32 => { + self.expose_int32_memory(); + Ok(gen(self, "commitI32ToWasm", 4, "getInt32Memory")) + } + VectorKind::U32 => { + self.expose_uint32_memory(); + Ok(gen(self, "commitU32ToWasm", 4, "getUint32Memory")) + } + VectorKind::F32 => { + self.expose_f32_memory(); + Ok(gen(self, "commitF32ToWasm", 4, "getFloat32Memory")) + } + VectorKind::F64 => { + self.expose_f64_memory(); + Ok(gen(self, "commitF64ToWasm", 8, "getFloat64Memory")) + } + } + } + fn expose_set_global_argument(&mut self) -> Result<(), Error> { if !self.exposed_globals.insert("set_global_argument") { return Ok(()); diff --git a/crates/cli-support/src/js/rust2js.rs b/crates/cli-support/src/js/rust2js.rs index bd46778de..d47369e48 100644 --- a/crates/cli-support/src/js/rust2js.rs +++ b/crates/cli-support/src/js/rust2js.rs @@ -92,6 +92,9 @@ impl<'a, 'b> Rust2Js<'a, 'b> { wasm.__wbindgen_free(arg{0}, len{0} * {size});\ ", i, size = ty.size())); self.cx.require_internal_export("__wbindgen_free")?; + } else if arg.is_mut_ref() { + let f = self.cx.expose_commit_slice_to_wasm(ty)?; + self.finally(&format!("{}(arg{i}, v{i});", f, i = i)); } self.js_arguments.push(format!("v{}", i)); return Ok(()) diff --git a/src/convert.rs b/src/convert.rs index 80d49f249..82fb908af 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -170,6 +170,14 @@ macro_rules! vectors { } } + impl<'a> IntoWasmAbi for &'a mut [$t] { + type Abi = u32; + + fn into_abi(self, extra: &mut Stack) -> u32 { + (&*self).into_abi(extra) + } + } + impl RefFromWasmAbi for [$t] { type Abi = u32; type Anchor = &'static [$t]; @@ -181,6 +189,18 @@ macro_rules! vectors { ) } } + + impl RefMutFromWasmAbi for [$t] { + type Abi = u32; + type Anchor = &'static mut [$t]; + + unsafe fn ref_mut_from_abi(js: u32, extra: &mut Stack) -> &'static mut [$t] { + slice::from_raw_parts_mut( + <*mut $t>::from_abi(js, extra), + extra.pop() as usize, + ) + } + } )*) } diff --git a/tests/all/slice.rs b/tests/all/slice.rs index f3bfab675..13a59bb25 100644 --- a/tests/all/slice.rs +++ b/tests/all/slice.rs @@ -278,3 +278,225 @@ fn pass_array_works() { "#) .test(); } + +#[test] +fn import_mut() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + macro_rules! doit { + ($(($rust:ident, $js:ident, $i:ident))*) => ( + $( + #[wasm_bindgen(module = "./test")] + extern { + fn $js(a: &mut [$i]); + } + + fn $rust() { + let mut buf = [ + 1 as $i, + 2 as $i, + 3 as $i, + ]; + $js(&mut buf); + assert_eq!(buf[0], 4 as $i); + assert_eq!(buf[1], 5 as $i); + assert_eq!(buf[2], 3 as $i); + } + )* + + #[wasm_bindgen] + pub fn run() { + $($rust();)* + } + ) + } + + + doit! { + (rust_i8, js_i8, i8) + (rust_u8, js_u8, u8) + (rust_i16, js_i16, i16) + (rust_u16, js_u16, u16) + (rust_i32, js_i32, i32) + (rust_u32, js_u32, u32) + (rust_f32, js_f32, f32) + (rust_f64, js_f64, f64) + } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import * as wasm from "./out"; + + function foo(a: any) { + assert.strictEqual(a.length, 3); + assert.strictEqual(a[0], 1); + assert.strictEqual(a[1], 2); + a[0] = 4; + a[1] = 5; + } + + export const js_i8 = foo; + export const js_u8 = foo; + export const js_i16 = foo; + export const js_u16 = foo; + export const js_i32 = foo; + export const js_u32 = foo; + export const js_f32 = foo; + export const js_f64 = foo; + + export function test() { + wasm.run(); + } + "#) + .test(); +} + +#[test] +fn import_mut_realloc_middle() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + macro_rules! doit { + ($(($rust:ident, $js:ident, $i:ident))*) => ( + $( + #[wasm_bindgen(module = "./test")] + extern { + fn $js(a: &mut [$i]); + } + + fn $rust() { + let mut buf = [ + 1 as $i, + 2 as $i, + 3 as $i, + ]; + $js(&mut buf); + assert_eq!(buf[0], 4 as $i); + assert_eq!(buf[1], 5 as $i); + assert_eq!(buf[2], 3 as $i); + } + )* + + #[wasm_bindgen] + pub fn run() { + $($rust();)* + } + + #[wasm_bindgen] + pub fn allocate() { + std::mem::forget(Vec::::with_capacity(128 * 1024)); + } + ) + } + + + doit! { + (rust_i8, js_i8, i8) + (rust_u8, js_u8, u8) + (rust_i16, js_i16, i16) + (rust_u16, js_u16, u16) + (rust_i32, js_i32, i32) + (rust_u32, js_u32, u32) + (rust_f32, js_f32, f32) + (rust_f64, js_f64, f64) + } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import * as wasm from "./out"; + + function foo(a: any) { + wasm.allocate(); + assert.strictEqual(a.length, 3); + assert.strictEqual(a[0], 1); + assert.strictEqual(a[1], 2); + a[0] = 4; + a[1] = 5; + } + + export const js_i8 = foo; + export const js_u8 = foo; + export const js_i16 = foo; + export const js_u16 = foo; + export const js_i32 = foo; + export const js_u32 = foo; + export const js_f32 = foo; + export const js_f64 = foo; + + export function test() { + wasm.run(); + } + "#) + .test(); +} + +#[test] +fn export_mut() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + macro_rules! doit { + ($($i:ident)*) => ($( + #[wasm_bindgen] + pub fn $i(a: &mut [$i]) { + assert_eq!(a.len(), 3); + assert_eq!(a[0], 1 as $i); + assert_eq!(a[1], 2 as $i); + assert_eq!(a[2], 3 as $i); + a[0] = 4 as $i; + a[1] = 5 as $i; + } + )*) + } + + + doit! { i8 u8 i16 u16 i32 u32 f32 f64 } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import * as wasm from "./out"; + + function run(a: any, rust: any) { + assert.strictEqual(a.length, 3); + a[0] = 1; + a[1] = 2; + a[2] = 3; + console.log(a); + rust(a); + console.log(a); + assert.strictEqual(a.length, 3); + assert.strictEqual(a[0], 4); + assert.strictEqual(a[1], 5); + assert.strictEqual(a[2], 3); + } + + export function test() { + run(new Int8Array(3), wasm.i8); + run(new Uint8Array(3), wasm.u8); + run(new Int16Array(3), wasm.i16); + run(new Uint16Array(3), wasm.u16); + run(new Int32Array(3), wasm.i32); + run(new Uint32Array(3), wasm.u32); + run(new Float32Array(3), wasm.f32); + run(new Float64Array(3), wasm.f64); + } + "#) + .test(); +} +