diff --git a/crates/compiler/alias_analysis/src/lib.rs b/crates/compiler/alias_analysis/src/lib.rs index 723c78333b..27a3be9947 100644 --- a/crates/compiler/alias_analysis/src/lib.rs +++ b/crates/compiler/alias_analysis/src/lib.rs @@ -1422,6 +1422,38 @@ fn expr_spec<'a>( builder.add_get_tuple_field(block, variant_id, index) } }, + UnionFieldPtrAtIndex { + index, + tag_id, + structure, + union_layout, + } => { + let index = (*index) as u32; + let tag_value_id = env.symbols[structure]; + + let type_name_bytes = recursive_tag_union_name_bytes(union_layout).as_bytes(); + let type_name = TypeName(&type_name_bytes); + + // unwrap the named wrapper + let union_id = builder.add_unwrap_named(block, MOD_APP, type_name, tag_value_id)?; + + // now we have a tuple (cell, union { ... }); decompose + let heap_cell = builder.add_get_tuple_field(block, union_id, TAG_CELL_INDEX)?; + let union_data = builder.add_get_tuple_field(block, union_id, TAG_DATA_INDEX)?; + + // we're reading from this value, so touch the heap cell + builder.add_touch(block, heap_cell)?; + + // next, unwrap the union at the tag id that we've got + let variant_id = builder.add_unwrap_union(block, union_data, *tag_id as u32)?; + + let value = builder.add_get_tuple_field(block, variant_id, index)?; + + // construct the box. Here the heap_cell of the tag is re-used, I'm hoping that that + // conveys to morphic that we're borrowing into the existing tag?! + builder.add_make_tuple(block, &[heap_cell, value]) + } + StructAtIndex { index, structure, .. } => { diff --git a/crates/compiler/can/src/builtins.rs b/crates/compiler/can/src/builtins.rs index 9f05e61ddf..d59e23d523 100644 --- a/crates/compiler/can/src/builtins.rs +++ b/crates/compiler/can/src/builtins.rs @@ -85,7 +85,9 @@ macro_rules! map_symbol_to_lowlevel_and_arity { // these are used internally and not tied to a symbol LowLevel::Hash => unimplemented!(), LowLevel::PtrCast => unimplemented!(), - LowLevel::PtrWrite => unimplemented!(), + LowLevel::PtrStore => unimplemented!(), + LowLevel::PtrLoad => unimplemented!(), + LowLevel::PtrToZeroed => unimplemented!(), LowLevel::RefCountIncRcPtr => unimplemented!(), LowLevel::RefCountDecRcPtr=> unimplemented!(), LowLevel::RefCountIncDataPtr => unimplemented!(), diff --git a/crates/compiler/debug_flags/src/lib.rs b/crates/compiler/debug_flags/src/lib.rs index 3852100c71..ede77bd317 100644 --- a/crates/compiler/debug_flags/src/lib.rs +++ b/crates/compiler/debug_flags/src/lib.rs @@ -135,6 +135,10 @@ flags! { /// instructions. ROC_PRINT_IR_AFTER_REFCOUNT + /// Writes a pretty-printed mono IR to stderr after the tail recursion (modulo cons) + /// has been applied. + ROC_PRINT_IR_AFTER_TRMC + /// Writes a pretty-printed mono IR to stderr after performing dropspecialization. /// Which inlines drop functions to remove pairs of alloc/dealloc instructions of its children. ROC_PRINT_IR_AFTER_DROP_SPECIALIZATION diff --git a/crates/compiler/gen_dev/src/lib.rs b/crates/compiler/gen_dev/src/lib.rs index 648ab8247c..fd63a1d95f 100644 --- a/crates/compiler/gen_dev/src/lib.rs +++ b/crates/compiler/gen_dev/src/lib.rs @@ -163,6 +163,9 @@ impl<'a> LastSeenMap<'a> { Expr::UnionAtIndex { structure, .. } => { self.set_last_seen(*structure, stmt); } + Expr::UnionFieldPtrAtIndex { structure, .. } => { + self.set_last_seen(*structure, stmt); + } Expr::Array { elems, .. } => { for elem in *elems { if let ListLiteralElement::Symbol(sym) = elem { @@ -794,6 +797,14 @@ trait Backend<'a> { } => { self.load_union_at_index(sym, structure, *tag_id, *index, union_layout); } + Expr::UnionFieldPtrAtIndex { + structure, + tag_id, + union_layout, + index, + } => { + todo!(); + } Expr::GetTagId { structure, union_layout, @@ -1581,7 +1592,7 @@ trait Backend<'a> { self.build_ptr_cast(sym, &args[0]) } - LowLevel::PtrWrite => { + LowLevel::PtrStore => { let element_layout = match self.interner().get_repr(*ret_layout) { LayoutRepr::Boxed(boxed) => boxed, _ => unreachable!("cannot write to {:?}", self.interner().dbg(*ret_layout)), @@ -1589,6 +1600,10 @@ trait Backend<'a> { self.build_ptr_write(*sym, args[0], args[1], element_layout); } + LowLevel::PtrLoad => { + // + todo!() + } LowLevel::RefCountDecRcPtr => self.build_fn_call( sym, bitcode::UTILS_DECREF_RC_PTR.to_string(), diff --git a/crates/compiler/gen_llvm/src/llvm/build.rs b/crates/compiler/gen_llvm/src/llvm/build.rs index c400bcce26..b7f91ddff0 100644 --- a/crates/compiler/gen_llvm/src/llvm/build.rs +++ b/crates/compiler/gen_llvm/src/llvm/build.rs @@ -23,7 +23,7 @@ use inkwell::passes::{PassManager, PassManagerBuilder}; use inkwell::types::{ AnyType, BasicMetadataTypeEnum, BasicType, BasicTypeEnum, FunctionType, IntType, StructType, }; -use inkwell::values::BasicValueEnum::{self}; +use inkwell::values::BasicValueEnum; use inkwell::values::{ BasicMetadataValueEnum, CallSiteValue, FunctionValue, InstructionValue, IntValue, PointerValue, StructValue, @@ -1379,12 +1379,13 @@ pub(crate) fn build_exp_expr<'a, 'ctx>( layout_interner.get_repr(layout), ); - lookup_at_index_ptr2( + lookup_at_index_ptr( env, layout_interner, field_layouts, *index as usize, ptr, + None, target_loaded_type, ) } @@ -1404,7 +1405,7 @@ pub(crate) fn build_exp_expr<'a, 'ctx>( field_layouts, *index as usize, argument.into_pointer_value(), - struct_type.into_struct_type(), + Some(struct_type.into_struct_type()), target_loaded_type, ) } @@ -1430,12 +1431,13 @@ pub(crate) fn build_exp_expr<'a, 'ctx>( layout_interner.get_repr(layout), ); - lookup_at_index_ptr2( + lookup_at_index_ptr( env, layout_interner, field_layouts, *index as usize, ptr, + None, target_loaded_type, ) } @@ -1463,13 +1465,117 @@ pub(crate) fn build_exp_expr<'a, 'ctx>( // the tag id is not stored *index as usize, argument.into_pointer_value(), - struct_type.into_struct_type(), + Some(struct_type.into_struct_type()), target_loaded_type, ) } } } + UnionFieldPtrAtIndex { + tag_id, + structure, + index, + union_layout, + } => { + // cast the argument bytes into the desired shape for this tag + let (argument, structure_layout) = scope.load_symbol_and_layout(structure); + let ret_repr = layout_interner.get_repr(layout); + + let pointer_value = match union_layout { + UnionLayout::NonRecursive(_) => unreachable!(), + UnionLayout::Recursive(tag_layouts) => { + debug_assert!(argument.is_pointer_value()); + + let field_layouts = tag_layouts[*tag_id as usize]; + + let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value()); + let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr); + + union_field_at_index( + env, + layout_interner, + field_layouts, + None, + *index as usize, + ptr, + target_loaded_type, + ) + } + UnionLayout::NonNullableUnwrapped(field_layouts) => { + let struct_layout = LayoutRepr::struct_(field_layouts); + + let struct_type = basic_type_from_layout(env, layout_interner, struct_layout); + let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr); + + union_field_at_index( + env, + layout_interner, + field_layouts, + Some(struct_type.into_struct_type()), + *index as usize, + argument.into_pointer_value(), + target_loaded_type, + ) + } + UnionLayout::NullableWrapped { + nullable_id, + other_tags, + } => { + debug_assert!(argument.is_pointer_value()); + debug_assert_ne!(*tag_id, *nullable_id); + + let tag_index = if *tag_id < *nullable_id { + *tag_id + } else { + tag_id - 1 + }; + + let field_layouts = other_tags[tag_index as usize]; + + let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value()); + let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr); + + union_field_at_index( + env, + layout_interner, + field_layouts, + None, + *index as usize, + ptr, + target_loaded_type, + ) + .into() + } + UnionLayout::NullableUnwrapped { + nullable_id, + other_fields, + } => { + debug_assert!(argument.is_pointer_value()); + debug_assert_ne!(*tag_id != 0, *nullable_id); + + let field_layouts = other_fields; + let struct_layout = LayoutRepr::struct_(field_layouts); + + let struct_type = basic_type_from_layout(env, layout_interner, struct_layout); + let target_loaded_type = basic_type_from_layout(env, layout_interner, ret_repr); + + union_field_at_index( + env, + layout_interner, + field_layouts, + Some(struct_type.into_struct_type()), + // the tag id is not stored + *index as usize, + argument.into_pointer_value(), + target_loaded_type, + ) + } + }; + + pointer_value.into() + } + GetTagId { structure, union_layout, @@ -2025,21 +2131,20 @@ fn lookup_at_index_ptr<'a, 'ctx>( field_layouts: &[InLayout<'a>], index: usize, value: PointerValue<'ctx>, - struct_type: StructType<'ctx>, + struct_type: Option>, target_loaded_type: BasicTypeEnum<'ctx>, ) -> BasicValueEnum<'ctx> { let builder = env.builder; - let ptr = env.builder.build_pointer_cast( + let elem_ptr = union_field_at_index_help( + env, + layout_interner, + field_layouts, + struct_type, + index, value, - struct_type.ptr_type(AddressSpace::default()), - "cast_lookup_at_index_ptr", ); - let elem_ptr = builder - .new_build_struct_gep(struct_type, ptr, index as u32, "at_index_struct_gep") - .unwrap(); - let field_layout = field_layouts[index]; let result = load_roc_value( env, @@ -2054,19 +2159,23 @@ fn lookup_at_index_ptr<'a, 'ctx>( cast_if_necessary_for_opaque_recursive_pointers(env.builder, result, target_loaded_type) } -fn lookup_at_index_ptr2<'a, 'ctx>( +fn union_field_at_index_help<'a, 'ctx>( env: &Env<'a, 'ctx, '_>, layout_interner: &STLayoutInterner<'a>, field_layouts: &'a [InLayout<'a>], + opt_struct_type: Option>, index: usize, value: PointerValue<'ctx>, - target_loaded_type: BasicTypeEnum<'ctx>, -) -> BasicValueEnum<'ctx> { +) -> PointerValue<'ctx> { let builder = env.builder; - let struct_layout = LayoutRepr::struct_(field_layouts); - let struct_type = - basic_type_from_layout(env, layout_interner, struct_layout).into_struct_type(); + let struct_type = match opt_struct_type { + Some(st) => st, + None => { + let struct_layout = LayoutRepr::struct_(field_layouts); + basic_type_from_layout(env, layout_interner, struct_layout).into_struct_type() + } + }; let data_ptr = env.builder.build_pointer_cast( value, @@ -2074,27 +2183,40 @@ fn lookup_at_index_ptr2<'a, 'ctx>( "cast_lookup_at_index_ptr", ); - let elem_ptr = builder + builder .new_build_struct_gep( struct_type, data_ptr, index as u32, "at_index_struct_gep_data", ) - .unwrap(); + .unwrap() +} - let field_layout = field_layouts[index]; - let result = load_roc_value( +fn union_field_at_index<'a, 'ctx>( + env: &Env<'a, 'ctx, '_>, + layout_interner: &STLayoutInterner<'a>, + field_layouts: &'a [InLayout<'a>], + opt_struct_type: Option>, + index: usize, + value: PointerValue<'ctx>, + target_loaded_type: BasicTypeEnum<'ctx>, +) -> PointerValue<'ctx> { + let result = union_field_at_index_help( env, layout_interner, - layout_interner.get_repr(field_layout), - elem_ptr, - "load_at_index_ptr", + field_layouts, + opt_struct_type, + index, + value, ); // A recursive pointer in the loaded structure is stored as a `i64*`, but the loaded layout // might want a more precise structure. As such, cast it to the refined type if needed. - cast_if_necessary_for_opaque_recursive_pointers(env.builder, result, target_loaded_type) + let from_value: BasicValueEnum = result.into(); + let to_type: BasicTypeEnum = target_loaded_type; + cast_if_necessary_for_opaque_recursive_pointers(env.builder, from_value, to_type) + .into_pointer_value() } pub fn reserve_with_refcount<'a, 'ctx>( @@ -3071,7 +3193,7 @@ pub fn cast_if_necessary_for_opaque_recursive_pointers<'ctx>( to_type: BasicTypeEnum<'ctx>, ) -> BasicValueEnum<'ctx> { if from_value.get_type() != to_type - // Only perform the cast if the target types are transumatble. + // Only perform the cast if the target types are transmutable. && equivalent_type_constructors(&from_value.get_type(), &to_type) { complex_bitcast( diff --git a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs index 4bfb429ce9..cdc3af28cb 100644 --- a/crates/compiler/gen_llvm/src/llvm/lowlevel.rs +++ b/crates/compiler/gen_llvm/src/llvm/lowlevel.rs @@ -1304,8 +1304,28 @@ pub(crate) fn run_low_level<'a, 'ctx>( .into() } - PtrStore | PtrLoad | PtrToZeroed | RefCountIncRcPtr | RefCountDecRcPtr - | RefCountIncDataPtr | RefCountDecDataPtr => { + PtrStore => { + arguments!(ptr, value); + + env.builder.build_store(ptr.into_pointer_value(), value); + + // ptr + env.context.struct_type(&[], false).const_zero().into() + } + + PtrLoad => { + arguments!(ptr); + + let ret_repr = layout_interner.get_repr(layout); + let element_type = basic_type_from_layout(env, layout_interner, ret_repr); + + env.builder + .new_build_load(element_type, ptr.into_pointer_value(), "ptr_load") + } + + PtrToZeroed => todo!(), + + RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr => { unreachable!("Not used in LLVM backend: {:?}", op); } diff --git a/crates/compiler/gen_wasm/src/backend.rs b/crates/compiler/gen_wasm/src/backend.rs index 0476c9d19a..dde79f02bd 100644 --- a/crates/compiler/gen_wasm/src/backend.rs +++ b/crates/compiler/gen_wasm/src/backend.rs @@ -1079,6 +1079,13 @@ impl<'a, 'r> WasmBackend<'a, 'r> { index, } => self.expr_union_at_index(*structure, *tag_id, union_layout, *index, sym), + Expr::UnionFieldPtrAtIndex { + structure, + tag_id, + union_layout, + index, + } => todo!(), + Expr::ExprBox { symbol: arg_sym } => self.expr_box(sym, *arg_sym, layout, storage), Expr::ExprUnbox { symbol: arg_sym } => self.expr_unbox(sym, *arg_sym), diff --git a/crates/compiler/load_internal/src/file.rs b/crates/compiler/load_internal/src/file.rs index 5cef4dd2ce..86d9805786 100644 --- a/crates/compiler/load_internal/src/file.rs +++ b/crates/compiler/load_internal/src/file.rs @@ -16,7 +16,7 @@ use roc_can::module::{ }; use roc_collections::{default_hasher, BumpMap, MutMap, MutSet, VecMap, VecSet}; use roc_constrain::module::constrain_module; -use roc_debug_flags::dbg_do; +use roc_debug_flags::{dbg_do, ROC_PRINT_IR_AFTER_TRMC}; #[cfg(debug_assertions)] use roc_debug_flags::{ ROC_CHECK_MONO_IR, ROC_PRINT_IR_AFTER_DROP_SPECIALIZATION, ROC_PRINT_IR_AFTER_REFCOUNT, @@ -3104,6 +3104,16 @@ fn update<'a>( let ident_ids = state.constrained_ident_ids.get_mut(&module_id).unwrap(); + roc_mono::tail_recursion::apply_trmc( + arena, + &mut layout_interner, + module_id, + ident_ids, + &mut state.procedures, + ); + + debug_print_ir!(state, &layout_interner, ROC_PRINT_IR_AFTER_TRMC); + inc_dec::insert_inc_dec_operations( arena, &layout_interner, diff --git a/crates/compiler/module/src/low_level.rs b/crates/compiler/module/src/low_level.rs index 32fb65ed6d..76f02bc51c 100644 --- a/crates/compiler/module/src/low_level.rs +++ b/crates/compiler/module/src/low_level.rs @@ -230,7 +230,9 @@ macro_rules! map_symbol_to_lowlevel { // these are used internally and not tied to a symbol LowLevel::Hash => unimplemented!(), LowLevel::PtrCast => unimplemented!(), - LowLevel::PtrWrite => unimplemented!(), + LowLevel::PtrStore => unimplemented!(), + LowLevel::PtrLoad => unimplemented!(), + LowLevel::PtrToZeroed => unimplemented!(), LowLevel::RefCountIncRcPtr => unimplemented!(), LowLevel::RefCountDecRcPtr=> unimplemented!(), LowLevel::RefCountIncDataPtr => unimplemented!(), diff --git a/crates/compiler/mono/src/borrow.rs b/crates/compiler/mono/src/borrow.rs index 6978ccae3e..b71ce34ae9 100644 --- a/crates/compiler/mono/src/borrow.rs +++ b/crates/compiler/mono/src/borrow.rs @@ -741,6 +741,14 @@ impl<'a> BorrowInfState<'a> { self.if_is_owned_then_own(z, *x); } + UnionFieldPtrAtIndex { structure: x, .. } => { + // if the structure (record/tag/array) is owned, the extracted value is + self.if_is_owned_then_own(*x, z); + + // if the extracted value is owned, the structure must be too + self.if_is_owned_then_own(z, *x); + } + GetTagId { structure: x, .. } => { // if the structure (record/tag/array) is owned, the extracted value is self.if_is_owned_then_own(*x, z); @@ -1035,7 +1043,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[Ownership] { unreachable!("These lowlevel operations are turned into mono Expr's") } - PtrWrite => arena.alloc_slice_copy(&[irrelevant, irrelevant]), + PtrStore => arena.alloc_slice_copy(&[owned, borrowed]), + PtrLoad => arena.alloc_slice_copy(&[owned]), + PtrToZeroed => arena.alloc_slice_copy(&[owned]), PtrCast | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr | RefCountIsUnique => { diff --git a/crates/compiler/mono/src/debug/checker.rs b/crates/compiler/mono/src/debug/checker.rs index 9362121262..d6f872ffc1 100644 --- a/crates/compiler/mono/src/debug/checker.rs +++ b/crates/compiler/mono/src/debug/checker.rs @@ -429,6 +429,15 @@ impl<'a, 'r> Ctx<'a, 'r> { } => self.with_sym_layout(structure, |ctx, _def_line, layout| { ctx.check_union_at_index(structure, layout, union_layout, tag_id, index) }), + &Expr::UnionFieldPtrAtIndex { + structure, + tag_id, + union_layout, + index, + } => self.with_sym_layout(structure, |ctx, _def_line, layout| { + // TODO: I suspect this will fail because the output layout has an extra Box layer? + ctx.check_union_at_index(structure, layout, union_layout, tag_id, index) + }), Expr::Array { elem_layout, elems } => { for elem in elems.iter() { match elem { diff --git a/crates/compiler/mono/src/drop_specialization.rs b/crates/compiler/mono/src/drop_specialization.rs index 82a4978bd5..a8aa12247f 100644 --- a/crates/compiler/mono/src/drop_specialization.rs +++ b/crates/compiler/mono/src/drop_specialization.rs @@ -211,7 +211,20 @@ fn specialize_drops_stmt<'a, 'i>( // TODO perhaps we need the union_layout later as well? if so, create a new function/map to store it. environment.add_union_child(*structure, *binding, *tag_id, *index); // Generated code might know the tag of the union without switching on it. - // So if we unionAtIndex, we must know the tag and we can use it to specialize the drop. + // So if we UnionAtIndex, we must know the tag and we can use it to specialize the drop. + environment.symbol_tag.insert(*structure, *tag_id); + alloc_let_with_continuation!(environment) + } + Expr::UnionFieldPtrAtIndex { + structure, + tag_id, + union_layout: _, + index, + } => { + // TODO perhaps we need the union_layout later as well? if so, create a new function/map to store it. + environment.add_union_child(*structure, *binding, *tag_id, *index); + // Generated code might know the tag of the union without switching on it. + // So if we UnionFieldPtrAtIndex, we must know the tag and we can use it to specialize the drop. environment.symbol_tag.insert(*structure, *tag_id); alloc_let_with_continuation!(environment) } @@ -1666,8 +1679,13 @@ fn low_level_no_rc(lowlevel: &LowLevel) -> RC { unreachable!("These lowlevel operations are turned into mono Expr's") } - PtrCast | PtrWrite | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr - | RefCountDecDataPtr | RefCountIsUnique => { + // only inserted for internal purposes. RC should not touch it + PtrStore => RC::NoRc, + PtrLoad => RC::NoRc, + PtrToZeroed => RC::NoRc, + + PtrCast | RefCountIncRcPtr | RefCountDecRcPtr | RefCountIncDataPtr | RefCountDecDataPtr + | RefCountIsUnique => { unreachable!("Only inserted *after* borrow checking: {:?}", lowlevel); } } diff --git a/crates/compiler/mono/src/inc_dec.rs b/crates/compiler/mono/src/inc_dec.rs index e03f9a8ea7..e02b53e1f2 100644 --- a/crates/compiler/mono/src/inc_dec.rs +++ b/crates/compiler/mono/src/inc_dec.rs @@ -345,17 +345,16 @@ impl<'v> RefcountEnvironment<'v> { // A groupby or something similar would be nice here. let mut symbol_usage = MutMap::default(); for symbol in symbols { - match { - self.symbols_rc_types - .get(&symbol) - .expect("Expected symbol to be in the map") - } { + match self.symbols_rc_types.get(&symbol) { // If the symbol is reference counted, we need to increment the usage count. - VarRcType::ReferenceCounted => { + Some(VarRcType::ReferenceCounted) => { *symbol_usage.entry(symbol).or_default() += 1; } // If the symbol is not reference counted, we don't need to do anything. - VarRcType::NotReferenceCounted => continue, + Some(VarRcType::NotReferenceCounted) => continue, + None => { + internal_error!("symbol {symbol:?} does not have an rc type") + } } } symbol_usage @@ -891,6 +890,7 @@ fn insert_refcount_operations_binding<'a>( Expr::GetTagId { structure, .. } | Expr::StructAtIndex { structure, .. } | Expr::UnionAtIndex { structure, .. } + | Expr::UnionFieldPtrAtIndex { structure, .. } | Expr::ExprUnbox { symbol: structure } => { // All structures are alive at this point and don't have to be copied in order to take an index out/get tag id/copy values to the stack. // But we do want to make sure to decrement this item if it is the last reference. @@ -904,6 +904,7 @@ fn insert_refcount_operations_binding<'a>( match expr { Expr::StructAtIndex { .. } | Expr::UnionAtIndex { .. } + | Expr::UnionFieldPtrAtIndex { .. } | Expr::ExprUnbox { .. } => insert_inc_stmt(arena, *binding, 1, new_stmt), // No usage of an element of a reference counted symbol. No need to increment. Expr::GetTagId { .. } => new_stmt, diff --git a/crates/compiler/mono/src/ir.rs b/crates/compiler/mono/src/ir.rs index c8553bdaa3..39a25f6815 100644 --- a/crates/compiler/mono/src/ir.rs +++ b/crates/compiler/mono/src/ir.rs @@ -403,41 +403,6 @@ impl<'a> Proc<'a> { w.push(b'\n'); String::from_utf8(w).unwrap() } - - fn make_tail_recursive(&mut self, interner: &mut I, env: &mut Env<'a, '_>) - where - I: LayoutInterner<'a>, - { - let mut args = Vec::with_capacity_in(self.args.len(), env.arena); - let mut proc_args = Vec::with_capacity_in(self.args.len(), env.arena); - - for (layout, symbol) in self.args { - let new = env.unique_symbol(); - args.push((*layout, *symbol, new)); - proc_args.push((*layout, new)); - } - - use self::SelfRecursive::*; - if let SelfRecursive(id) = self.is_self_recursive { - if crate::tail_recursion::is_trmc_candidate(interner, self) { - *self = crate::tail_recursion::TrmcEnv::init(env, interner, self); - } else { - let transformed = crate::tail_recursion::make_tail_recursive( - env.arena, - id, - self.name, - self.body.clone(), - args.into_bump_slice(), - self.ret_layout, - ); - - if let Some(with_tco) = transformed { - self.body = with_tco; - self.args = proc_args.into_bump_slice(); - } - } - } - } } /// A host-exposed function must be specialized; it's a seed for subsequent specializations @@ -1032,7 +997,7 @@ impl<'a> Procs<'a> { MutMap::with_capacity_and_hasher(self.specialized.len(), default_hasher()); for (symbol, layout, mut proc) in self.specialized.into_iter_assert_done() { - proc.make_tail_recursive(&mut layout_cache.interner, env); + // proc.make_tail_recursive(&mut layout_cache.interner, env); let key = (symbol, layout); specialized_procs.insert(key, proc); @@ -1888,6 +1853,12 @@ pub enum Expr<'a> { union_layout: UnionLayout<'a>, index: u64, }, + UnionFieldPtrAtIndex { + structure: Symbol, + tag_id: TagIdIntType, + union_layout: UnionLayout<'a>, + index: u64, + }, Array { elem_layout: InLayout<'a>, @@ -2105,6 +2076,19 @@ impl<'a> Expr<'a> { .. } => text!(alloc, "UnionAtIndex (Id {}) (Index {}) ", tag_id, index) .append(symbol_to_doc(alloc, *structure, pretty)), + + UnionFieldPtrAtIndex { + tag_id, + structure, + index, + .. + } => text!( + alloc, + "UnionFieldPtrAtIndex (Id {}) (Index {}) ", + tag_id, + index + ) + .append(symbol_to_doc(alloc, *structure, pretty)), } } @@ -7678,6 +7662,21 @@ fn substitute_in_expr<'a>( }), None => None, }, + + UnionFieldPtrAtIndex { + structure, + tag_id, + index, + union_layout, + } => match substitute(subs, *structure) { + Some(structure) => Some(UnionFieldPtrAtIndex { + structure, + tag_id: *tag_id, + index: *index, + union_layout: *union_layout, + }), + None => None, + }, } } diff --git a/crates/compiler/mono/src/tail_recursion.rs b/crates/compiler/mono/src/tail_recursion.rs index 77db249119..5213181b6f 100644 --- a/crates/compiler/mono/src/tail_recursion.rs +++ b/crates/compiler/mono/src/tail_recursion.rs @@ -1,11 +1,88 @@ #![allow(clippy::manual_map)] use crate::borrow::Ownership; -use crate::ir::{Call, CallType, Env, Expr, JoinPointId, Param, Proc, SelfRecursive, Stmt}; -use crate::layout::{InLayout, LambdaName, LayoutInterner, LayoutRepr, TagIdIntType, UnionLayout}; +use crate::ir::{ + Call, CallType, Expr, JoinPointId, Param, Proc, ProcLayout, SelfRecursive, Stmt, UpdateModeId, +}; +use crate::layout::{ + InLayout, LambdaName, Layout, LayoutInterner, LayoutRepr, STLayoutInterner, TagIdIntType, + UnionLayout, +}; use bumpalo::collections::Vec; use bumpalo::Bump; -use roc_module::symbol::Symbol; +use roc_collections::MutMap; +use roc_module::low_level::LowLevel; +use roc_module::symbol::{IdentIds, ModuleId, Symbol}; + +pub struct Env<'a, 'i> { + arena: &'a Bump, + home: ModuleId, + interner: &'i mut STLayoutInterner<'a>, + ident_ids: &'i mut IdentIds, +} + +impl<'a, 'i> Env<'a, 'i> { + pub fn unique_symbol(&mut self) -> Symbol { + let ident_id = self.ident_ids.gen_unique(); + + Symbol::new(self.home, ident_id) + } + + pub fn named_unique_symbol(&mut self, name: &str) -> Symbol { + let ident_id = self.ident_ids.add_str(name); + Symbol::new(self.home, ident_id) + } +} + +pub fn apply_trmc<'a, 'i>( + arena: &'a Bump, + interner: &'i mut STLayoutInterner<'a>, + home: ModuleId, + ident_ids: &'i mut IdentIds, + procs: &mut MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, +) { + let mut env = Env { + arena, + interner, + home, + ident_ids, + }; + + let env = &mut env; + + for (_, proc) in procs { + use self::SelfRecursive::*; + if let SelfRecursive(id) = proc.is_self_recursive { + if crate::tail_recursion::is_trmc_candidate(env.interner, proc) { + let new_proc = crate::tail_recursion::TrmcEnv::init(env, proc); + *proc = new_proc; + } else { + let mut args = Vec::with_capacity_in(proc.args.len(), arena); + let mut proc_args = Vec::with_capacity_in(proc.args.len(), arena); + + for (layout, symbol) in proc.args { + let new = env.unique_symbol(); + args.push((*layout, *symbol, new)); + proc_args.push((*layout, new)); + } + + let transformed = crate::tail_recursion::make_tail_recursive( + arena, + id, + proc.name, + proc.body.clone(), + args.into_bump_slice(), + proc.ret_layout, + ); + + if let Some(with_tco) = transformed { + proc.body = with_tco; + proc.args = proc_args.into_bump_slice(); + } + } + } + } +} /// Make tail calls into loops (using join points) /// @@ -325,7 +402,7 @@ fn insert_jumps<'a>( } } -pub(crate) fn is_trmc_candidate<'a, I>(interner: &I, proc: &Proc<'a>) -> bool +pub(crate) fn is_trmc_candidate<'a, I>(interner: &'_ I, proc: &'_ Proc<'a>) -> bool where I: LayoutInterner<'a>, { @@ -338,10 +415,68 @@ where } // and return a recursive tag union - match interner.get_repr(proc.ret_layout) { - LayoutRepr::Union(union_layout) => union_layout.is_recursive(), - _ => false, + if !matches!(interner.get_repr(proc.ret_layout), LayoutRepr::Union(union_layout) if union_layout.is_recursive()) + { + return false; } + + has_cons_in_tail_position(&proc.body, proc.name) +} + +fn has_cons_in_tail_position(initial_stmt: &Stmt<'_>, function_name: LambdaName) -> bool { + // we are looking for code of the form + // + // let x = Tag a b c + // ret x + + let mut stack = vec![(None, initial_stmt)]; + + while let Some((recursive_call, stmt)) = stack.pop() { + match stmt { + Stmt::Let(symbol, expr, _, next) => { + if let Some(cons_info) = TrmcEnv::is_terminal_constructor(stmt) { + // must use the result of a recursive call directly as an argument + if let Some(recursive_call) = recursive_call { + if cons_info.arguments.contains(&recursive_call) { + return true; + } + } + } + + let recursive_call = recursive_call + .or_else(|| TrmcEnv::is_recursive_expr(expr, function_name).map(|_| *symbol)); + + stack.push((recursive_call, next)); + } + Stmt::Switch { + branches, + default_branch, + .. + } => { + for (_, _, stmt) in branches.iter() { + stack.push((recursive_call, stmt)); + } + stack.push((recursive_call, default_branch.1)); + } + Stmt::Refcounting(_, next) => { + stack.push((recursive_call, next)); + } + Stmt::Expect { remainder, .. } + | Stmt::ExpectFx { remainder, .. } + | Stmt::Dbg { remainder, .. } => { + stack.push((recursive_call, remainder)); + } + Stmt::Join { + body, remainder, .. + } => { + stack.push((recursive_call, body)); + stack.push((recursive_call, remainder)); + } + Stmt::Ret(_) | Stmt::Jump(_, _) | Stmt::Crash(_, _) => { /* terminal */ } + } + } + + false } #[derive(Clone)] @@ -358,6 +493,7 @@ pub(crate) struct TrmcEnv<'a> { recursive_call: Option<(Symbol, Call<'a>)>, } +#[derive(Debug)] struct ConstructorInfo<'a> { tag_layout: UnionLayout<'a>, tag_id: TagIdIntType, @@ -365,25 +501,18 @@ struct ConstructorInfo<'a> { } impl<'a> TrmcEnv<'a> { - fn is_recursive_expr(&mut self, expr: &Expr<'a>) -> Option> { - if let Expr::Call(call) = expr { - self.is_recursive_call(call).then_some(call.clone()) - } else { - None - } - } - - fn is_terminal_constructor(&mut self, stmt: &Stmt<'a>) -> Option> { + #[inline(always)] + fn is_terminal_constructor(stmt: &Stmt<'a>) -> Option> { match stmt { Stmt::Let(s1, expr, _layout, Stmt::Ret(s2)) if s1 == s2 => { - self.get_contructor_info(expr) + Self::get_contructor_info(expr) } _ => None, } } - fn get_contructor_info(&mut self, expr: &Expr<'a>) -> Option> { + fn get_contructor_info(expr: &Expr<'a>) -> Option> { if let Expr::Tag { tag_layout, tag_id, @@ -402,16 +531,19 @@ impl<'a> TrmcEnv<'a> { } } - fn is_recursive_call(&mut self, call: &Call<'a>) -> bool { + fn is_recursive_expr(expr: &Expr<'a>, lambda_name: LambdaName<'_>) -> Option> { + if let Expr::Call(call) = expr { + Self::is_recursive_call(call, lambda_name).then_some(call.clone()) + } else { + None + } + } + + fn is_recursive_call(call: &Call<'a>, lambda_name: LambdaName<'_>) -> bool { match call.call_type { - CallType::ByName { - name, - ret_layout, - arg_layouts, - specialization_id, - } => { + CallType::ByName { name, .. } => { // TODO are there other restrictions? - name == self.function_name + name == lambda_name } CallType::Foreign { .. } | CallType::LowLevel { .. } | CallType::HigherOrder(_) => { false @@ -421,16 +553,16 @@ impl<'a> TrmcEnv<'a> { fn ptr_write( env: &mut Env<'a, '_>, - interner: &mut impl LayoutInterner<'a>, - return_layout: InLayout<'a>, + _return_layout: InLayout<'a>, ptr: Symbol, value: Symbol, next: &'a Stmt<'a>, ) -> Stmt<'a> { let box_write = Call { call_type: crate::ir::CallType::LowLevel { - op: roc_module::low_level::LowLevel::PtrWrite, - update_mode: env.next_update_mode_id(), + op: LowLevel::PtrStore, + // update_mode: env.next_update_mode_id(), + update_mode: UpdateModeId::BACKEND_DUMMY, }, arguments: env.arena.alloc([ptr, value]), }; @@ -438,16 +570,13 @@ impl<'a> TrmcEnv<'a> { Stmt::Let( env.named_unique_symbol("_ptr_write_unit"), Expr::Call(box_write), - interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)), + // interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)), + Layout::UNIT, next, ) } - pub fn init( - env: &mut Env<'a, '_>, - interner: &mut impl LayoutInterner<'a>, - proc: &Proc<'a>, - ) -> Proc<'a> { + pub fn init<'i>(env: &mut Env<'a, 'i>, proc: &Proc<'a>) -> Proc<'a> { let arena = env.arena; let return_layout = proc.ret_layout; @@ -475,8 +604,9 @@ impl<'a> TrmcEnv<'a> { let null_symbol = env.named_unique_symbol("null"); let let_null = |next| Stmt::Let(null_symbol, Expr::NullPointer, return_layout, next); - let box_return_layout = - interner.insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)); + let box_return_layout = env + .interner + .insert_direct_no_semantic(LayoutRepr::Boxed(return_layout)); let box_null = Expr::ExprBox { symbol: null_symbol, }; @@ -508,7 +638,7 @@ impl<'a> TrmcEnv<'a> { let joinpoint = Stmt::Join { id: joinpoint_id, parameters: joinpoint_parameters.into_bump_slice(), - body: arena.alloc(this.walk_stmt(env, interner, &proc.body)), + body: arena.alloc(this.walk_stmt(env, &proc.body)), remainder: arena.alloc(jump_stmt), }; @@ -534,24 +664,19 @@ impl<'a> TrmcEnv<'a> { } } - fn walk_stmt( - &mut self, - env: &mut Env<'a, '_>, - interner: &mut impl LayoutInterner<'a>, - stmt: &Stmt<'a>, - ) -> Stmt<'a> { + fn walk_stmt(&mut self, env: &mut Env<'a, '_>, stmt: &Stmt<'a>) -> Stmt<'a> { let arena = env.arena; match stmt { Stmt::Let(symbol, expr, layout, next) => { if self.recursive_call.is_none() { - if let Some(call) = self.is_recursive_expr(expr) { + if let Some(call) = Self::is_recursive_expr(expr, self.function_name) { self.recursive_call = Some((*symbol, call)); - return self.walk_stmt(env, interner, next); + return self.walk_stmt(env, next); } } - if let Some(cons_info) = self.is_terminal_constructor(stmt) { + if let Some(cons_info) = Self::is_terminal_constructor(stmt) { match &self.recursive_call { None => { // this control flow path did not encounter a recursive call. Just @@ -561,7 +686,7 @@ impl<'a> TrmcEnv<'a> { let output = define_tag(arena.alloc( // - self.non_trmc_return(env, interner, *symbol), + self.non_trmc_return(env, *symbol), )); return output; @@ -571,11 +696,21 @@ impl<'a> TrmcEnv<'a> { // branch. // TODO remove unwrap. also what if the symbol occurs more than once? - let recursive_field_index = cons_info - .arguments - .iter() - .position(|s| *s == *call_symbol) - .unwrap(); + let opt_recursive_field_index = + cons_info.arguments.iter().position(|s| *s == *call_symbol); + + let recursive_field_index = match opt_recursive_field_index { + None => { + let next = self.walk_stmt(env, next); + return Stmt::Let( + *symbol, + expr.clone(), + *layout, + arena.alloc(next), + ); + } + Some(v) => v, + }; let mut arguments = Vec::from_iter_in(cons_info.arguments.iter().copied(), env.arena); @@ -589,8 +724,11 @@ impl<'a> TrmcEnv<'a> { let let_tag = |next| Stmt::Let(*symbol, tag_expr, *layout, next); - let get_reference_expr = Expr::ExprBox { - symbol: self.null_symbol, + let get_reference_expr = Expr::UnionFieldPtrAtIndex { + structure: *symbol, + tag_id: cons_info.tag_id, + union_layout: cons_info.tag_layout, + index: recursive_field_index as _, }; let new_hole_symbol = env.named_unique_symbol("newHole"); @@ -616,7 +754,6 @@ impl<'a> TrmcEnv<'a> { // Self::ptr_write( env, - interner, *layout, self.hole_symbol, *symbol, @@ -630,7 +767,7 @@ impl<'a> TrmcEnv<'a> { } } - let next = self.walk_stmt(env, interner, next); + let next = self.walk_stmt(env, next); Stmt::Let(*symbol, expr.clone(), *layout, arena.alloc(next)) } Stmt::Switch { @@ -646,14 +783,13 @@ impl<'a> TrmcEnv<'a> { for (id, info, stmt) in branches.iter() { self.recursive_call = opt_recursive_call.clone(); - let new_stmt = self.walk_stmt(env, interner, stmt); + let new_stmt = self.walk_stmt(env, stmt); new_branches.push((*id, info.clone(), new_stmt)); } self.recursive_call = opt_recursive_call; - let new_default_branch = - &*arena.alloc(self.walk_stmt(env, interner, default_branch.1)); + let new_default_branch = &*arena.alloc(self.walk_stmt(env, default_branch.1)); Stmt::Switch { cond_symbol: *cond_symbol, @@ -666,42 +802,92 @@ impl<'a> TrmcEnv<'a> { Stmt::Ret(symbol) => { // write the symbol we're supposed to return into the hole // then read initial_symbol and return its contents - self.non_trmc_return(env, interner, *symbol) + self.non_trmc_return(env, *symbol) } - Stmt::Refcounting(_, _) => todo!(), - Stmt::Expect { .. } => todo!(), - Stmt::ExpectFx { .. } => todo!(), - Stmt::Dbg { .. } => todo!(), - Stmt::Join { .. } => todo!(), - Stmt::Jump(_, _) => todo!(), - Stmt::Crash(_, _) => todo!(), + Stmt::Refcounting(op, next) => { + let new_next = self.walk_stmt(env, next); + Stmt::Refcounting(*op, arena.alloc(new_next)) + } + Stmt::Expect { + condition, + region, + lookups, + variables, + remainder, + } => Stmt::Expect { + condition: *condition, + region: *region, + lookups, + variables, + remainder: arena.alloc(self.walk_stmt(env, remainder)), + }, + Stmt::ExpectFx { + condition, + region, + lookups, + variables, + remainder, + } => Stmt::Expect { + condition: *condition, + region: *region, + lookups, + variables, + remainder: arena.alloc(self.walk_stmt(env, remainder)), + }, + Stmt::Dbg { + symbol, + variable, + remainder, + } => Stmt::Dbg { + symbol: *symbol, + variable: *variable, + remainder: arena.alloc(self.walk_stmt(env, remainder)), + }, + Stmt::Join { + id, + parameters, + body, + remainder, + } => { + let new_body = self.walk_stmt(env, body); + let new_remainder = self.walk_stmt(env, remainder); + + Stmt::Join { + id: *id, + parameters, + body: arena.alloc(new_body), + remainder: arena.alloc(new_remainder), + } + } + Stmt::Jump(id, arguments) => Stmt::Jump(*id, arguments), + Stmt::Crash(symbol, crash_tag) => Stmt::Crash(*symbol, *crash_tag), } } - fn non_trmc_return( - &mut self, - env: &mut Env<'a, '_>, - interner: &mut impl LayoutInterner<'a>, - value_symbol: Symbol, - ) -> Stmt<'a> { + fn non_trmc_return(&mut self, env: &mut Env<'a, '_>, value_symbol: Symbol) -> Stmt<'a> { let arena = env.arena; let layout = self.return_layout; - let unbox_expr = Expr::ExprUnbox { - symbol: self.initial_box_symbol, - }; let final_symbol = env.named_unique_symbol("final"); - let unbox = |next| Stmt::Let(final_symbol, unbox_expr, layout, next); + + let call = Call { + call_type: CallType::LowLevel { + op: LowLevel::PtrLoad, + update_mode: UpdateModeId::BACKEND_DUMMY, + }, + arguments: &*arena.alloc([self.initial_box_symbol]), + }; + + let ptr_load = |next| Stmt::Let(final_symbol, Expr::Call(call), layout, next); Self::ptr_write( env, - interner, layout, self.hole_symbol, value_symbol, arena.alloc( // - unbox(arena.alloc(Stmt::Ret(final_symbol))), + ptr_load(arena.alloc(Stmt::Ret(final_symbol))), ), ) } diff --git a/examples/platform-switching/rocLovesZig.roc b/examples/platform-switching/rocLovesZig.roc index fe838c5396..39bab4d986 100644 --- a/examples/platform-switching/rocLovesZig.roc +++ b/examples/platform-switching/rocLovesZig.roc @@ -3,4 +3,29 @@ app "rocLovesZig" imports [] provides [main] to pf -main = "Roc <3 Zig!\n" +LinkedList a : [Nil, Cons a (LinkedList a)] + +map : LinkedList a, (a -> b) -> LinkedList b +map = \list, f -> + when list is + Nil -> Nil + Cons x xs -> Cons (f x) (map xs f) + +unfold : a, Nat -> LinkedList a +unfold = \value, n -> + when n is + 0 -> Nil + _ -> Cons value (unfold value (n - 1)) + +length : LinkedList a -> I64 +length = \list -> + when list is + Nil -> 0 + Cons _ rest -> 1 + length rest + +main : Str +main = + unfold 42 5 + |> map (\x -> x + 1i64) + |> length + |> Num.toStr