wasm_interp: remove CallStack, create Frame, & share value storage for stack and locals

This allows us to do calls without moving arguments from one place to another
This commit is contained in:
Brian Carroll 2022-12-16 21:38:37 +00:00
parent 8b8e385cde
commit caedb9060b
No known key found for this signature in database
GPG Key ID: 5C7B2EC4101703C0
7 changed files with 487 additions and 570 deletions

View File

@ -1,344 +0,0 @@
use bumpalo::{collections::Vec, Bump};
use roc_wasm_module::opcodes::OpCode;
use roc_wasm_module::sections::ImportDesc;
use roc_wasm_module::{parse::Parse, Value, ValueType, WasmModule};
use std::fmt::{self, Write};
use std::iter::repeat;
use crate::{pc_to_fn_index, Error, ValueStack};
/// Struct-of-Arrays storage for the call stack.
/// Type info is packed to avoid wasting space on padding.
/// However we store 64 bits for every local, even 32-bit values, for easy random access.
#[derive(Debug)]
pub struct CallStack<'a> {
/// return addresses and nested block depths (one entry per frame)
return_addrs_and_block_depths: Vec<'a, (u32, u32)>,
/// frame offsets into the `locals`, `is_float`, and `is_64` vectors (one entry per frame)
frame_offsets: Vec<'a, u32>,
/// base size of the value stack before executing (one entry per frame)
value_stack_bases: Vec<'a, u32>,
/// local variables (one entry per local)
locals: Vec<'a, Value>,
}
impl<'a> CallStack<'a> {
pub fn new(arena: &'a Bump) -> Self {
CallStack {
return_addrs_and_block_depths: Vec::with_capacity_in(256, arena),
frame_offsets: Vec::with_capacity_in(256, arena),
value_stack_bases: Vec::with_capacity_in(256, arena),
locals: Vec::with_capacity_in(16 * 256, arena),
}
}
/// On entering a Wasm call, save the return address, and make space for locals
pub(crate) fn push_frame(
&mut self,
return_addr: u32,
return_block_depth: u32,
arg_type_bytes: &[u8],
value_stack: &mut ValueStack<'a>,
code_bytes: &[u8],
pc: &mut usize,
) -> Result<(), crate::Error> {
self.return_addrs_and_block_depths
.push((return_addr, return_block_depth));
let frame_offset = self.locals.len();
self.frame_offsets.push(frame_offset as u32);
// Make space for arguments
let n_args = arg_type_bytes.len();
self.locals.extend(repeat(Value::I64(0)).take(n_args));
// Pop arguments off the value stack and into locals
for (i, type_byte) in arg_type_bytes.iter().copied().enumerate().rev() {
let arg = value_stack.pop();
let ty = ValueType::from(arg);
let expected_type = ValueType::from(type_byte);
if ty != expected_type {
return Err(Error::ValueStackType(expected_type, ty));
}
self.set_local_help(i as u32, arg);
}
self.value_stack_bases.push(value_stack.depth() as u32);
// Parse local variable declarations in the function header. They're grouped by type.
let local_group_count = u32::parse((), code_bytes, pc).unwrap();
for _ in 0..local_group_count {
let (group_size, ty) = <(u32, ValueType)>::parse((), code_bytes, pc).unwrap();
let n = group_size as usize;
let zero = match ty {
ValueType::I32 => Value::I32(0),
ValueType::I64 => Value::I64(0),
ValueType::F32 => Value::F32(0.0),
ValueType::F64 => Value::F64(0.0),
};
self.locals.extend(repeat(zero).take(n));
}
Ok(())
}
/// On returning from a Wasm call, drop its locals and retrieve the return address
pub fn pop_frame(&mut self) -> Option<(u32, u32)> {
let frame_offset = self.frame_offsets.pop()? as usize;
self.value_stack_bases.pop()?;
self.locals.truncate(frame_offset);
self.return_addrs_and_block_depths.pop()
}
pub fn get_local(&self, local_index: u32) -> Value {
self.get_local_help(self.frame_offsets.len() - 1, local_index)
}
fn get_local_help(&self, frame_index: usize, local_index: u32) -> Value {
let frame_offset = self.frame_offsets[frame_index];
let index = (frame_offset + local_index) as usize;
self.locals[index]
}
pub(crate) fn set_local(&mut self, local_index: u32, value: Value) -> Result<(), Error> {
let expected_type = self.set_local_help(local_index, value);
let actual_type = ValueType::from(value);
if actual_type == expected_type {
Ok(())
} else {
Err(Error::ValueStackType(expected_type, actual_type))
}
}
fn set_local_help(&mut self, local_index: u32, value: Value) -> ValueType {
let frame_offset = *self.frame_offsets.last().unwrap();
let index = (frame_offset + local_index) as usize;
let old_value = self.locals[index];
self.locals[index] = value;
ValueType::from(old_value)
}
pub fn value_stack_base(&self) -> u32 {
*self.value_stack_bases.last().unwrap_or(&0)
}
pub fn is_empty(&self) -> bool {
self.frame_offsets.is_empty()
}
/// Dump a stack trace of the WebAssembly program
///
/// --------------
/// function 123
/// address 0x12345
/// args 0: I64(234), 1: F64(7.15)
/// locals 2: I32(412), 3: F64(3.14)
/// stack [I64(111), F64(3.14)]
/// --------------
pub fn dump_trace(
&self,
module: &WasmModule<'a>,
value_stack: &ValueStack<'a>,
pc: usize,
buffer: &mut String,
) -> fmt::Result {
let divider = "-------------------";
writeln!(buffer, "{}", divider)?;
let mut value_stack_iter = value_stack.iter();
for frame in 0..self.frame_offsets.len() {
let next_frame = frame + 1;
let op_offset = if next_frame < self.frame_offsets.len() {
// return address of next frame = next op in this frame
let next_op = self.return_addrs_and_block_depths[next_frame].0 as usize;
// Call address is more intuitive than the return address when debugging. Search backward for it.
// Skip last byte of function index to avoid a false match with CALL/CALLINDIRECT.
// The more significant bytes won't match because of LEB-128 encoding.
let mut call_op = next_op - 2;
loop {
let byte = module.code.bytes[call_op];
if byte == OpCode::CALL as u8 || byte == OpCode::CALLINDIRECT as u8 {
break;
} else {
call_op -= 1;
}
}
call_op
} else {
pc
};
let fn_index = pc_to_fn_index(op_offset, module);
let address = op_offset + module.code.section_offset as usize;
writeln!(buffer, "function {}", fn_index)?;
writeln!(buffer, " address {:06x}", address)?; // format matches wasm-objdump, for easy search
write!(buffer, " args ")?;
let arg_count = {
let n_import_fns = module.import.imports.len();
let signature_index = if fn_index < n_import_fns {
match module.import.imports[fn_index].description {
ImportDesc::Func { signature_index } => signature_index,
_ => unreachable!(),
}
} else {
module.function.signatures[fn_index - n_import_fns]
};
module.types.look_up_arg_type_bytes(signature_index).len()
};
let args_and_locals_count = {
let frame_offset = self.frame_offsets[frame] as usize;
let next_frame_offset = if frame == self.frame_offsets.len() - 1 {
self.locals.len()
} else {
self.frame_offsets[frame + 1] as usize
};
next_frame_offset - frame_offset
};
for index in 0..args_and_locals_count {
let value = self.get_local_help(frame, index as u32);
if index != 0 {
write!(buffer, ", ")?;
}
if index == arg_count {
write!(buffer, "\n locals ")?;
}
write!(buffer, "{}: {:?}", index, value)?;
}
write!(buffer, "\n stack [")?;
let frame_value_count = {
let value_stack_base = self.value_stack_bases[frame];
let next_value_stack_base = if frame == self.frame_offsets.len() - 1 {
value_stack.depth() as u32
} else {
self.value_stack_bases[frame + 1]
};
next_value_stack_base - value_stack_base
};
for i in 0..frame_value_count {
if i != 0 {
write!(buffer, ", ")?;
}
if let Some(value) = value_stack_iter.next() {
write!(buffer, "{:?}", value)?;
}
}
writeln!(buffer, "]")?;
writeln!(buffer, "{}", divider)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use roc_wasm_module::Serialize;
use super::*;
const RETURN_ADDR: u32 = 0x12345;
fn test_get_set(call_stack: &mut CallStack<'_>, index: u32, value: Value) {
call_stack.set_local(index, value).unwrap();
assert_eq!(call_stack.get_local(index), value);
}
fn setup<'a>(arena: &'a Bump, call_stack: &mut CallStack<'a>) {
let mut buffer = vec![];
let mut cursor = 0;
let mut vs = ValueStack::new(arena);
// Push a other few frames before the test frame, just to make the scenario more typical.
[(1u32, ValueType::I32)].serialize(&mut buffer);
call_stack
.push_frame(0x11111, 0, &[], &mut vs, &buffer, &mut cursor)
.unwrap();
[(2u32, ValueType::I32)].serialize(&mut buffer);
call_stack
.push_frame(0x22222, 0, &[], &mut vs, &buffer, &mut cursor)
.unwrap();
[(3u32, ValueType::I32)].serialize(&mut buffer);
call_stack
.push_frame(0x33333, 0, &[], &mut vs, &buffer, &mut cursor)
.unwrap();
// Create a test call frame with local variables of every type
[
(8u32, ValueType::I32),
(4u32, ValueType::I64),
(2u32, ValueType::F32),
(1u32, ValueType::F64),
]
.serialize(&mut buffer);
call_stack
.push_frame(RETURN_ADDR, 0, &[], &mut vs, &buffer, &mut cursor)
.unwrap();
}
#[test]
fn test_all() {
let arena = Bump::new();
let mut call_stack = CallStack::new(&arena);
setup(&arena, &mut call_stack);
test_get_set(&mut call_stack, 0, Value::I32(123));
test_get_set(&mut call_stack, 8, Value::I64(123456));
test_get_set(&mut call_stack, 12, Value::F32(1.01));
test_get_set(&mut call_stack, 14, Value::F64(-1.1));
test_get_set(&mut call_stack, 0, Value::I32(i32::MIN));
test_get_set(&mut call_stack, 0, Value::I32(i32::MAX));
test_get_set(&mut call_stack, 8, Value::I64(i64::MIN));
test_get_set(&mut call_stack, 8, Value::I64(i64::MAX));
test_get_set(&mut call_stack, 12, Value::F32(f32::MIN));
test_get_set(&mut call_stack, 12, Value::F32(f32::MAX));
test_get_set(&mut call_stack, 14, Value::F64(f64::MIN));
test_get_set(&mut call_stack, 14, Value::F64(f64::MAX));
assert_eq!(call_stack.pop_frame(), Some((RETURN_ADDR, 0)));
}
#[test]
#[should_panic]
fn test_type_error_i32() {
let arena = Bump::new();
let mut call_stack = CallStack::new(&arena);
setup(&arena, &mut call_stack);
test_get_set(&mut call_stack, 0, Value::F32(1.01));
}
#[test]
#[should_panic]
fn test_type_error_i64() {
let arena = Bump::new();
let mut call_stack = CallStack::new(&arena);
setup(&arena, &mut call_stack);
test_get_set(&mut call_stack, 8, Value::F32(1.01));
}
#[test]
#[should_panic]
fn test_type_error_f32() {
let arena = Bump::new();
let mut call_stack = CallStack::new(&arena);
setup(&arena, &mut call_stack);
test_get_set(&mut call_stack, 12, Value::I32(123));
}
#[test]
#[should_panic]
fn test_type_error_f64() {
let arena = Bump::new();
let mut call_stack = CallStack::new(&arena);
setup(&arena, &mut call_stack);
test_get_set(&mut call_stack, 14, Value::I32(123));
}
}

View File

@ -0,0 +1,181 @@
use roc_wasm_module::{parse::Parse, Value, ValueType, WasmModule};
use std::fmt;
use std::iter::repeat;
use crate::value_stack::ValueStack;
#[derive(Debug)]
pub struct Frame {
/// The function this frame belongs to
pub fn_index: usize,
/// Address in the code section where this frame returns to
pub return_addr: usize,
/// Number of block scopes when this frame returns
pub return_block_depth: usize,
/// Offset in the ValueStack where the locals begin
pub locals_start: usize,
/// Number of locals in the frame
pub locals_count: usize,
}
impl Frame {
pub fn new() -> Self {
Frame {
fn_index: 0,
return_addr: 0,
return_block_depth: 0,
locals_start: 0,
locals_count: 0,
}
}
pub fn enter(
fn_index: usize,
return_addr: usize,
return_block_depth: usize,
arg_type_bytes: &[u8],
code_bytes: &[u8],
value_stack: &mut ValueStack<'_>,
pc: &mut usize,
) -> Self {
let n_args = arg_type_bytes.len();
let locals_start = value_stack.depth() - n_args;
// Parse local variable declarations in the function header. They're grouped by type.
let local_group_count = u32::parse((), code_bytes, pc).unwrap();
for _ in 0..local_group_count {
let (group_size, ty) = <(u32, ValueType)>::parse((), code_bytes, pc).unwrap();
let n = group_size as usize;
let zero = match ty {
ValueType::I32 => Value::I32(0),
ValueType::I64 => Value::I64(0),
ValueType::F32 => Value::F32(0.0),
ValueType::F64 => Value::F64(0.0),
};
value_stack.extend(repeat(zero).take(n));
}
let locals_count = value_stack.depth() - locals_start;
Frame {
fn_index,
return_addr,
return_block_depth,
locals_start,
locals_count,
}
}
pub fn get_local(&self, values: &ValueStack<'_>, index: u32) -> Value {
debug_assert!((index as usize) < self.locals_count);
*values.get(self.locals_start + index as usize).unwrap()
}
pub fn set_local(&self, values: &mut ValueStack<'_>, index: u32, value: Value) {
debug_assert!((index as usize) < self.locals_count);
values.set(self.locals_start + index as usize, value)
}
}
pub fn write_stack_trace(
_current_frame: &Frame,
_previous_frames: &[Frame],
_module: &WasmModule<'_>,
value_stack: &ValueStack<'_>,
_pc: usize,
_buffer: &mut String,
) -> fmt::Result {
let _ = value_stack.iter();
// let divider = "-------------------";
// writeln!(buffer, "{}", divider)?;
// let mut value_stack_iter = value_stack.iter();
// for frame in 0..self.frame_offsets.len() {
// let next_frame = frame + 1;
// let op_offset = if next_frame < self.frame_offsets.len() {
// // return address of next frame = next op in this frame
// let next_op = self.return_addrs_and_block_depths[next_frame].0 as usize;
// // Call address is more intuitive than the return address when debugging. Search backward for it.
// // Skip last byte of function index to avoid a false match with CALL/CALLINDIRECT.
// // The more significant bytes won't match because of LEB-128 encoding.
// let mut call_op = next_op - 2;
// loop {
// let byte = module.code.bytes[call_op];
// if byte == OpCode::CALL as u8 || byte == OpCode::CALLINDIRECT as u8 {
// break;
// } else {
// call_op -= 1;
// }
// }
// call_op
// } else {
// pc
// };
// let fn_index = pc_to_fn_index(op_offset, module);
// let address = op_offset + module.code.section_offset as usize;
// writeln!(buffer, "function {}", fn_index)?;
// writeln!(buffer, " address {:06x}", address)?; // format matches wasm-objdump, for easy search
// write!(buffer, " args ")?;
// let arg_count = {
// let n_import_fns = module.import.imports.len();
// let signature_index = if fn_index < n_import_fns {
// match module.import.imports[fn_index].description {
// ImportDesc::Func { signature_index } => signature_index,
// _ => unreachable!(),
// }
// } else {
// module.function.signatures[fn_index - n_import_fns]
// };
// module.types.look_up_arg_type_bytes(signature_index).len()
// };
// let args_and_locals_count = {
// let frame_offset = self.frame_offsets[frame] as usize;
// let next_frame_offset = if frame == self.frame_offsets.len() - 1 {
// self.locals.len()
// } else {
// self.frame_offsets[frame + 1] as usize
// };
// next_frame_offset - frame_offset
// };
// for index in 0..args_and_locals_count {
// let value = self.get_local_help(frame, index as u32);
// if index != 0 {
// write!(buffer, ", ")?;
// }
// if index == arg_count {
// write!(buffer, "\n locals ")?;
// }
// write!(buffer, "{}: {:?}", index, value)?;
// }
// write!(buffer, "\n stack [")?;
// let frame_value_count = {
// let value_stack_base = self.value_stack_bases[frame];
// let next_value_stack_base = if frame == self.frame_offsets.len() - 1 {
// value_stack.depth() as u32
// } else {
// self.value_stack_bases[frame + 1]
// };
// next_value_stack_base - value_stack_base
// };
// for i in 0..frame_value_count {
// if i != 0 {
// write!(buffer, ", ")?;
// }
// if let Some(value) = value_stack_iter.next() {
// write!(buffer, "{:?}", value)?;
// }
// }
// writeln!(buffer, "]")?;
// writeln!(buffer, "{}", divider)?;
// }
// Ok(())
todo!()
}

View File

@ -8,7 +8,7 @@ use roc_wasm_module::sections::{ImportDesc, MemorySection};
use roc_wasm_module::{ExportType, WasmModule};
use roc_wasm_module::{Value, ValueType};
use crate::call_stack::CallStack;
use crate::frame::{write_stack_trace, Frame};
use crate::value_stack::ValueStack;
use crate::{pc_to_fn_index, Error, ImportDispatcher};
@ -33,23 +33,21 @@ struct BranchCacheEntry {
#[derive(Debug)]
pub struct Instance<'a, I: ImportDispatcher> {
module: &'a WasmModule<'a>,
pub(crate) module: &'a WasmModule<'a>,
/// Contents of the WebAssembly instance's memory
pub memory: Vec<'a, u8>,
/// Metadata for every currently-active function call
pub call_stack: CallStack<'a>,
/// The current call frame
pub(crate) current_frame: Frame,
/// Previous call frames
previous_frames: Vec<'a, Frame>,
/// The WebAssembly stack machine's stack of values
pub value_stack: ValueStack<'a>,
pub(crate) value_stack: ValueStack<'a>,
/// Values of any global variables
pub globals: Vec<'a, Value>,
pub(crate) globals: Vec<'a, Value>,
/// Index in the code section of the current instruction
pub program_counter: usize,
pub(crate) program_counter: usize,
/// One entry per nested block. For loops, stores the address of the first instruction.
blocks: Vec<'a, Block>,
/// Outermost block depth for the currently-executing function.
outermost_block: u32,
/// Current function index
current_function: usize,
/// Cache for branching instructions, split into buckets for each function.
branch_cache: Vec<'a, Vec<'a, BranchCacheEntry>>,
/// Number of imports in the module
@ -78,14 +76,13 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
Instance {
module: arena.alloc(WasmModule::new(arena)),
memory: Vec::from_iter_in(iter::repeat(0).take(mem_bytes as usize), arena),
call_stack: CallStack::new(arena),
current_frame: Frame::new(),
previous_frames: Vec::new_in(arena),
value_stack: ValueStack::new(arena),
globals: Vec::from_iter_in(globals, arena),
program_counter,
blocks: Vec::new_in(arena),
outermost_block: 0,
branch_cache: bumpalo::vec![in arena; bumpalo::vec![in arena]],
current_function: 0,
import_count: 0,
import_dispatcher,
import_arguments: Vec::new_in(arena),
@ -130,7 +127,6 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
);
let value_stack = ValueStack::new(arena);
let call_stack = CallStack::new(arena);
let debug_string = if is_debug_mode {
Some(String::new())
@ -148,13 +144,12 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
Ok(Instance {
module,
memory,
call_stack,
current_frame: Frame::new(),
previous_frames: Vec::new_in(arena),
value_stack,
globals,
program_counter: usize::MAX,
blocks: Vec::new_in(arena),
outermost_block: 0,
current_function: usize::MAX,
branch_cache,
import_count,
import_dispatcher,
@ -167,7 +162,8 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
where
A: IntoIterator<Item = Value>,
{
let arg_type_bytes = self.prepare_to_call_export(self.module, fn_name)?;
let (fn_index, arg_type_bytes) =
self.call_export_help_before_arg_load(self.module, fn_name)?;
for (i, (value, type_byte)) in arg_values
.into_iter()
@ -185,7 +181,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.value_stack.push(value);
}
self.call_export_help(self.module, arg_type_bytes)
self.call_export_help_after_arg_load(self.module, fn_index, arg_type_bytes)
}
pub fn call_export_from_cli(
@ -207,7 +203,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
// Implement the "basic numbers" CLI
// Check if the called Wasm function takes numeric arguments, and if so, try to parse them from the CLI.
let arg_type_bytes = self.prepare_to_call_export(module, fn_name)?;
let (fn_index, arg_type_bytes) = self.call_export_help_before_arg_load(module, fn_name)?;
for (value_bytes, type_byte) in arg_strings
.iter()
.skip(1) // first string is the .wasm filename
@ -224,15 +220,15 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.value_stack.push(value);
}
self.call_export_help(module, arg_type_bytes)
self.call_export_help_after_arg_load(module, fn_index, arg_type_bytes)
}
fn prepare_to_call_export<'m>(
fn call_export_help_before_arg_load<'m>(
&mut self,
module: &'m WasmModule<'a>,
fn_name: &str,
) -> Result<&'m [u8], String> {
self.current_function = {
) -> Result<(usize, &'m [u8]), String> {
let fn_index = {
let mut export_iter = module.export.exports.iter();
export_iter
// First look up the name in exports
@ -266,7 +262,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
})? as usize
};
let internal_fn_index = self.current_function - self.import_count;
let internal_fn_index = fn_index - self.import_count;
self.program_counter = {
let mut cursor = module.code.function_offsets[internal_fn_index] as usize;
@ -282,30 +278,32 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
if self.debug_string.is_some() {
println!(
"Calling export func[{}] '{}' at address {:#x}",
self.current_function,
fn_index,
fn_name,
self.program_counter + module.code.section_offset as usize
);
}
Ok(arg_type_bytes)
Ok((fn_index, arg_type_bytes))
}
fn call_export_help(
fn call_export_help_after_arg_load(
&mut self,
module: &WasmModule<'a>,
fn_index: usize,
arg_type_bytes: &[u8],
) -> Result<Option<Value>, String> {
self.call_stack
.push_frame(
0, // return_addr
0, // return_block_depth
arg_type_bytes,
&mut self.value_stack,
&module.code.bytes,
&mut self.program_counter,
)
.map_err(|e| e.to_string_at(self.program_counter))?;
self.previous_frames.clear();
self.blocks.clear();
self.current_frame = Frame::enter(
fn_index,
0, // return_addr
0, // return_block_depth
arg_type_bytes,
&module.code.bytes,
&mut self.value_stack,
&mut self.program_counter,
);
loop {
match self.execute_next_instruction(module) {
@ -316,14 +314,15 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
Err(e) => {
let file_offset = self.program_counter + module.code.section_offset as usize;
let mut message = e.to_string_at(file_offset);
self.call_stack
.dump_trace(
module,
&self.value_stack,
self.program_counter,
&mut message,
)
.unwrap();
write_stack_trace(
&self.current_frame,
&self.previous_frames,
self.module,
&self.value_stack,
self.program_counter,
&mut message,
)
.unwrap();
return Err(message);
}
};
@ -347,18 +346,39 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
}
fn do_return(&mut self) -> Action {
self.blocks.truncate(self.outermost_block as usize);
if let Some((return_addr, block_depth)) = self.call_stack.pop_frame() {
if self.call_stack.is_empty() {
// We just popped the stack frame for the entry function. Terminate the program.
Action::Break
} else {
self.program_counter = return_addr as usize;
self.outermost_block = block_depth;
Action::Continue
}
let Frame {
return_addr,
return_block_depth,
locals_start,
..
} = self.current_frame;
// Check where in the value stack the current block started
let current_block_base = match self.blocks.last() {
Some(Block::Loop { vstack, .. } | Block::Normal { vstack }) => *vstack,
_ => 0,
};
// If there's a value on the stack in this block, we should return it
let return_value = if self.value_stack.depth() > current_block_base {
Some(self.value_stack.pop())
} else {
// We should never get here with real programs, but maybe in tests. Terminate the program.
None
};
self.value_stack.truncate(locals_start);
if let Some(val) = return_value {
self.value_stack.push(val);
}
self.blocks.truncate(return_block_depth);
self.program_counter = return_addr;
if let Some(caller_frame) = self.previous_frames.pop() {
self.current_frame = caller_frame;
Action::Continue
} else {
// We just popped the stack frame for the entry function. Terminate the program.
Action::Break
}
}
@ -411,7 +431,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
use OpCode::*;
let addr = self.program_counter as u32;
let cache_result = self.branch_cache[self.current_function]
let cache_result = self.branch_cache[self.current_frame.fn_index]
.iter()
.find(|entry| entry.addr == addr && entry.argument == relative_blocks_outward);
@ -437,7 +457,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
_ => {}
}
}
self.branch_cache[self.current_function].push(BranchCacheEntry {
self.branch_cache[self.current_frame.fn_index].push(BranchCacheEntry {
addr,
argument: relative_blocks_outward,
target: self.program_counter as u32,
@ -499,25 +519,28 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
write!(debug_string, " {}.{}", import.module, import.name).unwrap();
}
} else {
let return_addr = self.program_counter as u32;
let return_addr = self.program_counter;
let return_block_depth = self.blocks.len();
// set PC to start of function bytes
let internal_fn_index = fn_index - self.import_count;
self.program_counter = module.code.function_offsets[internal_fn_index] as usize;
// advance PC to the start of the local variable declarations
u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap();
let return_block_depth = self.outermost_block;
self.outermost_block = self.blocks.len() as u32;
let _function_byte_length =
u32::parse((), &module.code.bytes, &mut self.program_counter).unwrap();
self.call_stack.push_frame(
let mut swap_frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&mut self.value_stack,
&module.code.bytes,
&mut self.value_stack,
&mut self.program_counter,
)?;
);
std::mem::swap(&mut swap_frame, &mut self.current_frame);
self.previous_frames.push(swap_frame);
}
self.current_function = fn_index;
Ok(())
}
@ -565,7 +588,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
});
if condition == 0 {
let addr = self.program_counter as u32;
let cache_result = self.branch_cache[self.current_function]
let cache_result = self.branch_cache[self.current_frame.fn_index]
.iter()
.find(|entry| entry.addr == addr);
if let Some(entry) = cache_result {
@ -598,7 +621,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
_ => {}
}
}
self.branch_cache[self.current_function].push(BranchCacheEntry {
self.branch_cache[self.current_frame.fn_index].push(BranchCacheEntry {
addr,
argument: 0,
target: self.program_counter as u32,
@ -613,7 +636,7 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
self.do_break(0, module);
}
END => {
if self.blocks.len() == self.outermost_block as usize {
if self.blocks.len() == self.current_frame.return_block_depth {
// implicit RETURN at end of function
action = self.do_return();
implicit_return = true;
@ -692,18 +715,20 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
}
GETLOCAL => {
let index = self.fetch_immediate_u32(module);
let value = self.call_stack.get_local(index);
let value = self.current_frame.get_local(&self.value_stack, index);
self.value_stack.push(value);
}
SETLOCAL => {
let index = self.fetch_immediate_u32(module);
let value = self.value_stack.pop();
self.call_stack.set_local(index, value)?;
self.current_frame
.set_local(&mut self.value_stack, index, value);
}
TEELOCAL => {
let index = self.fetch_immediate_u32(module);
let value = self.value_stack.peek();
self.call_stack.set_local(index, value)?;
self.current_frame
.set_local(&mut self.value_stack, index, value);
}
GETGLOBAL => {
let index = self.fetch_immediate_u32(module);
@ -1618,12 +1643,15 @@ impl<'a, I: ImportDispatcher> Instance<'a, I> {
}
if let Some(debug_string) = &self.debug_string {
let base = self.call_stack.value_stack_base();
let base = self.current_frame.locals_start + self.current_frame.locals_count;
let slice = self.value_stack.get_slice(base as usize);
eprintln!("{:06x} {:17} {:?}", file_offset, debug_string, slice);
if op_code == RETURN || (op_code == END && implicit_return) {
let fn_index = pc_to_fn_index(self.program_counter, module);
eprintln!("returning to function {}\n", fn_index);
eprintln!(
"returning to function {} at pc={:06x}\n",
fn_index, self.program_counter
);
} else if op_code == CALL || op_code == CALLINDIRECT {
eprintln!();
}

View File

@ -1,4 +1,4 @@
mod call_stack;
mod frame;
mod instance;
mod tests;
mod value_stack;
@ -10,7 +10,6 @@ pub use wasi::{WasiDispatcher, WasiFile};
pub use roc_wasm_module::Value;
use roc_wasm_module::{ValueType, WasmModule};
use value_stack::ValueStack;
pub trait ImportDispatcher {
/// Dispatch a call from WebAssembly to your own code, based on module and function name.

View File

@ -92,7 +92,7 @@ where
}
let mut inst =
Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), false).unwrap();
Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), true).unwrap();
let return_val = inst.call_export("test", []).unwrap().unwrap();

View File

@ -1,7 +1,8 @@
#![cfg(test)]
use super::{const_value, create_exported_function_no_locals, default_state};
use crate::{instance::Action, DefaultImportDispatcher, ImportDispatcher, Instance, ValueStack};
use crate::frame::Frame;
use crate::{instance::Action, DefaultImportDispatcher, ImportDispatcher, Instance};
use bumpalo::{collections::Vec, Bump};
use roc_wasm_module::sections::{Import, ImportDesc};
use roc_wasm_module::{
@ -17,88 +18,95 @@ fn test_loop() {
fn test_loop_help(end: i32, expected: i32) {
let arena = Bump::new();
let mut module = WasmModule::new(&arena);
let buf = &mut module.code.bytes;
{
let buf = &mut module.code.bytes;
// Loop from 0 to end, adding the loop variable to a total
let var_i = 0;
let var_total = 1;
// Loop from 0 to end, adding the loop variable to a total
let var_i = 0;
let var_total = 1;
// (local i32 i32)
buf.push(1); // one group of the given type
buf.push(2); // two locals in the group
buf.push(ValueType::I32 as u8);
let fn_len_index = buf.encode_padded_u32(0);
// loop <void>
buf.push(OpCode::LOOP as u8);
buf.push(ValueType::VOID as u8);
// (local i32 i32)
buf.push(1); // one group of the given type
buf.push(2); // two locals in the group
buf.push(ValueType::I32 as u8);
// local.get $i
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_i);
// loop <void>
buf.push(OpCode::LOOP as u8);
buf.push(ValueType::VOID as u8);
// i32.const 1
buf.push(OpCode::I32CONST as u8);
buf.encode_i32(1);
// local.get $i
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_i);
// i32.add
buf.push(OpCode::I32ADD as u8);
// i32.const 1
buf.push(OpCode::I32CONST as u8);
buf.encode_i32(1);
// local.tee $i
buf.push(OpCode::TEELOCAL as u8);
buf.encode_u32(var_i);
// i32.add
buf.push(OpCode::I32ADD as u8);
// local.get $total
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_total);
// local.tee $i
buf.push(OpCode::TEELOCAL as u8);
buf.encode_u32(var_i);
// i32.add
buf.push(OpCode::I32ADD as u8);
// local.get $total
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_total);
// local.set $total
buf.push(OpCode::SETLOCAL as u8);
buf.encode_u32(var_total);
// i32.add
buf.push(OpCode::I32ADD as u8);
// local.get $i
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_i);
// local.set $total
buf.push(OpCode::SETLOCAL as u8);
buf.encode_u32(var_total);
// i32.const $end
buf.push(OpCode::I32CONST as u8);
buf.encode_i32(end);
// local.get $i
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_i);
// i32.lt_s
buf.push(OpCode::I32LTS as u8);
// i32.const $end
buf.push(OpCode::I32CONST as u8);
buf.encode_i32(end);
// br_if 0
buf.push(OpCode::BRIF as u8);
buf.encode_u32(0);
// i32.lt_s
buf.push(OpCode::I32LTS as u8);
// end
buf.push(OpCode::END as u8);
// br_if 0
buf.push(OpCode::BRIF as u8);
buf.encode_u32(0);
// local.get $total
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_total);
// end
buf.push(OpCode::END as u8);
// end function
buf.push(OpCode::END as u8);
// local.get $total
buf.push(OpCode::GETLOCAL as u8);
buf.encode_u32(var_total);
let mut state = default_state(&arena);
state
.call_stack
.push_frame(
0,
0,
&[],
&mut state.value_stack,
&module.code.bytes,
&mut state.program_counter,
)
.unwrap();
// end function
buf.push(OpCode::END as u8);
while let Ok(Action::Continue) = state.execute_next_instruction(&module) {}
buf.overwrite_padded_u32(fn_len_index, (buf.len() - fn_len_index) as u32);
}
module.code.function_offsets.push(0);
module.code.function_count = 1;
assert_eq!(state.value_stack.pop_i32(), Ok(expected));
module.add_function_signature(Signature {
param_types: Vec::new_in(&arena),
ret_type: Some(ValueType::I32),
});
module.export.append(Export {
name: "test",
ty: ExportType::Func,
index: 0,
});
let mut inst =
Instance::for_module(&arena, &module, DefaultImportDispatcher::default(), false).unwrap();
let return_val = inst.call_export("test", []).unwrap().unwrap();
assert_eq!(return_val, Value::I32(expected));
}
#[test]
@ -155,17 +163,20 @@ fn test_if_else_help(condition: i32, expected: i32) {
buf.push(OpCode::END as u8);
let mut state = default_state(&arena);
state
.call_stack
.push_frame(
0,
0,
&[],
&mut state.value_stack,
&module.code.bytes,
&mut state.program_counter,
)
.unwrap();
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
let frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buf,
&mut state.value_stack,
&mut state.program_counter,
);
state.current_frame = frame;
while let Ok(Action::Continue) = state.execute_next_instruction(&module) {}
@ -247,17 +258,20 @@ fn test_br() {
buf.push(OpCode::END as u8);
state
.call_stack
.push_frame(
0,
0,
&[],
&mut state.value_stack,
&module.code.bytes,
&mut state.program_counter,
)
.unwrap();
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
let frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buf,
&mut state.value_stack,
&mut state.program_counter,
);
state.current_frame = frame;
while let Ok(Action::Continue) = state.execute_next_instruction(&module) {}
@ -348,17 +362,20 @@ fn test_br_if_help(condition: i32, expected: i32) {
buf.push(OpCode::END as u8);
state
.call_stack
.push_frame(
0,
0,
&[],
&mut state.value_stack,
&module.code.bytes,
&mut state.program_counter,
)
.unwrap();
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
let frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buf,
&mut state.value_stack,
&mut state.program_counter,
);
state.current_frame = frame;
while let Ok(Action::Continue) = state.execute_next_instruction(&module) {}
@ -455,17 +472,20 @@ fn test_br_table_help(condition: i32, expected: i32) {
println!("{:02x?}", buf);
state
.call_stack
.push_frame(
0,
0,
&[],
&mut state.value_stack,
&module.code.bytes,
&mut state.program_counter,
)
.unwrap();
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
let frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buf,
&mut state.value_stack,
&mut state.program_counter,
);
state.current_frame = frame;
while let Ok(Action::Continue) = state.execute_next_instruction(&module) {}
@ -489,7 +509,6 @@ impl ImportDispatcher for TestDispatcher {
assert_eq!(arguments.len(), 1);
let val = arguments[0].expect_i32().unwrap();
self.internal_state += val;
dbg!(val, self.internal_state);
Some(Value::I32(self.internal_state))
}
}
@ -789,30 +808,33 @@ fn test_select_help(first: Value, second: Value, condition: i32, expected: Value
buf.push(OpCode::SELECT as u8);
buf.push(OpCode::END as u8);
let mut state = default_state(&arena);
state
.call_stack
.push_frame(
0,
0,
&[],
&mut state.value_stack,
&module.code.bytes,
&mut state.program_counter,
)
.unwrap();
let mut inst = default_state(&arena);
while let Ok(Action::Continue) = state.execute_next_instruction(&module) {}
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
let frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buf,
&mut inst.value_stack,
&mut inst.program_counter,
);
inst.current_frame = frame;
assert_eq!(state.value_stack.pop(), expected);
while let Ok(Action::Continue) = inst.execute_next_instruction(&module) {}
assert_eq!(inst.value_stack.pop(), expected);
}
#[test]
fn test_set_get_local() {
let arena = Bump::new();
let mut state = default_state(&arena);
let mut inst = default_state(&arena);
let mut module = WasmModule::new(&arena);
let mut vs = ValueStack::new(&arena);
let mut buffer = vec![];
let mut cursor = 0;
@ -823,10 +845,20 @@ fn test_set_get_local() {
(1u32, ValueType::I64),
]
.serialize(&mut buffer);
state
.call_stack
.push_frame(0x1234, 0, &[], &mut vs, &buffer, &mut cursor)
.unwrap();
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
inst.current_frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buffer,
&mut inst.value_stack,
&mut cursor,
);
module.code.bytes.push(OpCode::I32CONST as u8);
module.code.bytes.encode_i32(12345);
@ -836,19 +868,18 @@ fn test_set_get_local() {
module.code.bytes.push(OpCode::GETLOCAL as u8);
module.code.bytes.encode_u32(2);
state.execute_next_instruction(&module).unwrap();
state.execute_next_instruction(&module).unwrap();
state.execute_next_instruction(&module).unwrap();
assert_eq!(state.value_stack.depth(), 1);
assert_eq!(state.value_stack.pop(), Value::I32(12345));
inst.execute_next_instruction(&module).unwrap();
inst.execute_next_instruction(&module).unwrap();
inst.execute_next_instruction(&module).unwrap();
assert_eq!(inst.value_stack.depth(), 5);
assert_eq!(inst.value_stack.pop(), Value::I32(12345));
}
#[test]
fn test_tee_get_local() {
let arena = Bump::new();
let mut state = default_state(&arena);
let mut inst = default_state(&arena);
let mut module = WasmModule::new(&arena);
let mut vs = ValueStack::new(&arena);
let mut buffer = vec![];
let mut cursor = 0;
@ -859,10 +890,20 @@ fn test_tee_get_local() {
(1u32, ValueType::I64),
]
.serialize(&mut buffer);
state
.call_stack
.push_frame(0x1234, 0, &[], &mut vs, &buffer, &mut cursor)
.unwrap();
let fn_index = 0;
let return_addr = 0x1234;
let return_block_depth = 0;
let arg_type_bytes = &[];
inst.current_frame = Frame::enter(
fn_index,
return_addr,
return_block_depth,
arg_type_bytes,
&buffer,
&mut inst.value_stack,
&mut cursor,
);
module.code.bytes.push(OpCode::I32CONST as u8);
module.code.bytes.encode_i32(12345);
@ -872,12 +913,12 @@ fn test_tee_get_local() {
module.code.bytes.push(OpCode::GETLOCAL as u8);
module.code.bytes.encode_u32(2);
state.execute_next_instruction(&module).unwrap();
state.execute_next_instruction(&module).unwrap();
state.execute_next_instruction(&module).unwrap();
assert_eq!(state.value_stack.depth(), 2);
assert_eq!(state.value_stack.pop(), Value::I32(12345));
assert_eq!(state.value_stack.pop(), Value::I32(12345));
inst.execute_next_instruction(&module).unwrap();
inst.execute_next_instruction(&module).unwrap();
inst.execute_next_instruction(&module).unwrap();
assert_eq!(inst.value_stack.depth(), 6);
assert_eq!(inst.value_stack.pop(), Value::I32(12345));
assert_eq!(inst.value_stack.pop(), Value::I32(12345));
}
#[test]

View File

@ -38,6 +38,18 @@ impl<'a> ValueStack<'a> {
*self.values.last().unwrap()
}
pub(crate) fn get(&self, index: usize) -> Option<&Value> {
self.values.get(index)
}
pub(crate) fn set(&mut self, index: usize, value: Value) {
self.values[index] = value;
}
pub(crate) fn extend<I: Iterator<Item = Value>>(&mut self, values: I) {
self.values.extend(values)
}
/// Memory addresses etc
pub(crate) fn pop_u32(&mut self) -> Result<u32, Error> {
match self.values.pop() {