Merge pull request #2834 from rtfeldman/wasm-list-map-n

Wasm List.mapN
This commit is contained in:
Brian Carroll 2022-04-20 14:47:04 +01:00 committed by GitHub
commit f8156ffd53
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 181 additions and 84 deletions

View File

@ -7,7 +7,7 @@ use roc_collections::all::MutMap;
use roc_module::ident::Ident;
use roc_module::low_level::{LowLevel, LowLevelWrapperType};
use roc_module::symbol::{Interns, Symbol};
use roc_mono::code_gen_help::{CodeGenHelp, REFCOUNT_MAX};
use roc_mono::code_gen_help::{CodeGenHelp, HelperOp, REFCOUNT_MAX};
use roc_mono::ir::{
BranchInfo, CallType, Expr, JoinPointId, ListLiteralElement, Literal, ModifyRc, Param, Proc,
ProcLayout, Stmt,
@ -270,6 +270,13 @@ impl<'a> WasmBackend<'a> {
self.storage.stack_frame_size,
self.storage.stack_frame_pointer,
);
if DEBUG_LOG_SETTINGS.storage_map {
println!("\nStorage:");
for (sym, storage) in self.storage.symbol_storage_map.iter() {
println!("{:?} => {:?}", sym, storage);
}
}
}
fn append_proc_debug_name(&mut self, sym: Symbol) {
@ -1609,8 +1616,9 @@ impl<'a> WasmBackend<'a> {
);
}
/// Generate a refcount increment procedure and return its Wasm function index
pub fn gen_refcount_inc_for_zig(&mut self, layout: Layout<'a>) -> u32 {
/// Generate a refcount helper procedure and return a pointer (table index) to it
/// This allows it to be indirectly called from Zig code
pub fn get_refcount_fn_ptr(&mut self, layout: Layout<'a>, op: HelperOp) -> i32 {
let ident_ids = self
.interns
.all_ident_ids
@ -1619,7 +1627,7 @@ impl<'a> WasmBackend<'a> {
let (proc_symbol, new_specializations) = self
.helper_proc_gen
.gen_refcount_inc_proc(ident_ids, layout);
.gen_refcount_proc(ident_ids, layout, op);
// If any new specializations were created, register their symbol data
for (spec_sym, spec_layout) in new_specializations.into_iter() {
@ -1632,6 +1640,7 @@ impl<'a> WasmBackend<'a> {
.position(|lookup| lookup.name == proc_symbol && lookup.layout.arguments[0] == layout)
.unwrap();
self.fn_index_offset + proc_index as u32
let wasm_fn_index = self.fn_index_offset + proc_index as u32;
self.get_fn_table_index(wasm_fn_index)
}
}

View File

@ -253,6 +253,7 @@ pub struct WasmDebugLogSettings {
helper_procs_ir: bool,
let_stmt_ir: bool,
instructions: bool,
storage_map: bool,
pub keep_test_binary: bool,
}
@ -262,5 +263,6 @@ pub const DEBUG_LOG_SETTINGS: WasmDebugLogSettings = WasmDebugLogSettings {
helper_procs_ir: false && cfg!(debug_assertions),
let_stmt_ir: false && cfg!(debug_assertions),
instructions: false && cfg!(debug_assertions),
storage_map: false && cfg!(debug_assertions),
keep_test_binary: false && cfg!(debug_assertions),
};

View File

@ -3,6 +3,7 @@ use roc_builtins::bitcode::{self, FloatWidth, IntWidth};
use roc_error_macros::internal_error;
use roc_module::low_level::LowLevel;
use roc_module::symbol::Symbol;
use roc_mono::code_gen_help::HelperOp;
use roc_mono::ir::{HigherOrderLowLevel, PassedFunction, ProcLayout};
use roc_mono::layout::{Builtin, Layout, UnionLayout};
use roc_mono::low_level::HigherOrder;
@ -1014,58 +1015,73 @@ pub fn call_higher_order_lowlevel<'a>(
};
let wrapper_fn_idx = backend.register_helper_proc(wrapper_sym, wrapper_layout, source);
let inc_fn_idx = backend.gen_refcount_inc_for_zig(closure_data_layout);
let wrapper_fn_ptr = backend.get_fn_table_index(wrapper_fn_idx);
let inc_fn_ptr = backend.get_fn_table_index(inc_fn_idx);
let inc_fn_ptr = match closure_data_layout {
Layout::Struct {
field_layouts: &[], ..
} => {
// Our code gen would ignore the Unit arg, but the Zig builtin passes a pointer for it!
// That results in an exception (type signature mismatch in indirect call).
// The workaround is to use I32 layout, treating the (ignored) pointer as an integer.
backend.get_refcount_fn_ptr(Layout::Builtin(Builtin::Int(IntWidth::I32)), HelperOp::Inc)
}
_ => backend.get_refcount_fn_ptr(closure_data_layout, HelperOp::Inc),
};
match op {
// List.map : List elem_x, (elem_x -> elem_ret) -> List elem_ret
ListMap { xs } => {
let list_layout_in = backend.storage.symbol_layouts[xs];
ListMap { xs } => list_map_n(
bitcode::LIST_MAP,
backend,
&[*xs],
return_sym,
*return_layout,
wrapper_fn_ptr,
inc_fn_ptr,
closure_data_exists,
*captured_environment,
*owns_captured_environment,
),
let (elem_x, elem_ret) = match (list_layout_in, return_layout) {
(
Layout::Builtin(Builtin::List(elem_x)),
Layout::Builtin(Builtin::List(elem_ret)),
) => (elem_x, elem_ret),
_ => unreachable!("invalid layout for List.map arguments"),
};
let elem_x_size = elem_x.stack_size(TARGET_INFO);
let (elem_ret_size, elem_ret_align) = elem_ret.stack_size_and_alignment(TARGET_INFO);
ListMap2 { xs, ys } => list_map_n(
bitcode::LIST_MAP2,
backend,
&[*xs, *ys],
return_sym,
*return_layout,
wrapper_fn_ptr,
inc_fn_ptr,
closure_data_exists,
*captured_environment,
*owns_captured_environment,
),
let cb = &mut backend.code_builder;
ListMap3 { xs, ys, zs } => list_map_n(
bitcode::LIST_MAP3,
backend,
&[*xs, *ys, *zs],
return_sym,
*return_layout,
wrapper_fn_ptr,
inc_fn_ptr,
closure_data_exists,
*captured_environment,
*owns_captured_environment,
),
// Load return pointer & argument values
// Wasm signature: (i32, i64, i64, i32, i32, i32, i32, i32, i32, i32) -> nil
backend.storage.load_symbols(cb, &[return_sym]);
backend.storage.load_symbol_zig(cb, *xs); // list with capacity = 2 x i64 args
cb.i32_const(wrapper_fn_ptr);
if closure_data_exists {
backend.storage.load_symbols(cb, &[*captured_environment]);
} else {
// Normally, a zero-size arg would be eliminated in code gen, but Zig expects one!
cb.i32_const(0); // null pointer
}
cb.i32_const(inc_fn_ptr);
cb.i32_const(*owns_captured_environment as i32);
cb.i32_const(elem_ret_align as i32); // used for allocating the new list
cb.i32_const(elem_x_size as i32);
cb.i32_const(elem_ret_size as i32);
ListMap4 { xs, ys, zs, ws } => list_map_n(
bitcode::LIST_MAP4,
backend,
&[*xs, *ys, *zs, *ws],
return_sym,
*return_layout,
wrapper_fn_ptr,
inc_fn_ptr,
closure_data_exists,
*captured_environment,
*owns_captured_environment,
),
let num_wasm_args = 10; // 1 return pointer + 8 Zig args + list 2nd i64
let has_return_val = false;
backend.call_zig_builtin_after_loading_args(
bitcode::LIST_MAP,
num_wasm_args,
has_return_val,
);
}
ListMap2 { .. }
| ListMap3 { .. }
| ListMap4 { .. }
| ListMapWithIndex { .. }
ListMapWithIndex { .. }
| ListKeepIf { .. }
| ListWalk { .. }
| ListWalkUntil { .. }
@ -1079,3 +1095,71 @@ pub fn call_higher_order_lowlevel<'a>(
| DictWalk { .. } => todo!("{:?}", op),
}
}
fn unwrap_list_elem_layout(list_layout: Layout<'_>) -> &Layout<'_> {
match list_layout {
Layout::Builtin(Builtin::List(x)) => x,
e => internal_error!("expected List layout, got {:?}", e),
}
}
#[allow(clippy::too_many_arguments)]
fn list_map_n<'a>(
zig_fn_name: &'static str,
backend: &mut WasmBackend<'a>,
arg_symbols: &[Symbol],
return_sym: Symbol,
return_layout: Layout<'a>,
wrapper_fn_ptr: i32,
inc_fn_ptr: i32,
closure_data_exists: bool,
captured_environment: Symbol,
owns_captured_environment: bool,
) {
let arg_elem_layouts = Vec::from_iter_in(
arg_symbols
.iter()
.map(|sym| *unwrap_list_elem_layout(backend.storage.symbol_layouts[sym])),
backend.env.arena,
);
let elem_ret = unwrap_list_elem_layout(return_layout);
let (elem_ret_size, elem_ret_align) = elem_ret.stack_size_and_alignment(TARGET_INFO);
let cb = &mut backend.code_builder;
backend.storage.load_symbols(cb, &[return_sym]);
for s in arg_symbols {
backend.storage.load_symbol_zig(cb, *s);
}
cb.i32_const(wrapper_fn_ptr);
if closure_data_exists {
backend.storage.load_symbols(cb, &[captured_environment]);
} else {
// load_symbols assumes that a zero-size arg should be eliminated in code gen,
// but that's a specialization that our Zig code doesn't have! Pass a null pointer.
cb.i32_const(0);
}
cb.i32_const(inc_fn_ptr);
cb.i32_const(owns_captured_environment as i32);
cb.i32_const(elem_ret_align as i32);
for el in arg_elem_layouts.iter() {
cb.i32_const(el.stack_size(TARGET_INFO) as i32);
}
cb.i32_const(elem_ret_size as i32);
// If we have lists of different lengths, we may need to decrement
let num_wasm_args = if arg_elem_layouts.len() > 1 {
for el in arg_elem_layouts.iter() {
let ptr = backend.get_refcount_fn_ptr(*el, HelperOp::Dec);
backend.code_builder.i32_const(ptr);
}
7 + arg_elem_layouts.len() * 4
} else {
7 + arg_elem_layouts.len() * 3
};
let has_return_val = false;
backend.call_zig_builtin_after_loading_args(zig_fn_name, num_wasm_args, has_return_val);
}

View File

@ -62,7 +62,7 @@ struct VmBlock<'a> {
impl std::fmt::Debug for VmBlock<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("{:?}", self.opcode))
f.write_fmt(format_args!("{:?} {:?}", self.opcode, self.value_stack))
}
}
@ -608,7 +608,7 @@ impl<'a> CodeBuilder<'a> {
log_instruction!(
"{:10}\t\t{:?}",
format!("{:?}", opcode),
self.current_stack()
self.vm_block_stack
);
}
@ -635,7 +635,7 @@ impl<'a> CodeBuilder<'a> {
"{:10}\t{}\t{:?}",
format!("{:?}", opcode),
immediate,
self.current_stack()
self.vm_block_stack
);
}
@ -648,7 +648,7 @@ impl<'a> CodeBuilder<'a> {
format!("{:?}", opcode),
align,
offset,
self.current_stack()
self.vm_block_stack
);
}
@ -752,7 +752,7 @@ impl<'a> CodeBuilder<'a> {
"{:10}\t{}\t{:?}",
format!("{:?}", CALL),
function_index,
self.current_stack()
self.vm_block_stack
);
}
@ -823,7 +823,7 @@ impl<'a> CodeBuilder<'a> {
"{:10}\t{}\t{:?}",
format!("{:?}", opcode),
x,
self.current_stack()
self.vm_block_stack
);
}
pub fn i32_const(&mut self, x: i32) {

View File

@ -244,22 +244,22 @@ impl LowLevelWrapperType {
Symbol::LIST_JOIN => CanBeReplacedBy(ListJoin),
Symbol::LIST_RANGE => CanBeReplacedBy(ListRange),
Symbol::LIST_MAP => WrapperIsRequired,
Symbol::LIST_MAP2 => CanBeReplacedBy(ListMap2),
Symbol::LIST_MAP3 => CanBeReplacedBy(ListMap3),
Symbol::LIST_MAP4 => CanBeReplacedBy(ListMap4),
Symbol::LIST_MAP_WITH_INDEX => CanBeReplacedBy(ListMapWithIndex),
Symbol::LIST_KEEP_IF => CanBeReplacedBy(ListKeepIf),
Symbol::LIST_WALK => CanBeReplacedBy(ListWalk),
Symbol::LIST_WALK_UNTIL => CanBeReplacedBy(ListWalkUntil),
Symbol::LIST_WALK_BACKWARDS => CanBeReplacedBy(ListWalkBackwards),
Symbol::LIST_KEEP_OKS => CanBeReplacedBy(ListKeepOks),
Symbol::LIST_KEEP_ERRS => CanBeReplacedBy(ListKeepErrs),
Symbol::LIST_SORT_WITH => CanBeReplacedBy(ListSortWith),
Symbol::LIST_MAP2 => WrapperIsRequired,
Symbol::LIST_MAP3 => WrapperIsRequired,
Symbol::LIST_MAP4 => WrapperIsRequired,
Symbol::LIST_MAP_WITH_INDEX => WrapperIsRequired,
Symbol::LIST_KEEP_IF => WrapperIsRequired,
Symbol::LIST_WALK => WrapperIsRequired,
Symbol::LIST_WALK_UNTIL => WrapperIsRequired,
Symbol::LIST_WALK_BACKWARDS => WrapperIsRequired,
Symbol::LIST_KEEP_OKS => WrapperIsRequired,
Symbol::LIST_KEEP_ERRS => WrapperIsRequired,
Symbol::LIST_SORT_WITH => WrapperIsRequired,
Symbol::LIST_SUBLIST => CanBeReplacedBy(ListSublist),
Symbol::LIST_DROP_AT => CanBeReplacedBy(ListDropAt),
Symbol::LIST_SWAP => CanBeReplacedBy(ListSwap),
Symbol::LIST_ANY => CanBeReplacedBy(ListAny),
Symbol::LIST_ALL => CanBeReplacedBy(ListAll),
Symbol::LIST_ANY => WrapperIsRequired,
Symbol::LIST_ALL => WrapperIsRequired,
Symbol::LIST_FIND => WrapperIsRequired,
Symbol::DICT_LEN => CanBeReplacedBy(DictSize),
Symbol::DICT_EMPTY => CanBeReplacedBy(DictEmpty),
@ -272,7 +272,7 @@ impl LowLevelWrapperType {
Symbol::DICT_UNION => CanBeReplacedBy(DictUnion),
Symbol::DICT_INTERSECTION => CanBeReplacedBy(DictIntersection),
Symbol::DICT_DIFFERENCE => CanBeReplacedBy(DictDifference),
Symbol::DICT_WALK => CanBeReplacedBy(DictWalk),
Symbol::DICT_WALK => WrapperIsRequired,
Symbol::SET_FROM_LIST => CanBeReplacedBy(SetFromList),
Symbol::NUM_ADD => CanBeReplacedBy(NumAdd),
Symbol::NUM_ADD_WRAP => CanBeReplacedBy(NumAddWrap),

View File

@ -25,7 +25,7 @@ const ARG_2: Symbol = Symbol::ARG_2;
pub const REFCOUNT_MAX: usize = 0;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum HelperOp {
pub enum HelperOp {
Inc,
Dec,
DecRef(JoinPointId),
@ -185,16 +185,16 @@ impl<'a> CodeGenHelp<'a> {
/// Generate a refcount increment procedure, *without* a Call expression.
/// *This method should be rarely used* - only when the proc is to be called from Zig.
/// Otherwise you want to generate the Proc and the Call together, using another method.
/// This only supports the 'inc' operation, as it's the only real use case we have.
pub fn gen_refcount_inc_proc(
pub fn gen_refcount_proc(
&mut self,
ident_ids: &mut IdentIds,
layout: Layout<'a>,
op: HelperOp,
) -> (Symbol, Vec<'a, (Symbol, ProcLayout<'a>)>) {
let mut ctx = Context {
new_linker_data: Vec::new_in(self.arena),
recursive_union: None,
op: HelperOp::Inc,
op,
};
let proc_name = self.find_or_create_proc(ident_ids, &mut ctx, layout);

View File

@ -107,7 +107,9 @@ pub fn refcount_generic<'a>(
match layout {
Layout::Builtin(Builtin::Int(_) | Builtin::Float(_) | Builtin::Bool | Builtin::Decimal) => {
unreachable!("Not refcounted: {:?}", layout)
// Generate a dummy function that immediately returns Unit
// Some higher-order Zig builtins *always* call an RC function on List elements.
rc_return_stmt(root, ident_ids, ctx)
}
Layout::Builtin(Builtin::Str) => refcount_str(root, ident_ids, ctx),
Layout::Builtin(Builtin::List(elem_layout)) => {

View File

@ -1104,7 +1104,7 @@ fn list_map_closure() {
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn list_map4_group() {
assert_evals_to!(
indoc!(
@ -1112,13 +1112,13 @@ fn list_map4_group() {
List.map4 [1,2,3] [3,2,1] [2,1,3] [3,1,2] (\a, b, c, d -> Group a b c d)
"#
),
RocList::from_slice(&[(1, 3, 2, 3), (2, 2, 1, 1), (3, 1, 3, 2)]),
RocList<(i64, i64, i64, i64)>
RocList::from_slice(&[[1, 3, 2, 3], [2, 2, 1, 1], [3, 1, 3, 2]]),
RocList<[i64; 4]>
);
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn list_map4_different_length() {
assert_evals_to!(
indoc!(
@ -1137,7 +1137,7 @@ fn list_map4_different_length() {
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn list_map3_group() {
assert_evals_to!(
indoc!(
@ -1151,7 +1151,7 @@ fn list_map3_group() {
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn list_map3_different_length() {
assert_evals_to!(
indoc!(
@ -1169,7 +1169,7 @@ fn list_map3_different_length() {
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn list_map2_pair() {
assert_evals_to!(
indoc!(
@ -1184,13 +1184,13 @@ fn list_map2_pair() {
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn list_map2_different_lengths() {
assert_evals_to!(
indoc!(
r#"
List.map2
["a", "b", "lllllllllllllongnggg" ]
["a", "b", "lllllllllllllooooooooongnggg" ]
["b"]
(\a, b -> Str.concat a b)
"#