From 0566a97485f8168da5fa80a79d928fa069e8d251 Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Tue, 1 May 2018 10:06:35 -0700 Subject: [PATCH] Add support for mutable slices This commit adds support for mutable slices to pass the boundary between JS and Rust. While mutable slices cannot be used as return values they can be listed as arguments to both exported functions as well as imported functions. When passing a mutable slice into a Rust function (aka having it as an argument to an exported Rust function) then like before with a normal slice it's copied into the wasm memory. Afterwards, however, the updates in the wasm memory will be reflected back into the original slice. This does require a lot of copying and probably isn't the most efficient, but it should at least work for the time being. The real nifty part happens when Rust passes a mutable slice out to JS. When doing this it's a very cheap operation that just gets a subarray of the main wasm memory. Now the wasm memory's buffer can change over time which can produce surprising results where memory is modified in JS but it may not be reflected back into Rust. To accomodate this when a JS imported function returns any updates to the buffer are copied back to Rust if Rust's memory buffer has changed in the meantime. Along the way this fixes usage of `slice` to instead use `subarray` as that's what we really want, no copying. All methods have been updated to use `subarray` accessors instead of `slice` or constructing new arrays. Closes #53 --- crates/cli-support/src/descriptor.rs | 13 ++ crates/cli-support/src/js/js2rust.rs | 9 + crates/cli-support/src/js/mod.rs | 276 ++++++++++++++++----------- crates/cli-support/src/js/rust2js.rs | 3 + src/convert.rs | 20 ++ tests/all/slice.rs | 222 +++++++++++++++++++++ 6 files changed, 430 insertions(+), 113 deletions(-) diff --git a/crates/cli-support/src/descriptor.rs b/crates/cli-support/src/descriptor.rs index bd6bc25f7..b98a2ec5e 100644 --- a/crates/cli-support/src/descriptor.rs +++ b/crates/cli-support/src/descriptor.rs @@ -187,6 +187,12 @@ impl Descriptor { _ => return None, } } + Descriptor::RefMut(ref d) => { + match **d { + Descriptor::Slice(ref d) => &**d, + _ => return None, + } + } _ => return None, }; match *inner { @@ -234,6 +240,13 @@ impl Descriptor { _ => false, } } + + pub fn is_mut_ref(&self) -> bool { + match *self { + Descriptor::RefMut(_) => true, + _ => false, + } + } } fn get(a: &mut &[u32]) -> u32 { diff --git a/crates/cli-support/src/js/js2rust.rs b/crates/cli-support/src/js/js2rust.rs index d4f9d7391..d300e79e3 100644 --- a/crates/cli-support/src/js/js2rust.rs +++ b/crates/cli-support/src/js/js2rust.rs @@ -118,6 +118,15 @@ impl<'a, 'b> Js2Rust<'a, 'b> { setGlobalArgument(len{i}, {global_idx});\n\ ", i = i, func = func, arg = name, global_idx = global_idx)); if arg.is_by_ref() { + if arg.is_mut_ref() { + let get = self.cx.memview_function(kind); + self.finally(&format!("\ + {arg}.set({get}().subarray(\ + ptr{i} / {size}, \ + ptr{i} / {size} + len{i}\ + ));\n\ + ", i = i, arg = name, get = get, size = kind.size())); + } self.finally(&format!("\ wasm.__wbindgen_free(ptr{i}, len{i} * {size});\n\ ", i = i, size = kind.size())); diff --git a/crates/cli-support/src/js/mod.rs b/crates/cli-support/src/js/mod.rs index b6dba1c6b..f9c84db68 100644 --- a/crates/cli-support/src/js/mod.rs +++ b/crates/cli-support/src/js/mod.rs @@ -926,7 +926,7 @@ impl<'a> Context<'a> { self.expose_uint8_memory(); self.global(&format!(" function getStringFromWasm(ptr, len) {{ - return cachedDecoder.decode(getUint8Memory().slice(ptr, ptr + len)); + return cachedDecoder.decode(getUint8Memory().subarray(ptr, ptr + len)); }} ")); } @@ -940,7 +940,7 @@ impl<'a> Context<'a> { self.global(&format!(" function getArrayJsValueFromWasm(ptr, len) {{ const mem = getUint32Memory(); - const slice = mem.slice(ptr / 4, ptr / 4 + len); + const slice = mem.subarray(ptr / 4, ptr / 4 + len); const result = []; for (let i = 0; i < slice.length; i++) {{ result.push(takeObject(slice[i])) @@ -951,158 +951,154 @@ impl<'a> Context<'a> { } fn expose_get_array_i8_from_wasm(&mut self) { - self.expose_uint8_memory(); - if !self.exposed_globals.insert("get_array_i8_from_wasm") { - return; - } - self.global(&format!(" - function getArrayI8FromWasm(ptr, len) {{ - const mem = getUint8Memory(); - const slice = mem.slice(ptr, ptr + len); - return new Int8Array(slice); - }} - ")); + self.expose_int8_memory(); + self.arrayget("getArrayI8FromWasm", "getInt8Memory", 1); } fn expose_get_array_u8_from_wasm(&mut self) { self.expose_uint8_memory(); - if !self.exposed_globals.insert("get_array_u8_from_wasm") { - return; - } - self.global(&format!(" - function getArrayU8FromWasm(ptr, len) {{ - const mem = getUint8Memory(); - const slice = mem.slice(ptr, ptr + len); - return new Uint8Array(slice); - }} - ")); + self.arrayget("getArrayU8FromWasm", "getUint8Memory", 1); } fn expose_get_array_i16_from_wasm(&mut self) { - self.expose_uint16_memory(); - if !self.exposed_globals.insert("get_array_i16_from_wasm") { - return; - } - self.global(&format!(" - function getArrayI16FromWasm(ptr, len) {{ - const mem = getUint16Memory(); - const slice = mem.slice(ptr / 2, ptr / 2 + len); - return new Int16Array(slice); - }} - ")); + self.expose_int16_memory(); + self.arrayget("getArrayI16FromWasm", "getInt16Memory", 2); } fn expose_get_array_u16_from_wasm(&mut self) { self.expose_uint16_memory(); - if !self.exposed_globals.insert("get_array_u16_from_wasm") { - return; - } - self.global(&format!(" - function getArrayU16FromWasm(ptr, len) {{ - const mem = getUint16Memory(); - const slice = mem.slice(ptr / 2, ptr / 2 + len); - return new Uint16Array(slice); - }} - ")); + self.arrayget("getArrayU16FromWasm", "getUint16Memory", 2); } fn expose_get_array_i32_from_wasm(&mut self) { - self.expose_uint32_memory(); - if !self.exposed_globals.insert("get_array_i32_from_wasm") { - return; - } - self.global(&format!(" - function getArrayI32FromWasm(ptr, len) {{ - const mem = getUint32Memory(); - const slice = mem.slice(ptr / 4, ptr / 4 + len); - return new Int32Array(slice); - }} - ")); + self.expose_int32_memory(); + self.arrayget("getArrayI32FromWasm", "getInt32Memory", 4); } fn expose_get_array_u32_from_wasm(&mut self) { self.expose_uint32_memory(); - if !self.exposed_globals.insert("get_array_u32_from_wasm") { - return; - } - self.global(&format!(" - function getArrayU32FromWasm(ptr, len) {{ - const mem = getUint32Memory(); - const slice = mem.slice(ptr / 4, ptr / 4 + len); - return new Uint32Array(slice); - }} - ")); + self.arrayget("getArrayU32FromWasm", "getUint32Memory", 4); } fn expose_get_array_f32_from_wasm(&mut self) { - if !self.exposed_globals.insert("get_array_f32_from_wasm") { - return; - } - self.global(&format!(" - function getArrayF32FromWasm(ptr, len) {{ - const mem = new Float32Array(wasm.memory.buffer); - const slice = mem.slice(ptr / 4, ptr / 4 + len); - return new Float32Array(slice); - }} - ")); + self.expose_f32_memory(); + self.arrayget("getArrayF32FromWasm", "getFloat32Memory", 4); } fn expose_get_array_f64_from_wasm(&mut self) { - if !self.exposed_globals.insert("get_array_f64_from_wasm") { + self.expose_f64_memory(); + self.arrayget("getArrayF64FromWasm", "getFloat64Memory", 8); + } + + fn arrayget(&mut self, name: &'static str, mem: &'static str, size: usize) { + if !self.exposed_globals.insert(name) { return; } self.global(&format!(" - function getArrayF64FromWasm(ptr, len) {{ - const mem = new Float64Array(wasm.memory.buffer); - const slice = mem.slice(ptr / 8, ptr / 8 + len); - return new Float64Array(slice); + function {name}(ptr, len) {{ + return {mem}().subarray(ptr / {size}, ptr / {size} + len); }} - ")); + ", + name = name, + mem = mem, + size = size, + )); + } + + fn expose_int8_memory(&mut self) { + self.memview("getInt8Memory", "Int8Array"); } fn expose_uint8_memory(&mut self) { - if !self.exposed_globals.insert("uint8_memory") { - return; - } - self.global(&format!(" - let cachedUint8Memory = null; - function getUint8Memory() {{ - if (cachedUint8Memory === null || - cachedUint8Memory.buffer !== wasm.memory.buffer) - cachedUint8Memory = new Uint8Array(wasm.memory.buffer); - return cachedUint8Memory; - }} - ")); + self.memview("getUint8Memory", "Uint8Array"); + } + + fn expose_int16_memory(&mut self) { + self.memview("getInt16Memory", "Int16Array"); } fn expose_uint16_memory(&mut self) { - if !self.exposed_globals.insert("uint16_memory") { - return; - } - self.global(&format!(" - let cachedUint16Memory = null; - function getUint16Memory() {{ - if (cachedUint16Memory === null || - cachedUint16Memory.buffer !== wasm.memory.buffer) - cachedUint16Memory = new Uint16Array(wasm.memory.buffer); - return cachedUint16Memory; - }} - ")); + self.memview("getUint16Memory", "Uint16Array"); + } + + fn expose_int32_memory(&mut self) { + self.memview("getInt32Memory", "Int32Array"); } fn expose_uint32_memory(&mut self) { - if !self.exposed_globals.insert("uint32_memory") { + self.memview("getUint32Memory", "Uint32Array"); + } + + fn expose_f32_memory(&mut self) { + self.memview("getFloat32Memory", "Float32Array"); + } + + fn expose_f64_memory(&mut self) { + self.memview("getFloat64Memory", "Float64Array"); + } + + fn memview_function(&mut self, t: VectorKind) -> &'static str { + match t { + VectorKind::String => { + self.expose_uint8_memory(); + "getUint8Memory" + } + VectorKind::I8 => { + self.expose_int8_memory(); + "getInt8Memory" + } + VectorKind::U8 => { + self.expose_uint8_memory(); + "getUint8Memory" + } + VectorKind::I16 => { + self.expose_int16_memory(); + "getInt16Memory" + } + VectorKind::U16 => { + self.expose_uint16_memory(); + "getUint16Memory" + } + VectorKind::I32 => { + self.expose_int32_memory(); + "getInt32Memory" + } + VectorKind::U32 => { + self.expose_uint32_memory(); + "getUint32Memory" + } + VectorKind::F32 => { + self.expose_f32_memory(); + "getFloat32Memory" + } + VectorKind::F64 => { + self.expose_f64_memory(); + "getFloat64Memory" + } + VectorKind::Anyref => { + self.expose_uint32_memory(); + "getUint32Memory" + } + } + } + + + fn memview(&mut self, name: &'static str, js: &str) { + if !self.exposed_globals.insert(name) { return; } self.global(&format!(" - let cachedUint32Memory = null; - function getUint32Memory() {{ - if (cachedUint32Memory === null || - cachedUint32Memory.buffer !== wasm.memory.buffer) - cachedUint32Memory = new Uint32Array(wasm.memory.buffer); - return cachedUint32Memory; + let cache{name} = null; + function {name}() {{ + if (cache{name} === null || + cache{name}.buffer !== wasm.memory.buffer) + cache{name} = new {js}(wasm.memory.buffer); + return cache{name}; }} - ")); + ", + name = name, + js = js, + )); } fn expose_assert_class(&mut self) { @@ -1268,6 +1264,60 @@ impl<'a> Context<'a> { } } + fn expose_commit_slice_to_wasm(&mut self, ty: VectorKind) + -> Result<&'static str, Error> + { + let gen = |me: &mut Context, name: &'static str, size: usize, get: &str| { + me.global(&format!(" + function {name}(ptr, view) {{ + if (view.buffer !== wasm.memory.buffer) + {get}().set(view, ptr / {size}); + }} + ", + name = name, + size = size, + get = get, + )); + name + }; + match ty { + VectorKind::String => bail!("strings cannot be used with mutable slices"), + VectorKind::Anyref => bail!("js values cannot be used with mutable slices"), + VectorKind::I8 => { + self.expose_int8_memory(); + Ok(gen(self, "commitI8ToWasm", 1, "getInt8Memory")) + } + VectorKind::U8 => { + self.expose_uint8_memory(); + Ok(gen(self, "commitU8ToWasm", 1, "getUint8Memory")) + } + VectorKind::I16 => { + self.expose_int16_memory(); + Ok(gen(self, "commitI16ToWasm", 2, "getInt16Memory")) + } + VectorKind::U16 => { + self.expose_uint16_memory(); + Ok(gen(self, "commitU16ToWasm", 2, "getUint16Memory")) + } + VectorKind::I32 => { + self.expose_int32_memory(); + Ok(gen(self, "commitI32ToWasm", 4, "getInt32Memory")) + } + VectorKind::U32 => { + self.expose_uint32_memory(); + Ok(gen(self, "commitU32ToWasm", 4, "getUint32Memory")) + } + VectorKind::F32 => { + self.expose_f32_memory(); + Ok(gen(self, "commitF32ToWasm", 4, "getFloat32Memory")) + } + VectorKind::F64 => { + self.expose_f64_memory(); + Ok(gen(self, "commitF64ToWasm", 8, "getFloat64Memory")) + } + } + } + fn expose_set_global_argument(&mut self) -> Result<(), Error> { if !self.exposed_globals.insert("set_global_argument") { return Ok(()); diff --git a/crates/cli-support/src/js/rust2js.rs b/crates/cli-support/src/js/rust2js.rs index bd46778de..d47369e48 100644 --- a/crates/cli-support/src/js/rust2js.rs +++ b/crates/cli-support/src/js/rust2js.rs @@ -92,6 +92,9 @@ impl<'a, 'b> Rust2Js<'a, 'b> { wasm.__wbindgen_free(arg{0}, len{0} * {size});\ ", i, size = ty.size())); self.cx.require_internal_export("__wbindgen_free")?; + } else if arg.is_mut_ref() { + let f = self.cx.expose_commit_slice_to_wasm(ty)?; + self.finally(&format!("{}(arg{i}, v{i});", f, i = i)); } self.js_arguments.push(format!("v{}", i)); return Ok(()) diff --git a/src/convert.rs b/src/convert.rs index 80d49f249..82fb908af 100644 --- a/src/convert.rs +++ b/src/convert.rs @@ -170,6 +170,14 @@ macro_rules! vectors { } } + impl<'a> IntoWasmAbi for &'a mut [$t] { + type Abi = u32; + + fn into_abi(self, extra: &mut Stack) -> u32 { + (&*self).into_abi(extra) + } + } + impl RefFromWasmAbi for [$t] { type Abi = u32; type Anchor = &'static [$t]; @@ -181,6 +189,18 @@ macro_rules! vectors { ) } } + + impl RefMutFromWasmAbi for [$t] { + type Abi = u32; + type Anchor = &'static mut [$t]; + + unsafe fn ref_mut_from_abi(js: u32, extra: &mut Stack) -> &'static mut [$t] { + slice::from_raw_parts_mut( + <*mut $t>::from_abi(js, extra), + extra.pop() as usize, + ) + } + } )*) } diff --git a/tests/all/slice.rs b/tests/all/slice.rs index f3bfab675..13a59bb25 100644 --- a/tests/all/slice.rs +++ b/tests/all/slice.rs @@ -278,3 +278,225 @@ fn pass_array_works() { "#) .test(); } + +#[test] +fn import_mut() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + macro_rules! doit { + ($(($rust:ident, $js:ident, $i:ident))*) => ( + $( + #[wasm_bindgen(module = "./test")] + extern { + fn $js(a: &mut [$i]); + } + + fn $rust() { + let mut buf = [ + 1 as $i, + 2 as $i, + 3 as $i, + ]; + $js(&mut buf); + assert_eq!(buf[0], 4 as $i); + assert_eq!(buf[1], 5 as $i); + assert_eq!(buf[2], 3 as $i); + } + )* + + #[wasm_bindgen] + pub fn run() { + $($rust();)* + } + ) + } + + + doit! { + (rust_i8, js_i8, i8) + (rust_u8, js_u8, u8) + (rust_i16, js_i16, i16) + (rust_u16, js_u16, u16) + (rust_i32, js_i32, i32) + (rust_u32, js_u32, u32) + (rust_f32, js_f32, f32) + (rust_f64, js_f64, f64) + } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import * as wasm from "./out"; + + function foo(a: any) { + assert.strictEqual(a.length, 3); + assert.strictEqual(a[0], 1); + assert.strictEqual(a[1], 2); + a[0] = 4; + a[1] = 5; + } + + export const js_i8 = foo; + export const js_u8 = foo; + export const js_i16 = foo; + export const js_u16 = foo; + export const js_i32 = foo; + export const js_u32 = foo; + export const js_f32 = foo; + export const js_f64 = foo; + + export function test() { + wasm.run(); + } + "#) + .test(); +} + +#[test] +fn import_mut_realloc_middle() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section, wasm_import_module)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + macro_rules! doit { + ($(($rust:ident, $js:ident, $i:ident))*) => ( + $( + #[wasm_bindgen(module = "./test")] + extern { + fn $js(a: &mut [$i]); + } + + fn $rust() { + let mut buf = [ + 1 as $i, + 2 as $i, + 3 as $i, + ]; + $js(&mut buf); + assert_eq!(buf[0], 4 as $i); + assert_eq!(buf[1], 5 as $i); + assert_eq!(buf[2], 3 as $i); + } + )* + + #[wasm_bindgen] + pub fn run() { + $($rust();)* + } + + #[wasm_bindgen] + pub fn allocate() { + std::mem::forget(Vec::::with_capacity(128 * 1024)); + } + ) + } + + + doit! { + (rust_i8, js_i8, i8) + (rust_u8, js_u8, u8) + (rust_i16, js_i16, i16) + (rust_u16, js_u16, u16) + (rust_i32, js_i32, i32) + (rust_u32, js_u32, u32) + (rust_f32, js_f32, f32) + (rust_f64, js_f64, f64) + } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import * as wasm from "./out"; + + function foo(a: any) { + wasm.allocate(); + assert.strictEqual(a.length, 3); + assert.strictEqual(a[0], 1); + assert.strictEqual(a[1], 2); + a[0] = 4; + a[1] = 5; + } + + export const js_i8 = foo; + export const js_u8 = foo; + export const js_i16 = foo; + export const js_u16 = foo; + export const js_i32 = foo; + export const js_u32 = foo; + export const js_f32 = foo; + export const js_f64 = foo; + + export function test() { + wasm.run(); + } + "#) + .test(); +} + +#[test] +fn export_mut() { + project() + .file("src/lib.rs", r#" + #![feature(proc_macro, wasm_custom_section)] + + extern crate wasm_bindgen; + + use wasm_bindgen::prelude::*; + + macro_rules! doit { + ($($i:ident)*) => ($( + #[wasm_bindgen] + pub fn $i(a: &mut [$i]) { + assert_eq!(a.len(), 3); + assert_eq!(a[0], 1 as $i); + assert_eq!(a[1], 2 as $i); + assert_eq!(a[2], 3 as $i); + a[0] = 4 as $i; + a[1] = 5 as $i; + } + )*) + } + + + doit! { i8 u8 i16 u16 i32 u32 f32 f64 } + "#) + .file("test.ts", r#" + import * as assert from "assert"; + import * as wasm from "./out"; + + function run(a: any, rust: any) { + assert.strictEqual(a.length, 3); + a[0] = 1; + a[1] = 2; + a[2] = 3; + console.log(a); + rust(a); + console.log(a); + assert.strictEqual(a.length, 3); + assert.strictEqual(a[0], 4); + assert.strictEqual(a[1], 5); + assert.strictEqual(a[2], 3); + } + + export function test() { + run(new Int8Array(3), wasm.i8); + run(new Uint8Array(3), wasm.u8); + run(new Int16Array(3), wasm.i16); + run(new Uint16Array(3), wasm.u16); + run(new Int32Array(3), wasm.i32); + run(new Uint32Array(3), wasm.u32); + run(new Float32Array(3), wasm.f32); + run(new Float64Array(3), wasm.f64); + } + "#) + .test(); +} +