diff --git a/compiler/gen_llvm/src/llvm/build.rs b/compiler/gen_llvm/src/llvm/build.rs index 201acf20d6..a6707a4884 100644 --- a/compiler/gen_llvm/src/llvm/build.rs +++ b/compiler/gen_llvm/src/llvm/build.rs @@ -22,7 +22,7 @@ use crate::llvm::convert::{ basic_type_from_builtin, basic_type_from_layout, block_of_memory_slices, ptr_int, }; use crate::llvm::refcounting::{ - decrement_refcount_layout, increment_refcount_layout, PointerToRefcount, + build_reset, decrement_refcount_layout, increment_refcount_layout, PointerToRefcount, }; use bumpalo::collections::Vec; use bumpalo::Bump; @@ -893,7 +893,7 @@ pub fn build_exp_call<'a, 'ctx, 'env>( } } -pub const TAG_ID_INDEX: u32 = 1; +const TAG_ID_INDEX: u32 = 1; pub const TAG_DATA_INDEX: u32 = 0; pub fn struct_from_fields<'a, 'ctx, 'env, I>( @@ -919,6 +919,34 @@ where struct_value.into_struct_value() } +fn struct_pointer_from_fields<'a, 'ctx, 'env, I>( + env: &Env<'a, 'ctx, 'env>, + struct_type: StructType<'ctx>, + input_pointer: PointerValue<'ctx>, + values: I, +) where + I: Iterator)>, +{ + let struct_ptr = env + .builder + .build_bitcast( + input_pointer, + struct_type.ptr_type(AddressSpace::Generic), + "struct_ptr", + ) + .into_pointer_value(); + + // Insert field exprs into struct_val + for (index, field_val) in values { + let field_ptr = env + .builder + .build_struct_gep(struct_ptr, index as u32, "field_struct_gep") + .unwrap(); + + env.builder.build_store(field_ptr, field_val); + } +} + pub fn build_exp_expr<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, layout_ids: &mut LayoutIds<'a>, @@ -969,16 +997,87 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( struct_from_fields(env, struct_type, field_vals.into_iter().enumerate()).into() } + Reuse { + arguments, + tag_layout: union_layout, + tag_id, + symbol, + .. + } => { + let reset = load_symbol(scope, symbol).into_pointer_value(); + build_tag( + env, + scope, + union_layout, + *tag_id, + arguments, + Some(reset), + parent, + ) + } + Tag { arguments, tag_layout: union_layout, - union_size, tag_id, .. - } => build_tag(env, scope, union_layout, *union_size, *tag_id, arguments), + } => build_tag(env, scope, union_layout, *tag_id, arguments, None, parent), - Reset(_) => todo!(), - Reuse { .. } => todo!(), + Reset(symbol) => { + let (tag_ptr, layout) = load_symbol_and_layout(scope, symbol); + let tag_ptr = tag_ptr.into_pointer_value(); + + // reset is only generated for union values + let union_layout = match layout { + Layout::Union(ul) => ul, + _ => unreachable!(), + }; + + let ctx = env.context; + let then_block = ctx.append_basic_block(parent, "then_reset"); + let else_block = ctx.append_basic_block(parent, "else_decref"); + let cont_block = ctx.append_basic_block(parent, "cont"); + + let refcount_ptr = + PointerToRefcount::from_ptr_to_data(env, tag_pointer_clear_tag_id(env, tag_ptr)); + let is_unique = refcount_ptr.is_1(env); + + env.builder + .build_conditional_branch(is_unique, then_block, else_block); + + { + // reset, when used on a unique reference, eagerly decrements the components of the + // referenced value, and returns the location of the now-invalid cell + env.builder.position_at_end(then_block); + + let reset_function = build_reset(env, layout_ids, *union_layout); + let call = env + .builder + .build_call(reset_function, &[tag_ptr.into()], "call_reset"); + + call.set_call_convention(FAST_CALL_CONV); + + let _ = call.try_as_basic_value(); + + env.builder.build_unconditional_branch(cont_block); + } + { + // If reset is used on a shared, non-reusable reference, it behaves + // like dec and returns NULL, which instructs reuse to behave like ctor + env.builder.position_at_end(else_block); + refcount_ptr.decrement(env, layout); + env.builder.build_unconditional_branch(cont_block); + } + { + env.builder.position_at_end(cont_block); + let phi = env.builder.build_phi(tag_ptr.get_type(), "branch"); + + let null_ptr = tag_ptr.get_type().const_null(); + phi.add_incoming(&[(&tag_ptr, then_block), (&null_ptr, else_block)]); + + phi.as_basic_value() + } + } StructAtIndex { index, structure, .. @@ -1084,13 +1183,15 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( let tag_id_type = basic_type_from_layout(env, &union_layout.tag_id_layout()).into_int_type(); + let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value()); + lookup_at_index_ptr2( env, union_layout, tag_id_type, field_layouts, *index as usize, - argument.into_pointer_value(), + ptr, ) } UnionLayout::NonNullableUnwrapped(field_layouts) => { @@ -1125,13 +1226,14 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( let tag_id_type = basic_type_from_layout(env, &union_layout.tag_id_layout()).into_int_type(); + let ptr = tag_pointer_clear_tag_id(env, argument.into_pointer_value()); lookup_at_index_ptr2( env, union_layout, tag_id_type, field_layouts, *index as usize, - argument.into_pointer_value(), + ptr, ) } UnionLayout::NullableUnwrapped { @@ -1171,15 +1273,103 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( } } +#[allow(clippy::too_many_arguments)] +fn build_wrapped_tag<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + scope: &Scope<'a, 'ctx>, + union_layout: &UnionLayout<'a>, + tag_id: u8, + arguments: &[Symbol], + tag_field_layouts: &[Layout<'a>], + tags: &[&[Layout<'a>]], + reuse_allocation: Option>, + parent: FunctionValue<'ctx>, +) -> BasicValueEnum<'ctx> { + let ctx = env.context; + let builder = env.builder; + + let tag_id_layout = union_layout.tag_id_layout(); + + // Determine types + let num_fields = arguments.len() + 1; + let mut field_types = Vec::with_capacity_in(num_fields, env.arena); + let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); + + for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) { + let (val, _val_layout) = load_symbol_and_layout(scope, field_symbol); + + let field_type = basic_type_from_layout(env, tag_field_layout); + + field_types.push(field_type); + + if let Layout::RecursivePointer = tag_field_layout { + debug_assert!(val.is_pointer_value()); + + // we store recursive pointers as `i64*` + let ptr = env.builder.build_bitcast( + val, + ctx.i64_type().ptr_type(AddressSpace::Generic), + "cast_recursive_pointer", + ); + + field_vals.push(ptr); + } else { + // this check fails for recursive tag unions, but can be helpful while debugging + // debug_assert_eq!(tag_field_layout, val_layout); + + field_vals.push(val); + } + } + + // Create the struct_type + let raw_data_ptr = allocate_tag(env, parent, reuse_allocation, union_layout, tags); + let struct_type = env.context.struct_type(&field_types, false); + + if union_layout.stores_tag_id_as_data(env.ptr_bytes) { + let tag_id_ptr = builder + .build_struct_gep(raw_data_ptr, TAG_ID_INDEX, "tag_id_index") + .unwrap(); + + let tag_id_type = basic_type_from_layout(env, &tag_id_layout).into_int_type(); + + env.builder + .build_store(tag_id_ptr, tag_id_type.const_int(tag_id as u64, false)); + + let opaque_struct_ptr = builder + .build_struct_gep(raw_data_ptr, TAG_DATA_INDEX, "tag_data_index") + .unwrap(); + + struct_pointer_from_fields( + env, + struct_type, + opaque_struct_ptr, + field_vals.into_iter().enumerate(), + ); + + raw_data_ptr.into() + } else { + struct_pointer_from_fields( + env, + struct_type, + raw_data_ptr, + field_vals.into_iter().enumerate(), + ); + + tag_pointer_set_tag_id(env, tag_id, raw_data_ptr).into() + } +} + pub fn build_tag<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, scope: &Scope<'a, 'ctx>, union_layout: &UnionLayout<'a>, - union_size: u8, tag_id: u8, arguments: &[Symbol], + reuse_allocation: Option>, + parent: FunctionValue<'ctx>, ) -> BasicValueEnum<'ctx> { let tag_id_layout = union_layout.tag_id_layout(); + let union_size = union_layout.number_of_tags(); match union_layout { UnionLayout::NonRecursive(tags) => { @@ -1269,79 +1459,51 @@ pub fn build_tag<'a, 'ctx, 'env>( UnionLayout::Recursive(tags) => { debug_assert!(union_size > 1); - let ctx = env.context; - let builder = env.builder; - - // Determine types - let num_fields = arguments.len() + 1; - let mut field_types = Vec::with_capacity_in(num_fields, env.arena); - let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - let tag_field_layouts = &tags[tag_id as usize]; - for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) { - let (val, _val_layout) = load_symbol_and_layout(scope, field_symbol); + build_wrapped_tag( + env, + scope, + union_layout, + tag_id, + arguments, + &tag_field_layouts, + tags, + reuse_allocation, + parent, + ) + } + UnionLayout::NullableWrapped { + nullable_id, + other_tags: tags, + } => { + let tag_field_layouts = { + use std::cmp::Ordering::*; + match tag_id.cmp(&(*nullable_id as u8)) { + Equal => { + let layout = Layout::Union(*union_layout); - let field_type = basic_type_from_layout(env, tag_field_layout); - - field_types.push(field_type); - - if let Layout::RecursivePointer = tag_field_layout { - debug_assert!(val.is_pointer_value()); - - // we store recursive pointers as `i64*` - let ptr = env.builder.build_bitcast( - val, - ctx.i64_type().ptr_type(AddressSpace::Generic), - "cast_recursive_pointer", - ); - - field_vals.push(ptr); - } else { - // this check fails for recursive tag unions, but can be helpful while debugging - // debug_assert_eq!(tag_field_layout, val_layout); - - field_vals.push(val); + return basic_type_from_layout(env, &layout) + .into_pointer_type() + .const_null() + .into(); + } + Less => &tags[tag_id as usize], + Greater => &tags[tag_id as usize - 1], } - } + }; - // Create the struct_type - let raw_data_ptr = - reserve_with_refcount_union_as_block_of_memory(env, *union_layout, tags); - - let tag_id_ptr = builder - .build_struct_gep(raw_data_ptr, TAG_ID_INDEX, "tag_id_index") - .unwrap(); - - let tag_id_type = basic_type_from_layout(env, &tag_id_layout).into_int_type(); - - env.builder - .build_store(tag_id_ptr, tag_id_type.const_int(tag_id as u64, false)); - - let opaque_struct_ptr = builder - .build_struct_gep(raw_data_ptr, TAG_DATA_INDEX, "tag_data_index") - .unwrap(); - - let struct_type = env.context.struct_type(&field_types, false); - let struct_ptr = env - .builder - .build_bitcast( - opaque_struct_ptr, - struct_type.ptr_type(AddressSpace::Generic), - "struct_ptr", - ) - .into_pointer_value(); - - // Insert field exprs into struct_val - for (index, field_val) in field_vals.into_iter().enumerate() { - let field_ptr = builder - .build_struct_gep(struct_ptr, index as u32, "field_struct_gep") - .unwrap(); - - builder.build_store(field_ptr, field_val); - } - - raw_data_ptr.into() + build_wrapped_tag( + env, + scope, + union_layout, + tag_id, + arguments, + &tag_field_layouts, + tags, + reuse_allocation, + parent, + ) } UnionLayout::NonNullableUnwrapped(fields) => { debug_assert_eq!(union_size, 1); @@ -1349,7 +1511,6 @@ pub fn build_tag<'a, 'ctx, 'env>( debug_assert_eq!(arguments.len(), fields.len()); let ctx = env.context; - let builder = env.builder; // Determine types let num_fields = arguments.len() + 1; @@ -1386,126 +1547,16 @@ pub fn build_tag<'a, 'ctx, 'env>( reserve_with_refcount_union_as_block_of_memory(env, *union_layout, &[fields]); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); - let struct_ptr = env - .builder - .build_bitcast( - data_ptr, - struct_type.ptr_type(AddressSpace::Generic), - "block_of_memory_to_tag", - ) - .into_pointer_value(); - // Insert field exprs into struct_val - for (index, field_val) in field_vals.into_iter().enumerate() { - let field_ptr = builder - .build_struct_gep(struct_ptr, index as u32, "struct_gep") - .unwrap(); - - builder.build_store(field_ptr, field_val); - } + struct_pointer_from_fields( + env, + struct_type, + data_ptr, + field_vals.into_iter().enumerate(), + ); data_ptr.into() } - UnionLayout::NullableWrapped { - nullable_id, - other_tags: tags, - } => { - if tag_id == *nullable_id as u8 { - let layout = Layout::Union(*union_layout); - - return basic_type_from_layout(env, &layout) - .into_pointer_type() - .const_null() - .into(); - } - - debug_assert!(union_size > 1); - - let ctx = env.context; - let builder = env.builder; - - // Determine types - let num_fields = arguments.len() + 1; - let mut field_types = Vec::with_capacity_in(num_fields, env.arena); - let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - - let tag_field_layouts = { - use std::cmp::Ordering::*; - match tag_id.cmp(&(*nullable_id as u8)) { - Equal => unreachable!("early return above"), - Less => &tags[tag_id as usize], - Greater => &tags[tag_id as usize - 1], - } - }; - - for (field_symbol, tag_field_layout) in arguments.iter().zip(tag_field_layouts.iter()) { - let val = load_symbol(scope, field_symbol); - - // Zero-sized fields have no runtime representation. - // The layout of the struct expects them to be dropped! - if !tag_field_layout.is_dropped_because_empty() { - let field_type = basic_type_from_layout(env, tag_field_layout); - - field_types.push(field_type); - - if let Layout::RecursivePointer = tag_field_layout { - debug_assert!(val.is_pointer_value()); - - // we store recursive pointers as `i64*` - let ptr = env.builder.build_bitcast( - val, - ctx.i64_type().ptr_type(AddressSpace::Generic), - "cast_recursive_pointer", - ); - - field_vals.push(ptr); - } else { - // this check fails for recursive tag unions, but can be helpful while debugging - // debug_assert_eq!(tag_field_layout, val_layout); - - field_vals.push(val); - } - } - } - - // Create the struct_type - let raw_data_ptr = - reserve_with_refcount_union_as_block_of_memory(env, *union_layout, tags); - - let tag_id_ptr = builder - .build_struct_gep(raw_data_ptr, TAG_ID_INDEX, "tag_id_index") - .unwrap(); - - let tag_id_type = basic_type_from_layout(env, &tag_id_layout).into_int_type(); - - env.builder - .build_store(tag_id_ptr, tag_id_type.const_int(tag_id as u64, false)); - - let opaque_struct_ptr = builder - .build_struct_gep(raw_data_ptr, TAG_DATA_INDEX, "tag_data_index") - .unwrap(); - - let struct_type = env.context.struct_type(&field_types, false); - let struct_ptr = env - .builder - .build_bitcast( - opaque_struct_ptr, - struct_type.ptr_type(AddressSpace::Generic), - "struct_ptr", - ) - .into_pointer_value(); - - // Insert field exprs into struct_val - for (index, field_val) in field_vals.into_iter().enumerate() { - let field_ptr = builder - .build_struct_gep(struct_ptr, index as u32, "field_struct_gep") - .unwrap(); - - builder.build_store(field_ptr, field_val); - } - - raw_data_ptr.into() - } UnionLayout::NullableUnwrapped { nullable_id, other_fields, @@ -1526,7 +1577,6 @@ pub fn build_tag<'a, 'ctx, 'env>( debug_assert!(union_size == 2); let ctx = env.context; - let builder = env.builder; // Determine types let num_fields = arguments.len() + 1; @@ -1567,32 +1617,128 @@ pub fn build_tag<'a, 'ctx, 'env>( // Create the struct_type let data_ptr = - reserve_with_refcount_union_as_block_of_memory(env, *union_layout, &[other_fields]); + allocate_tag(env, parent, reuse_allocation, union_layout, &[other_fields]); let struct_type = ctx.struct_type(field_types.into_bump_slice(), false); - let struct_ptr = env - .builder - .build_bitcast( - data_ptr, - struct_type.ptr_type(AddressSpace::Generic), - "block_of_memory_to_tag", - ) - .into_pointer_value(); - // Insert field exprs into struct_val - for (index, field_val) in field_vals.into_iter().enumerate() { - let field_ptr = builder - .build_struct_gep(struct_ptr, index as u32, "struct_gep") - .unwrap(); - - builder.build_store(field_ptr, field_val); - } + struct_pointer_from_fields( + env, + struct_type, + data_ptr, + field_vals.into_iter().enumerate(), + ); data_ptr.into() } } } +fn tag_pointer_set_tag_id<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + tag_id: u8, + pointer: PointerValue<'ctx>, +) -> PointerValue<'ctx> { + // we only have 3 bits, so can encode only 0..7 + debug_assert!(tag_id < 8); + + let ptr_int = env.ptr_int(); + + let as_int = env.builder.build_ptr_to_int(pointer, ptr_int, "to_int"); + + let tag_id_intval = ptr_int.const_int(tag_id as u64, false); + let combined = env.builder.build_or(as_int, tag_id_intval, "store_tag_id"); + + env.builder + .build_int_to_ptr(combined, pointer.get_type(), "to_ptr") +} + +pub fn tag_pointer_read_tag_id<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + pointer: PointerValue<'ctx>, +) -> IntValue<'ctx> { + let mask: u64 = 0b0000_0111; + + let ptr_int = env.ptr_int(); + + let as_int = env.builder.build_ptr_to_int(pointer, ptr_int, "to_int"); + let mask_intval = env.ptr_int().const_int(mask, false); + + let masked = env.builder.build_and(as_int, mask_intval, "mask"); + + env.builder + .build_int_cast(masked, env.context.i8_type(), "to_u8") +} + +pub fn tag_pointer_clear_tag_id<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + pointer: PointerValue<'ctx>, +) -> PointerValue<'ctx> { + let ptr_int = env.ptr_int(); + + let as_int = env.builder.build_ptr_to_int(pointer, ptr_int, "to_int"); + + let mask = { + let a = env.ptr_int().const_all_ones(); + let tag_id_bits = env.ptr_int().const_int(3, false); + env.builder.build_left_shift(a, tag_id_bits, "make_mask") + }; + + let masked = env.builder.build_and(as_int, mask, "masked"); + + env.builder + .build_int_to_ptr(masked, pointer.get_type(), "to_ptr") +} + +fn allocate_tag<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + parent: FunctionValue<'ctx>, + reuse_allocation: Option>, + union_layout: &UnionLayout<'a>, + tags: &[&[Layout<'a>]], +) -> PointerValue<'ctx> { + match reuse_allocation { + Some(ptr) => { + // check if its a null pointer + let is_null_ptr = env.builder.build_is_null(ptr, "is_null_ptr"); + let ctx = env.context; + let then_block = ctx.append_basic_block(parent, "then_allocate_fresh"); + let else_block = ctx.append_basic_block(parent, "else_reuse"); + let cont_block = ctx.append_basic_block(parent, "cont"); + + env.builder + .build_conditional_branch(is_null_ptr, then_block, else_block); + + let raw_ptr = { + env.builder.position_at_end(then_block); + let raw_ptr = + reserve_with_refcount_union_as_block_of_memory(env, *union_layout, tags); + env.builder.build_unconditional_branch(cont_block); + raw_ptr + }; + + let reuse_ptr = { + env.builder.position_at_end(else_block); + + let cleared = tag_pointer_clear_tag_id(env, ptr); + + env.builder.build_unconditional_branch(cont_block); + + cleared + }; + + { + env.builder.position_at_end(cont_block); + let phi = env.builder.build_phi(raw_ptr.get_type(), "branch"); + + phi.add_incoming(&[(&raw_ptr, then_block), (&reuse_ptr, else_block)]); + + phi.as_basic_value().into_pointer_value() + } + } + None => reserve_with_refcount_union_as_block_of_memory(env, *union_layout, tags), + } +} + pub fn get_tag_id<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, parent: FunctionValue<'ctx>, @@ -1610,7 +1756,15 @@ pub fn get_tag_id<'a, 'ctx, 'env>( get_tag_id_non_recursive(env, tag) } - UnionLayout::Recursive(_) => get_tag_id_wrapped(env, argument.into_pointer_value()), + UnionLayout::Recursive(_) => { + let argument_ptr = argument.into_pointer_value(); + + if union_layout.stores_tag_id_as_data(env.ptr_bytes) { + get_tag_id_wrapped(env, argument_ptr) + } else { + tag_pointer_read_tag_id(env, argument_ptr) + } + } UnionLayout::NonNullableUnwrapped(_) => tag_id_int_type.const_zero(), UnionLayout::NullableWrapped { nullable_id, .. } => { let argument_ptr = argument.into_pointer_value(); @@ -1635,7 +1789,12 @@ pub fn get_tag_id<'a, 'ctx, 'env>( { env.builder.position_at_end(else_block); - let tag_id = get_tag_id_wrapped(env, argument_ptr); + + let tag_id = if union_layout.stores_tag_id_as_data(env.ptr_bytes) { + get_tag_id_wrapped(env, argument_ptr) + } else { + tag_pointer_read_tag_id(env, argument_ptr) + }; env.builder.build_store(result, tag_id); env.builder.build_unconditional_branch(cont_block); } @@ -1771,9 +1930,11 @@ fn reserve_with_refcount_union_as_block_of_memory<'a, 'ctx, 'env>( union_layout: UnionLayout<'a>, fields: &[&[Layout<'a>]], ) -> PointerValue<'ctx> { + let ptr_bytes = env.ptr_bytes; + let block_type = block_of_memory_slices(env.context, fields, env.ptr_bytes); - let basic_type = if union_layout.stores_tag_id() { + let basic_type = if union_layout.stores_tag_id_as_data(ptr_bytes) { let tag_id_type = basic_type_from_layout(env, &union_layout.tag_id_layout()); env.context @@ -1789,7 +1950,7 @@ fn reserve_with_refcount_union_as_block_of_memory<'a, 'ctx, 'env>( .max() .unwrap_or_default(); - if union_layout.stores_tag_id() { + if union_layout.stores_tag_id_as_data(ptr_bytes) { stack_size += union_layout.tag_id_layout().stack_size(env.ptr_bytes); } @@ -2412,11 +2573,26 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( _ if layout.is_refcounted() => { if value.is_pointer_value() { - // BasicValueEnum::PointerValue(value_ptr) => { let value_ptr = value.into_pointer_value(); - let refcount_ptr = - PointerToRefcount::from_ptr_to_data(env, value_ptr); - refcount_ptr.decrement(env, layout); + + let then_block = env.context.append_basic_block(parent, "then"); + let done_block = env.context.append_basic_block(parent, "done"); + + let condition = + env.builder.build_is_not_null(value_ptr, "box_is_not_null"); + env.builder + .build_conditional_branch(condition, then_block, done_block); + + { + env.builder.position_at_end(then_block); + let refcount_ptr = + PointerToRefcount::from_ptr_to_data(env, value_ptr); + refcount_ptr.decrement(env, layout); + + env.builder.build_unconditional_branch(done_block); + } + + env.builder.position_at_end(done_block); } else { eprint!("we're likely leaking memory; see issue #985 for details"); } diff --git a/compiler/gen_llvm/src/llvm/build_hash.rs b/compiler/gen_llvm/src/llvm/build_hash.rs index 61a74dd232..99af6b9e11 100644 --- a/compiler/gen_llvm/src/llvm/build_hash.rs +++ b/compiler/gen_llvm/src/llvm/build_hash.rs @@ -1,5 +1,6 @@ use crate::debug_info_init; use crate::llvm::bitcode::call_bitcode_fn; +use crate::llvm::build::tag_pointer_clear_tag_id; use crate::llvm::build::Env; use crate::llvm::build::{cast_block_of_memory_to_tag, get_tag_id, FAST_CALL_CONV, TAG_DATA_INDEX}; use crate::llvm::build_str; @@ -493,14 +494,9 @@ fn hash_tag<'a, 'ctx, 'env>( ); // hash the tag data - let answer = hash_ptr_to_struct( - env, - layout_ids, - union_layout, - field_layouts, - seed, - tag.into_pointer_value(), - ); + let tag = tag_pointer_clear_tag_id(env, tag.into_pointer_value()); + let answer = + hash_ptr_to_struct(env, layout_ids, union_layout, field_layouts, seed, tag); merge_phi.add_incoming(&[(&answer, block)]); env.builder.build_unconditional_branch(merge_block); @@ -598,6 +594,7 @@ fn hash_tag<'a, 'ctx, 'env>( ); // hash tag data + let tag = tag_pointer_clear_tag_id(env, tag); let answer = hash_ptr_to_struct( env, layout_ids, diff --git a/compiler/gen_llvm/src/llvm/compare.rs b/compiler/gen_llvm/src/llvm/compare.rs index 2a9d1d3fa3..ffc0b04e99 100644 --- a/compiler/gen_llvm/src/llvm/compare.rs +++ b/compiler/gen_llvm/src/llvm/compare.rs @@ -1,5 +1,5 @@ -use crate::llvm::build::Env; use crate::llvm::build::{cast_block_of_memory_to_tag, get_tag_id, FAST_CALL_CONV}; +use crate::llvm::build::{tag_pointer_clear_tag_id, Env}; use crate::llvm::build_list::{list_len, load_list_ptr}; use crate::llvm::build_str::str_equal; use crate::llvm::convert::basic_type_from_layout; @@ -925,6 +925,10 @@ fn build_tag_eq_help<'a, 'ctx, 'env>( let id1 = get_tag_id(env, parent, union_layout, tag1); let id2 = get_tag_id(env, parent, union_layout, tag2); + // clear the tag_id so we get a pointer to the actual data + let tag1 = tag_pointer_clear_tag_id(env, tag1.into_pointer_value()); + let tag2 = tag_pointer_clear_tag_id(env, tag2.into_pointer_value()); + let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields"); let same_tag = @@ -944,14 +948,8 @@ fn build_tag_eq_help<'a, 'ctx, 'env>( let block = env.context.append_basic_block(parent, "tag_id_modify"); env.builder.position_at_end(block); - let answer = eq_ptr_to_struct( - env, - layout_ids, - union_layout, - field_layouts, - tag1.into_pointer_value(), - tag2.into_pointer_value(), - ); + let answer = + eq_ptr_to_struct(env, layout_ids, union_layout, field_layouts, tag1, tag2); env.builder.build_return(Some(&answer)); @@ -1073,6 +1071,10 @@ fn build_tag_eq_help<'a, 'ctx, 'env>( let id1 = get_tag_id(env, parent, union_layout, tag1); let id2 = get_tag_id(env, parent, union_layout, tag2); + // clear the tag_id so we get a pointer to the actual data + let tag1 = tag_pointer_clear_tag_id(env, tag1.into_pointer_value()); + let tag2 = tag_pointer_clear_tag_id(env, tag2.into_pointer_value()); + let compare_tag_fields = ctx.append_basic_block(parent, "compare_tag_fields"); let same_tag = @@ -1093,14 +1095,8 @@ fn build_tag_eq_help<'a, 'ctx, 'env>( let block = env.context.append_basic_block(parent, "tag_id_modify"); env.builder.position_at_end(block); - let answer = eq_ptr_to_struct( - env, - layout_ids, - union_layout, - field_layouts, - tag1.into_pointer_value(), - tag2.into_pointer_value(), - ); + let answer = + eq_ptr_to_struct(env, layout_ids, union_layout, field_layouts, tag1, tag2); env.builder.build_return(Some(&answer)); diff --git a/compiler/gen_llvm/src/llvm/convert.rs b/compiler/gen_llvm/src/llvm/convert.rs index 1cb299c0ae..729c118626 100644 --- a/compiler/gen_llvm/src/llvm/convert.rs +++ b/compiler/gen_llvm/src/llvm/convert.rs @@ -31,21 +31,31 @@ pub fn basic_type_from_layout<'a, 'ctx, 'env>( basic_type_from_layout(env, &closure_data_layout) } Struct(sorted_fields) => basic_type_from_record(env, sorted_fields), - Union(variant) => { + Union(union_layout) => { use UnionLayout::*; - let tag_id_type = basic_type_from_layout(env, &variant.tag_id_layout()); + let tag_id_type = basic_type_from_layout(env, &union_layout.tag_id_layout()); - match variant { - NullableWrapped { + match union_layout { + NonRecursive(tags) => { + let data = block_of_memory_slices(env.context, tags, env.ptr_bytes); + + env.context.struct_type(&[data, tag_id_type], false).into() + } + Recursive(tags) + | NullableWrapped { other_tags: tags, .. } => { let data = block_of_memory_slices(env.context, tags, env.ptr_bytes); - env.context - .struct_type(&[data, tag_id_type], false) - .ptr_type(AddressSpace::Generic) - .into() + if union_layout.stores_tag_id_as_data(env.ptr_bytes) { + env.context + .struct_type(&[data, tag_id_type], false) + .ptr_type(AddressSpace::Generic) + .into() + } else { + data.ptr_type(AddressSpace::Generic).into() + } } NullableUnwrapped { other_fields, .. } => { let block = @@ -56,19 +66,6 @@ pub fn basic_type_from_layout<'a, 'ctx, 'env>( let block = block_of_memory_slices(env.context, &[fields], env.ptr_bytes); block.ptr_type(AddressSpace::Generic).into() } - Recursive(tags) => { - let data = block_of_memory_slices(env.context, tags, env.ptr_bytes); - - env.context - .struct_type(&[data, tag_id_type], false) - .ptr_type(AddressSpace::Generic) - .into() - } - NonRecursive(tags) => { - let data = block_of_memory_slices(env.context, tags, env.ptr_bytes); - - env.context.struct_type(&[data, tag_id_type], false).into() - } } } RecursivePointer => { @@ -145,16 +142,6 @@ pub fn union_data_is_struct_type<'ctx>( context.struct_type(&[struct_type.into(), tag_id_type.into()], false) } -pub fn union_data_block_of_memory<'ctx>( - context: &'ctx Context, - tag_id_int_type: IntType<'ctx>, - layouts: &[&[Layout<'_>]], - ptr_bytes: u32, -) -> StructType<'ctx> { - let data_type = block_of_memory_slices(context, layouts, ptr_bytes); - context.struct_type(&[data_type, tag_id_int_type.into()], false) -} - pub fn block_of_memory<'ctx>( context: &'ctx Context, layout: &Layout<'_>, diff --git a/compiler/gen_llvm/src/llvm/refcounting.rs b/compiler/gen_llvm/src/llvm/refcounting.rs index 4dd072f622..9c09469a97 100644 --- a/compiler/gen_llvm/src/llvm/refcounting.rs +++ b/compiler/gen_llvm/src/llvm/refcounting.rs @@ -1,12 +1,10 @@ use crate::debug_info_init; use crate::llvm::build::{ add_func, cast_basic_basic, cast_block_of_memory_to_tag, get_tag_id, get_tag_id_non_recursive, - Env, FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, TAG_DATA_INDEX, + tag_pointer_clear_tag_id, Env, FAST_CALL_CONV, LLVM_SADD_WITH_OVERFLOW_I64, TAG_DATA_INDEX, }; use crate::llvm::build_list::{incrementing_elem_loop, list_len, load_list}; -use crate::llvm::convert::{ - basic_type_from_layout, block_of_memory_slices, ptr_int, union_data_block_of_memory, -}; +use crate::llvm::convert::{basic_type_from_layout, ptr_int}; use bumpalo::collections::Vec; use inkwell::basic_block::BasicBlock; use inkwell::context::Context; @@ -644,70 +642,21 @@ fn modify_refcount_layout_build_function<'a, 'ctx, 'env>( Union(variant) => { use UnionLayout::*; - match variant { - NullableWrapped { - other_tags: tags, .. - } => { - let function = build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - *variant, - tags, - true, - ); + if let NonRecursive(tags) = variant { + let function = modify_refcount_union(env, layout_ids, mode, when_recursive, tags); - Some(function) - } - - NullableUnwrapped { other_fields, .. } => { - let function = build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - *variant, - env.arena.alloc([*other_fields]), - true, - ); - - Some(function) - } - - NonNullableUnwrapped(fields) => { - let function = build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - *variant, - &*env.arena.alloc([*fields]), - true, - ); - Some(function) - } - - Recursive(tags) => { - let function = build_rec_union( - env, - layout_ids, - mode, - &WhenRecursive::Loop(*variant), - *variant, - tags, - false, - ); - Some(function) - } - - NonRecursive(tags) => { - let function = - modify_refcount_union(env, layout_ids, mode, when_recursive, tags); - - Some(function) - } + return Some(function); } + + let function = build_rec_union( + env, + layout_ids, + mode, + &WhenRecursive::Loop(*variant), + *variant, + ); + + Some(function) } Closure(_, lambda_set, _) => { @@ -1208,10 +1157,8 @@ fn build_rec_union<'a, 'ctx, 'env>( mode: Mode, when_recursive: &WhenRecursive<'a>, union_layout: UnionLayout<'a>, - tags: &'a [&'a [Layout<'a>]], - is_nullable: bool, ) -> FunctionValue<'ctx> { - let layout = Layout::Union(UnionLayout::Recursive(tags)); + let layout = Layout::Union(union_layout); let (_, fn_name) = function_name_from_mode( layout_ids, @@ -1228,7 +1175,7 @@ fn build_rec_union<'a, 'ctx, 'env>( let block = env.builder.get_insert_block().expect("to be in a function"); let di_location = env.builder.get_current_debug_location().unwrap(); - let basic_type = basic_type_from_layout(env, &Layout::Union(union_layout)); + let basic_type = basic_type_from_layout(env, &layout); let function_value = build_header(env, basic_type, mode, &fn_name); build_rec_union_help( @@ -1237,9 +1184,7 @@ fn build_rec_union<'a, 'ctx, 'env>( mode, when_recursive, union_layout, - tags, function_value, - is_nullable, ); env.builder.position_at_end(block); @@ -1260,10 +1205,10 @@ fn build_rec_union_help<'a, 'ctx, 'env>( mode: Mode, when_recursive: &WhenRecursive<'a>, union_layout: UnionLayout<'a>, - tags: &'a [&'a [roc_mono::layout::Layout<'a>]], fn_val: FunctionValue<'ctx>, - is_nullable: bool, ) { + let tags = union_layout_tags(env.arena, &union_layout); + let is_nullable = union_layout.is_nullable(); debug_assert!(!tags.is_empty()); let context = &env.context; @@ -1286,7 +1231,8 @@ fn build_rec_union_help<'a, 'ctx, 'env>( let parent = fn_val; debug_assert!(arg_val.is_pointer_value()); - let value_ptr = arg_val.into_pointer_value(); + let current_tag_id = get_tag_id(env, fn_val, &union_layout, arg_val); + let value_ptr = tag_pointer_clear_tag_id(env, arg_val.into_pointer_value()); // to increment/decrement the cons-cell itself let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); @@ -1351,14 +1297,21 @@ fn build_rec_union_help<'a, 'ctx, 'env>( union_layout, tags, value_ptr, + current_tag_id, refcount_ptr, do_recurse_block, + DecOrReuse::Dec, ) } } } } +enum DecOrReuse { + Dec, + Reuse, +} + #[allow(clippy::too_many_arguments)] fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, @@ -1369,8 +1322,10 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>( union_layout: UnionLayout<'a>, tags: &[&[Layout<'a>]], value_ptr: PointerValue<'ctx>, + current_tag_id: IntValue<'ctx>, refcount_ptr: PointerToRefcount<'ctx>, match_block: BasicBlock<'ctx>, + decrement_or_reuse: DecOrReuse, ) { let mode = Mode::Dec; let call_mode = mode_to_call_mode(decrement_fn, mode); @@ -1442,28 +1397,8 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>( debug_assert!(ptr_as_i64_ptr.is_pointer_value()); // therefore we must cast it to our desired type - let union_type = match union_layout { - UnionLayout::NonRecursive(_) => unreachable!(), - UnionLayout::Recursive(_) | UnionLayout::NullableWrapped { .. } => { - union_data_block_of_memory( - env.context, - tag_id_int_type, - tags, - env.ptr_bytes, - ) - .into() - } - UnionLayout::NonNullableUnwrapped { .. } - | UnionLayout::NullableUnwrapped { .. } => { - block_of_memory_slices(env.context, tags, env.ptr_bytes) - } - }; - - let recursive_field_ptr = cast_basic_basic( - env.builder, - ptr_as_i64_ptr, - union_type.ptr_type(AddressSpace::Generic).into(), - ); + let union_type = basic_type_from_layout(env, &Layout::Union(union_layout)); + let recursive_field_ptr = cast_basic_basic(env.builder, ptr_as_i64_ptr, union_type); deferred_rec.push(recursive_field_ptr); } else if field_layout.contains_refcounted() { @@ -1486,7 +1421,13 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>( // lists. To achieve it, we must first load all fields that we want to inc/dec (done above) // and store them on the stack, then modify (and potentially free) the current cell, then // actually inc/dec the fields. - refcount_ptr.modify(call_mode, &Layout::Union(union_layout), env); + + match decrement_or_reuse { + DecOrReuse::Reuse => {} + DecOrReuse::Dec => { + refcount_ptr.modify(call_mode, &Layout::Union(union_layout), env); + } + } for (field, field_layout) in deferred_nonrec { modify_refcount_layout_help( @@ -1524,25 +1465,182 @@ fn build_rec_union_recursive_decrement<'a, 'ctx, 'env>( let (_, only_branch) = cases.pop().unwrap(); env.builder.build_unconditional_branch(only_branch); } else { - // read the tag_id - let current_tag_id = get_tag_id(env, parent, &union_layout, value_ptr.into()); - - let merge_block = env.context.append_basic_block(parent, "decrement_merge"); + let default_block = env.context.append_basic_block(parent, "switch_default"); // switch on it env.builder - .build_switch(current_tag_id, merge_block, &cases); + .build_switch(current_tag_id, default_block, &cases); - env.builder.position_at_end(merge_block); + { + env.builder.position_at_end(default_block); - // increment/decrement the cons-cell itself - refcount_ptr.modify(call_mode, &Layout::Union(union_layout), env); + // increment/decrement the cons-cell itself + if let DecOrReuse::Dec = decrement_or_reuse { + refcount_ptr.modify(call_mode, &Layout::Union(union_layout), env); + } + } // this function returns void builder.build_return(None); } } +fn union_layout_tags<'a>( + arena: &'a bumpalo::Bump, + union_layout: &UnionLayout<'a>, +) -> &'a [&'a [Layout<'a>]] { + use UnionLayout::*; + + match union_layout { + NullableWrapped { + other_tags: tags, .. + } => *tags, + NullableUnwrapped { other_fields, .. } => arena.alloc([*other_fields]), + NonNullableUnwrapped(fields) => arena.alloc([*fields]), + Recursive(tags) => tags, + NonRecursive(tags) => tags, + } +} + +pub fn build_reset<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + union_layout: UnionLayout<'a>, +) -> FunctionValue<'ctx> { + let mode = Mode::Dec; + + let layout_id = layout_ids.get(Symbol::DEC, &Layout::Union(union_layout)); + let fn_name = layout_id.to_symbol_string(Symbol::DEC, &env.interns); + let fn_name = format!("{}_reset", fn_name); + + let when_recursive = WhenRecursive::Loop(union_layout); + let dec_function = build_rec_union(env, layout_ids, Mode::Dec, &when_recursive, union_layout); + + let function = match env.module.get_function(fn_name.as_str()) { + Some(function_value) => function_value, + None => { + let block = env.builder.get_insert_block().expect("to be in a function"); + let di_location = env.builder.get_current_debug_location().unwrap(); + + let basic_type = basic_type_from_layout(env, &Layout::Union(union_layout)); + let function_value = build_header(env, basic_type, mode, &fn_name); + + build_reuse_rec_union_help( + env, + layout_ids, + &when_recursive, + union_layout, + function_value, + dec_function, + ); + + env.builder.position_at_end(block); + env.builder + .set_current_debug_location(env.context, di_location); + + function_value + } + }; + + function +} + +#[allow(clippy::too_many_arguments)] +fn build_reuse_rec_union_help<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout_ids: &mut LayoutIds<'a>, + when_recursive: &WhenRecursive<'a>, + union_layout: UnionLayout<'a>, + reset_function: FunctionValue<'ctx>, + dec_function: FunctionValue<'ctx>, +) { + let tags = union_layout_tags(env.arena, &union_layout); + let is_nullable = union_layout.is_nullable(); + + debug_assert!(!tags.is_empty()); + + let context = &env.context; + let builder = env.builder; + + // Add a basic block for the entry point + let entry = context.append_basic_block(reset_function, "entry"); + + builder.position_at_end(entry); + + debug_info_init!(env, reset_function); + + // Add args to scope + let arg_symbol = Symbol::ARG_1; + + let arg_val = reset_function.get_param_iter().next().unwrap(); + + arg_val.set_name(arg_symbol.ident_string(&env.interns)); + + let parent = reset_function; + + debug_assert!(arg_val.is_pointer_value()); + let current_tag_id = get_tag_id(env, reset_function, &union_layout, arg_val); + let value_ptr = tag_pointer_clear_tag_id(env, arg_val.into_pointer_value()); + + // to increment/decrement the cons-cell itself + let refcount_ptr = PointerToRefcount::from_ptr_to_data(env, value_ptr); + let call_mode = CallMode::Dec; + + let should_recurse_block = env.context.append_basic_block(parent, "should_recurse"); + + let ctx = env.context; + if is_nullable { + let is_null = env.builder.build_is_null(value_ptr, "is_null"); + + let then_block = ctx.append_basic_block(parent, "then"); + + env.builder + .build_conditional_branch(is_null, then_block, should_recurse_block); + + { + env.builder.position_at_end(then_block); + env.builder.build_return(None); + } + } else { + env.builder.build_unconditional_branch(should_recurse_block); + } + + env.builder.position_at_end(should_recurse_block); + + let layout = Layout::Union(union_layout); + + let do_recurse_block = env.context.append_basic_block(parent, "do_recurse"); + let no_recurse_block = env.context.append_basic_block(parent, "no_recurse"); + + builder.build_conditional_branch(refcount_ptr.is_1(env), do_recurse_block, no_recurse_block); + + { + env.builder.position_at_end(no_recurse_block); + + refcount_ptr.modify(call_mode, &layout, env); + env.builder.build_return(None); + } + + { + env.builder.position_at_end(do_recurse_block); + + build_rec_union_recursive_decrement( + env, + layout_ids, + when_recursive, + parent, + dec_function, + union_layout, + tags, + value_ptr, + current_tag_id, + refcount_ptr, + do_recurse_block, + DecOrReuse::Reuse, + ) + } +} + fn function_name_from_mode<'a>( layout_ids: &mut LayoutIds<'a>, interns: &Interns, diff --git a/compiler/load/src/file.rs b/compiler/load/src/file.rs index 24d07b1a47..fd8fd1dc4b 100644 --- a/compiler/load/src/file.rs +++ b/compiler/load/src/file.rs @@ -2047,7 +2047,7 @@ fn update<'a>( } MadeSpecializations { module_id, - ident_ids, + mut ident_ids, subs, procedures, external_specializations_requested, @@ -2070,6 +2070,15 @@ fn update<'a>( && state.dependencies.solved_all() && state.goal_phase == Phase::MakeSpecializations { + Proc::insert_reset_reuse_operations( + arena, + module_id, + &mut ident_ids, + &mut state.procedures, + ); + + Proc::insert_refcount_operations(arena, &mut state.procedures); + // display the mono IR of the module, for debug purposes if roc_mono::ir::PRETTY_PRINT_IR_SYMBOLS { let procs_string = state @@ -2083,8 +2092,6 @@ fn update<'a>( println!("{}", result); } - Proc::insert_refcount_operations(arena, &mut state.procedures); - // This is not safe with the new non-recursive RC updates that we do for tag unions // // Proc::optimize_refcount_operations( diff --git a/compiler/mono/src/alias_analysis.rs b/compiler/mono/src/alias_analysis.rs index 0e6626a04b..e8b04ca1e9 100644 --- a/compiler/mono/src/alias_analysis.rs +++ b/compiler/mono/src/alias_analysis.rs @@ -803,15 +803,20 @@ fn lowlevel_spec( builder.add_sub_block(block, sub_block) } + NumToFloat => { + // just dream up a unit value + builder.add_make_tuple(block, &[]) + } Eq | NotEq => { // just dream up a unit value builder.add_make_tuple(block, &[]) } - NumLte | NumLt | NumGt | NumGte => { + NumLte | NumLt | NumGt | NumGte | NumCompare => { // just dream up a unit value builder.add_make_tuple(block, &[]) } - ListLen => { + ListLen | DictSize => { + // TODO should this touch the heap cell? // just dream up a unit value builder.add_make_tuple(block, &[]) } @@ -839,7 +844,70 @@ fn lowlevel_spec( Ok(list) } + ListAppend => { + let list = env.symbols[&arguments[0]]; + let to_insert = env.symbols[&arguments[1]]; + + let bag = builder.add_get_tuple_field(block, list, LIST_BAG_INDEX)?; + let cell = builder.add_get_tuple_field(block, list, LIST_CELL_INDEX)?; + + let _unit = builder.add_update(block, update_mode_var, cell)?; + + builder.add_bag_insert(block, bag, to_insert)?; + + Ok(list) + } + DictEmpty => { + match layout { + Layout::Builtin(Builtin::EmptyDict) => { + // just make up an element type + let type_id = builder.add_tuple_type(&[])?; + new_dict(builder, block, type_id, type_id) + } + Layout::Builtin(Builtin::Dict(key_layout, value_layout)) => { + let key_id = layout_spec(builder, key_layout)?; + let value_id = layout_spec(builder, value_layout)?; + new_dict(builder, block, key_id, value_id) + } + _ => unreachable!("empty array does not have a list layout"), + } + } + DictGetUnsafe => { + // NOTE DictGetUnsafe returns a { flag: Bool, value: v } + // when the flag is True, the value is found and defined; + // otherwise it is not and `Dict.get` should return `Err ...` + + let dict = env.symbols[&arguments[0]]; + let key = env.symbols[&arguments[1]]; + + // indicate that we use the key + builder.add_recursive_touch(block, key)?; + + let bag = builder.add_get_tuple_field(block, dict, DICT_BAG_INDEX)?; + let cell = builder.add_get_tuple_field(block, dict, DICT_CELL_INDEX)?; + + let _unit = builder.add_touch(block, cell)?; + + builder.add_bag_get(block, bag) + } + DictInsert => { + let dict = env.symbols[&arguments[0]]; + let key = env.symbols[&arguments[1]]; + let value = env.symbols[&arguments[2]]; + + let key_value = builder.add_make_tuple(block, &[key, value])?; + + let bag = builder.add_get_tuple_field(block, dict, DICT_BAG_INDEX)?; + let cell = builder.add_get_tuple_field(block, dict, DICT_CELL_INDEX)?; + + let _unit = builder.add_update(block, update_mode_var, cell)?; + + builder.add_bag_insert(block, bag, key_value)?; + + Ok(dict) + } _other => { + // println!("missing {:?}", _other); // TODO overly pessimstic let arguments: Vec<_> = arguments.iter().map(|symbol| env.symbols[symbol]).collect(); @@ -945,11 +1013,17 @@ fn expr_spec<'a>( match expr { Literal(literal) => literal_spec(builder, block, literal), Call(call) => call_spec(builder, env, block, layout, call), - Tag { + Reuse { + tag_layout, + tag_name: _, + tag_id, + arguments, + .. + } + | Tag { tag_layout, tag_name: _, tag_id, - union_size: _, arguments, } => { let variant_types = build_variant_types(builder, tag_layout)?; @@ -1095,8 +1169,12 @@ fn expr_spec<'a>( Err(()) => unreachable!("empty array does not have a list layout"), } } - Reuse { .. } => todo!("currently unused"), - Reset(_) => todo!("currently unused"), + Reset(symbol) => { + let type_id = layout_spec(builder, layout)?; + let value_id = env.symbols[symbol]; + + builder.add_unknown_with(block, &[value_id], type_id) + } RuntimeErrorFunction(_) => { let type_id = layout_spec(builder, layout)?; @@ -1253,6 +1331,18 @@ fn new_list(builder: &mut FuncDefBuilder, block: BlockId, element_type: TypeId) builder.add_make_tuple(block, &[cell, bag]) } +fn new_dict( + builder: &mut FuncDefBuilder, + block: BlockId, + key_type: TypeId, + value_type: TypeId, +) -> Result { + let cell = builder.add_new_heap_cell(block)?; + let element_type = builder.add_tuple_type(&[key_type, value_type])?; + let bag = builder.add_empty_bag(block, element_type)?; + builder.add_make_tuple(block, &[cell, bag]) +} + fn new_static_string(builder: &mut FuncDefBuilder, block: BlockId) -> Result { let module = MOD_APP; diff --git a/compiler/mono/src/inc_dec.rs b/compiler/mono/src/inc_dec.rs index a9583856e4..9954707853 100644 --- a/compiler/mono/src/inc_dec.rs +++ b/compiler/mono/src/inc_dec.rs @@ -154,11 +154,12 @@ struct VarInfo { reference: bool, // true if the variable may be a reference (aka pointer) at runtime persistent: bool, // true if the variable is statically known to be marked a Persistent at runtime consume: bool, // true if the variable RC must be "consumed" + reset: bool, // true if the variable is the result of a Reset operation } type VarMap = MutMap; -type LiveVarSet = MutSet; -type JPLiveVarMap = MutMap; +pub type LiveVarSet = MutSet; +pub type JPLiveVarMap = MutMap; #[derive(Clone, Debug)] struct Context<'a> { @@ -254,6 +255,7 @@ impl<'a> Context<'a> { reference: false, // assume function symbols are global constants persistent: true, // assume function symbols are global constants consume: false, // no need to consume this variable + reset: false, // reset symbols cannot be passed as function arguments }, ); } @@ -310,7 +312,12 @@ impl<'a> Context<'a> { return stmt; } - let modify = ModifyRc::Dec(symbol); + let modify = if info.reset { + ModifyRc::DecRef(symbol) + } else { + ModifyRc::Dec(symbol) + }; + self.arena.alloc(Stmt::Refcounting(modify, stmt)) } @@ -753,12 +760,6 @@ impl<'a> Context<'a> { arguments, }) => self.visit_call(z, call_type, arguments, l, b, b_live_vars), - EmptyArray | Literal(_) | Reset(_) | RuntimeErrorFunction(_) => { - // EmptyArray is always stack-allocated - // function pointers are persistent - self.arena.alloc(Stmt::Let(z, v, l, b)) - } - StructAtIndex { structure: x, .. } => { let b = self.add_dec_if_needed(x, b, b_live_vars); let info_x = self.get_var_info(x); @@ -794,6 +795,12 @@ impl<'a> Context<'a> { self.arena.alloc(Stmt::Let(z, v, l, b)) } + + EmptyArray | Literal(_) | Reset(_) | RuntimeErrorFunction(_) => { + // EmptyArray is always stack-allocated + // function pointers are persistent + self.arena.alloc(Stmt::Let(z, v, l, b)) + } }; (new_b, live_vars) @@ -812,7 +819,7 @@ impl<'a> Context<'a> { // must this value be consumed? let consume = consume_call(&self.vars, call); - self.update_var_info_help(symbol, layout, persistent, consume) + self.update_var_info_help(symbol, layout, persistent, consume, false) } fn update_var_info(&self, symbol: Symbol, layout: &Layout<'a>, expr: &Expr<'a>) -> Self { @@ -823,7 +830,9 @@ impl<'a> Context<'a> { // must this value be consumed? let consume = consume_expr(&self.vars, expr); - self.update_var_info_help(symbol, layout, persistent, consume) + let reset = matches!(expr, Expr::Reset(_)); + + self.update_var_info_help(symbol, layout, persistent, consume, reset) } fn update_var_info_help( @@ -832,6 +841,7 @@ impl<'a> Context<'a> { layout: &Layout<'a>, persistent: bool, consume: bool, + reset: bool, ) -> Self { // should we perform incs and decs on this value? let reference = layout.contains_refcounted(); @@ -840,6 +850,7 @@ impl<'a> Context<'a> { reference, persistent, consume, + reset, }; let mut ctx = self.clone(); @@ -857,6 +868,7 @@ impl<'a> Context<'a> { reference: p.layout.contains_refcounted(), consume: !p.borrow, persistent: false, + reset: false, }; ctx.vars.insert(p.symbol, info); } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 9f8f08efb3..48cba39860 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -227,6 +227,19 @@ impl<'a> Proc<'a> { } } + pub fn insert_reset_reuse_operations<'i>( + arena: &'a Bump, + home: ModuleId, + ident_ids: &'i mut IdentIds, + procs: &mut MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, + ) { + for (_, proc) in procs.iter_mut() { + let new_proc = + crate::reset_reuse::insert_reset_reuse(arena, home, ident_ids, proc.clone()); + *proc = new_proc; + } + } + pub fn optimize_refcount_operations<'i, T>( arena: &'a Bump, home: ModuleId, @@ -1129,7 +1142,6 @@ pub enum Expr<'a> { tag_layout: UnionLayout<'a>, tag_name: TagName, tag_id: u8, - union_size: u8, arguments: &'a [Symbol], }, Struct(&'a [Symbol]), @@ -1160,6 +1172,9 @@ pub enum Expr<'a> { Reuse { symbol: Symbol, + update_tag_id: bool, + // normal Tag fields + tag_layout: UnionLayout<'a>, tag_name: TagName, tag_id: u8, arguments: &'a [Symbol], @@ -1273,11 +1288,12 @@ impl<'a> Expr<'a> { alloc .text("Reuse ") .append(symbol_to_doc(alloc, *symbol)) + .append(alloc.space()) .append(doc_tag) .append(alloc.space()) .append(alloc.intersperse(it, " ")) } - Reset(symbol) => alloc.text("Reuse ").append(symbol_to_doc(alloc, *symbol)), + Reset(symbol) => alloc.text("Reset ").append(symbol_to_doc(alloc, *symbol)), Struct(args) => { let it = args.iter().map(|s| symbol_to_doc(alloc, *s)); @@ -4036,14 +4052,12 @@ fn construct_closure_data<'a>( ClosureRepresentation::Union { tag_id, tag_layout: _, - union_size, tag_name, union_layout, } => { let expr = Expr::Tag { tag_id, tag_layout: union_layout, - union_size, tag_name, arguments: symbols, }; @@ -4172,7 +4186,6 @@ fn convert_tag_union<'a>( assign_to_symbols(env, procs, layout_cache, iter, stmt) } Wrapped(variant) => { - let union_size = variant.number_of_tags() as u8; let (tag_id, _) = variant.tag_name_to_id(&tag_name); let field_symbols_temp = sorted_field_symbols(env, procs, layout_cache, args); @@ -4215,7 +4228,6 @@ fn convert_tag_union<'a>( tag_layout: union_layout, tag_name, tag_id: tag_id as u8, - union_size, arguments: field_symbols, }; @@ -4239,7 +4251,6 @@ fn convert_tag_union<'a>( tag_layout: union_layout, tag_name, tag_id: tag_id as u8, - union_size, arguments: field_symbols, }; @@ -4265,7 +4276,6 @@ fn convert_tag_union<'a>( tag_layout: union_layout, tag_name, tag_id: tag_id as u8, - union_size, arguments: field_symbols, }; @@ -4293,7 +4303,6 @@ fn convert_tag_union<'a>( tag_layout: union_layout, tag_name, tag_id: tag_id as u8, - union_size, arguments: field_symbols, }; @@ -4312,7 +4321,6 @@ fn convert_tag_union<'a>( tag_layout: union_layout, tag_name, tag_id: tag_id as u8, - union_size, arguments: field_symbols, }; @@ -5346,7 +5354,6 @@ fn substitute_in_expr<'a>( tag_layout, tag_name, tag_id, - union_size, arguments: args, } => { let mut did_change = false; @@ -5368,7 +5375,6 @@ fn substitute_in_expr<'a>( tag_layout: *tag_layout, tag_name: tag_name.clone(), tag_id: *tag_id, - union_size: *union_size, arguments, }) } else { diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 42e8793b78..48584c947d 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -148,6 +148,16 @@ impl<'a> UnionLayout<'a> { } } + pub fn number_of_tags(&'a self) -> usize { + match self { + UnionLayout::NonRecursive(tags) | UnionLayout::Recursive(tags) => tags.len(), + + UnionLayout::NullableWrapped { other_tags, .. } => other_tags.len() + 1, + UnionLayout::NonNullableUnwrapped(_) => 1, + UnionLayout::NullableUnwrapped { .. } => 2, + } + } + fn tag_id_builtin_help(union_size: usize) -> Builtin<'a> { if union_size <= u8::MAX as usize { Builtin::Int8 @@ -178,12 +188,40 @@ impl<'a> UnionLayout<'a> { Layout::Builtin(self.tag_id_builtin()) } - pub fn stores_tag_id(&self) -> bool { + fn stores_tag_id_in_pointer_bits(tags: &[&[Layout<'a>]], ptr_bytes: u32) -> bool { + tags.len() <= ptr_bytes as usize + } + + // i.e. it is not implicit and not stored in the pointer bits + pub fn stores_tag_id_as_data(&self, ptr_bytes: u32) -> bool { + match self { + UnionLayout::NonRecursive(_) => true, + UnionLayout::Recursive(tags) + | UnionLayout::NullableWrapped { + other_tags: tags, .. + } => !Self::stores_tag_id_in_pointer_bits(tags, ptr_bytes), + UnionLayout::NonNullableUnwrapped(_) | UnionLayout::NullableUnwrapped { .. } => false, + } + } + + pub fn stores_tag_id_in_pointer(&self, ptr_bytes: u32) -> bool { + match self { + UnionLayout::NonRecursive(_) => false, + UnionLayout::Recursive(tags) + | UnionLayout::NullableWrapped { + other_tags: tags, .. + } => Self::stores_tag_id_in_pointer_bits(tags, ptr_bytes), + UnionLayout::NonNullableUnwrapped(_) | UnionLayout::NullableUnwrapped { .. } => false, + } + } + + pub fn tag_is_null(&self, tag_id: u8) -> bool { match self { UnionLayout::NonRecursive(_) - | UnionLayout::Recursive(_) - | UnionLayout::NullableWrapped { .. } => true, - UnionLayout::NonNullableUnwrapped(_) | UnionLayout::NullableUnwrapped { .. } => false, + | UnionLayout::NonNullableUnwrapped(_) + | UnionLayout::Recursive(_) => false, + UnionLayout::NullableWrapped { nullable_id, .. } => *nullable_id == tag_id as i64, + UnionLayout::NullableUnwrapped { nullable_id, .. } => *nullable_id == (tag_id != 0), } } @@ -213,7 +251,6 @@ pub enum ClosureRepresentation<'a> { tag_layout: &'a [Layout<'a>], tag_name: TagName, tag_id: u8, - union_size: u8, union_layout: UnionLayout<'a>, }, /// the representation is anything but a union @@ -252,7 +289,6 @@ impl<'a> LambdaSet<'a> { .unwrap(); ClosureRepresentation::Union { - union_size: self.set.len() as u8, tag_id: index as u8, tag_layout: tags[index], tag_name: TagName::Closure(function_symbol), @@ -713,6 +749,7 @@ impl<'a> Layout<'a> { } } RecursivePointer => true, + Closure(_, closure_layout, _) => closure_layout.contains_refcounted(), } } @@ -1509,6 +1546,18 @@ fn get_recursion_var(subs: &Subs, var: Variable) -> Option { } } +fn is_recursive_tag_union(layout: &Layout) -> bool { + matches!( + layout, + Layout::Union( + UnionLayout::NullableUnwrapped { .. } + | UnionLayout::Recursive(_) + | UnionLayout::NullableWrapped { .. } + | UnionLayout::NonNullableUnwrapped { .. }, + ) + ) +} + pub fn union_sorted_tags_help<'a>( arena: &'a Bump, mut tags_vec: std::vec::Vec<(TagName, std::vec::Vec)>, @@ -1624,10 +1673,17 @@ pub fn union_sorted_tags_help<'a>( for var in arguments { match Layout::from_var(&mut env, var) { Ok(layout) => { - // Drop any zero-sized arguments like {} - if !layout.is_dropped_because_empty() { - has_any_arguments = true; + has_any_arguments = true; + // make sure to not unroll recursive types! + let self_recursion = opt_rec_var.is_some() + && subs.get_root_key_without_compacting(var) + == subs.get_root_key_without_compacting(opt_rec_var.unwrap()) + && is_recursive_tag_union(&layout); + + if self_recursion { + arg_layouts.push(Layout::RecursivePointer); + } else { arg_layouts.push(layout); } } diff --git a/compiler/mono/src/lib.rs b/compiler/mono/src/lib.rs index 831ff3d8c1..d377342bed 100644 --- a/compiler/mono/src/lib.rs +++ b/compiler/mono/src/lib.rs @@ -8,6 +8,7 @@ pub mod expand_rc; pub mod inc_dec; pub mod ir; pub mod layout; +pub mod reset_reuse; pub mod tail_recursion; // Temporary, while we can build up test cases and optimize the exhaustiveness checking. diff --git a/compiler/mono/src/reset_reuse.rs b/compiler/mono/src/reset_reuse.rs new file mode 100644 index 0000000000..3a437e9d50 --- /dev/null +++ b/compiler/mono/src/reset_reuse.rs @@ -0,0 +1,679 @@ +use crate::inc_dec::{collect_stmt, occurring_variables_expr, JPLiveVarMap, LiveVarSet}; +use crate::ir::{BranchInfo, Call, Expr, Proc, Stmt}; +use crate::layout::{Layout, UnionLayout}; +use bumpalo::collections::Vec; +use bumpalo::Bump; +use roc_collections::all::MutSet; +use roc_module::symbol::{IdentIds, ModuleId, Symbol}; + +pub fn insert_reset_reuse<'a, 'i>( + arena: &'a Bump, + home: ModuleId, + ident_ids: &'i mut IdentIds, + mut proc: Proc<'a>, +) -> Proc<'a> { + let mut env = Env { + arena, + home, + ident_ids, + jp_live_vars: Default::default(), + }; + + let new_body = function_r(&mut env, arena.alloc(proc.body)); + proc.body = new_body.clone(); + + proc +} + +#[derive(Debug)] +struct CtorInfo<'a> { + id: u8, + layout: UnionLayout<'a>, +} + +fn may_reuse(tag_layout: UnionLayout, tag_id: u8, other: &CtorInfo) -> bool { + if tag_layout != other.layout { + return false; + } + + // we should not get here if the tag we matched on is represented as NULL + debug_assert!(!tag_layout.tag_is_null(other.id)); + + // furthermore, we can only use the memory if the tag we're creating is non-NULL + !tag_layout.tag_is_null(tag_id) +} + +#[derive(Debug)] +struct Env<'a, 'i> { + arena: &'a Bump, + + /// required for creating new `Symbol`s + home: ModuleId, + ident_ids: &'i mut IdentIds, + + jp_live_vars: JPLiveVarMap, +} + +impl<'a, 'i> Env<'a, 'i> { + fn unique_symbol(&mut self) -> Symbol { + let ident_id = self.ident_ids.gen_unique(); + + self.home.register_debug_idents(&self.ident_ids); + + Symbol::new(self.home, ident_id) + } +} + +fn function_s<'a, 'i>( + env: &mut Env<'a, 'i>, + w: Symbol, + c: &CtorInfo<'a>, + stmt: &'a Stmt<'a>, +) -> &'a Stmt<'a> { + use Stmt::*; + + let arena = env.arena; + + match stmt { + Let(symbol, expr, layout, continuation) => match expr { + Expr::Tag { + tag_layout, + tag_id, + tag_name, + arguments, + } if may_reuse(*tag_layout, *tag_id, c) => { + // for now, always overwrite the tag ID just to be sure + let update_tag_id = true; + + let new_expr = Expr::Reuse { + symbol: w, + update_tag_id, + tag_layout: *tag_layout, + tag_id: *tag_id, + tag_name: tag_name.clone(), + arguments, + }; + let new_stmt = Let(*symbol, new_expr, *layout, continuation); + + arena.alloc(new_stmt) + } + _ => { + let rest = function_s(env, w, c, continuation); + let new_stmt = Let(*symbol, expr.clone(), *layout, rest); + + arena.alloc(new_stmt) + } + }, + Join { + id, + parameters, + body, + remainder, + } => { + let id = *id; + let body: &Stmt = *body; + let new_body = function_s(env, w, c, body); + + let new_join = if std::ptr::eq(body, new_body) || body == new_body { + // the join point body will consume w + Join { + id, + parameters, + body: new_body, + remainder, + } + } else { + let new_remainder = function_s(env, w, c, remainder); + + Join { + id, + parameters, + body, + remainder: new_remainder, + } + }; + + arena.alloc(new_join) + } + Invoke { + symbol, + call, + layout, + pass, + fail, + exception_id, + } => { + let new_pass = function_s(env, w, c, pass); + let new_fail = function_s(env, w, c, fail); + + let new_invoke = Invoke { + symbol: *symbol, + call: call.clone(), + layout: *layout, + pass: new_pass, + fail: new_fail, + exception_id: *exception_id, + }; + + arena.alloc(new_invoke) + } + Switch { + cond_symbol, + cond_layout, + branches, + default_branch, + ret_layout, + } => { + let mut new_branches = Vec::with_capacity_in(branches.len(), arena); + new_branches.extend(branches.iter().map(|(tag, info, body)| { + let new_body = function_s(env, w, c, body); + + (*tag, info.clone(), new_body.clone()) + })); + + let new_default = function_s(env, w, c, default_branch.1); + + let new_switch = Switch { + cond_symbol: *cond_symbol, + cond_layout: *cond_layout, + branches: new_branches.into_bump_slice(), + default_branch: (default_branch.0.clone(), new_default), + ret_layout: *ret_layout, + }; + + arena.alloc(new_switch) + } + Refcounting(op, continuation) => { + let continuation: &Stmt = *continuation; + let new_continuation = function_s(env, w, c, continuation); + + if std::ptr::eq(continuation, new_continuation) || continuation == new_continuation { + stmt + } else { + let new_refcounting = Refcounting(*op, new_continuation); + + arena.alloc(new_refcounting) + } + } + Resume(_) | Ret(_) | Jump(_, _) | RuntimeError(_) => stmt, + } +} + +fn try_function_s<'a, 'i>( + env: &mut Env<'a, 'i>, + x: Symbol, + c: &CtorInfo<'a>, + stmt: &'a Stmt<'a>, +) -> &'a Stmt<'a> { + let w = env.unique_symbol(); + + let new_stmt = function_s(env, w, c, stmt); + + if std::ptr::eq(stmt, new_stmt) || stmt == new_stmt { + stmt + } else { + insert_reset(env, w, x, c.layout, new_stmt) + } +} + +fn insert_reset<'a>( + env: &mut Env<'a, '_>, + w: Symbol, + x: Symbol, + union_layout: UnionLayout<'a>, + mut stmt: &'a Stmt<'a>, +) -> &'a Stmt<'a> { + use crate::ir::Expr::*; + + let mut stack = vec![]; + + while let Stmt::Let(symbol, expr, expr_layout, rest) = stmt { + match &expr { + StructAtIndex { .. } | GetTagId { .. } | UnionAtIndex { .. } => { + stack.push((symbol, expr, expr_layout)); + stmt = rest; + } + Literal(_) + | Call(_) + | Tag { .. } + | Struct(_) + | Array { .. } + | EmptyArray + | Reuse { .. } + | Reset(_) + | RuntimeErrorFunction(_) => break, + } + } + + let reset_expr = Expr::Reset(x); + + // const I64: Layout<'static> = Layout::Builtin(crate::layout::Builtin::Int64); + + let layout = Layout::Union(union_layout); + + stmt = env.arena.alloc(Stmt::Let(w, reset_expr, layout, stmt)); + + for (symbol, expr, expr_layout) in stack.into_iter().rev() { + stmt = env + .arena + .alloc(Stmt::Let(*symbol, expr.clone(), *expr_layout, stmt)); + } + + stmt +} + +fn function_d_finalize<'a, 'i>( + env: &mut Env<'a, 'i>, + x: Symbol, + c: &CtorInfo<'a>, + output: (&'a Stmt<'a>, bool), +) -> &'a Stmt<'a> { + let (stmt, x_live_in_stmt) = output; + if x_live_in_stmt { + stmt + } else { + try_function_s(env, x, c, stmt) + } +} + +fn function_d_main<'a, 'i>( + env: &mut Env<'a, 'i>, + x: Symbol, + c: &CtorInfo<'a>, + stmt: &'a Stmt<'a>, +) -> (&'a Stmt<'a>, bool) { + use Stmt::*; + + let arena = env.arena; + + match stmt { + Let(symbol, expr, layout, continuation) => { + match expr { + Expr::Tag { arguments, .. } if arguments.iter().any(|s| *s == x) => { + // If the scrutinee `x` (the one that is providing memory) is being + // stored in a constructor, then reuse will probably not be able to reuse memory at runtime. + // It may work only if the new cell is consumed, but we ignore this case. + (stmt, true) + } + _ => { + let (b, found) = function_d_main(env, x, c, continuation); + + // NOTE the &b != continuation is not found in the Lean source, but is required + // otherwise we observe the same symbol being reset twice + let mut result = MutSet::default(); + if found + || { + occurring_variables_expr(expr, &mut result); + !result.contains(&x) + } + || &b != continuation + { + let let_stmt = Let(*symbol, expr.clone(), *layout, b); + + (arena.alloc(let_stmt), found) + } else { + let b = try_function_s(env, x, c, b); + let let_stmt = Let(*symbol, expr.clone(), *layout, b); + + (arena.alloc(let_stmt), found) + } + } + } + } + Invoke { + symbol, + call, + layout, + pass, + fail, + exception_id, + } => { + if has_live_var(&env.jp_live_vars, stmt, x) { + let new_pass = { + let temp = function_d_main(env, x, c, pass); + function_d_finalize(env, x, c, temp) + }; + let new_fail = { + let temp = function_d_main(env, x, c, fail); + function_d_finalize(env, x, c, temp) + }; + let new_switch = Invoke { + symbol: *symbol, + call: call.clone(), + layout: *layout, + pass: new_pass, + fail: new_fail, + exception_id: *exception_id, + }; + + (arena.alloc(new_switch), true) + } else { + (stmt, false) + } + } + Switch { + cond_symbol, + cond_layout, + branches, + default_branch, + ret_layout, + } => { + if has_live_var(&env.jp_live_vars, stmt, x) { + // if `x` is live in `stmt`, we recursively process each branch + let mut new_branches = Vec::with_capacity_in(branches.len(), arena); + + for (tag, info, body) in branches.iter() { + let temp = function_d_main(env, x, c, body); + let new_body = function_d_finalize(env, x, c, temp); + + new_branches.push((*tag, info.clone(), new_body.clone())); + } + + let new_default = { + let (info, body) = default_branch; + let temp = function_d_main(env, x, c, body); + let new_body = function_d_finalize(env, x, c, temp); + + (info.clone(), new_body) + }; + + let new_switch = Switch { + cond_symbol: *cond_symbol, + cond_layout: *cond_layout, + branches: new_branches.into_bump_slice(), + default_branch: new_default, + ret_layout: *ret_layout, + }; + + (arena.alloc(new_switch), true) + } else { + (stmt, false) + } + } + Refcounting(modify_rc, continuation) => { + let (b, found) = function_d_main(env, x, c, continuation); + + if found || modify_rc.get_symbol() != x { + let refcounting = Refcounting(*modify_rc, b); + + (arena.alloc(refcounting), found) + } else { + let b = try_function_s(env, x, c, b); + let refcounting = Refcounting(*modify_rc, b); + + (arena.alloc(refcounting), found) + } + } + Join { + id, + parameters, + body, + remainder, + } => { + env.jp_live_vars.insert(*id, LiveVarSet::default()); + + let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default()); + + env.jp_live_vars.insert(*id, body_live_vars); + + let (b, found) = function_d_main(env, x, c, remainder); + + let (v, _found) = function_d_main(env, x, c, body); + + env.jp_live_vars.remove(id); + + // If `found' == true`, then `Dmain b` must also have returned `(b, true)` since + // we assume the IR does not have dead join points. So, if `x` is live in `j` (i.e., `v`), + // then it must also live in `b` since `j` is reachable from `b` with a `jmp`. + // On the other hand, `x` may be live in `b` but dead in `j` (i.e., `v`). -/ + let new_join = Join { + id: *id, + parameters, + body: v, + remainder: b, + }; + + (arena.alloc(new_join), found) + } + Ret(_) | Resume(_) | Jump(_, _) | RuntimeError(_) => { + (stmt, has_live_var(&env.jp_live_vars, stmt, x)) + } + } +} + +fn function_d<'a, 'i>( + env: &mut Env<'a, 'i>, + x: Symbol, + c: &CtorInfo<'a>, + stmt: &'a Stmt<'a>, +) -> &'a Stmt<'a> { + let temp = function_d_main(env, x, c, stmt); + + function_d_finalize(env, x, c, temp) +} + +fn function_r_branch_body<'a, 'i>( + env: &mut Env<'a, 'i>, + info: &BranchInfo<'a>, + body: &'a Stmt<'a>, +) -> &'a Stmt<'a> { + let temp = function_r(env, body); + + match info { + BranchInfo::None => temp, + BranchInfo::Constructor { + scrutinee, + layout, + tag_id, + } => match layout { + Layout::Union(UnionLayout::NonRecursive(_)) => temp, + Layout::Union(union_layout) if !union_layout.tag_is_null(*tag_id) => { + let ctor_info = CtorInfo { + layout: *union_layout, + id: *tag_id, + }; + function_d(env, *scrutinee, &ctor_info, temp) + } + _ => temp, + }, + } +} + +fn function_r<'a, 'i>(env: &mut Env<'a, 'i>, stmt: &'a Stmt<'a>) -> &'a Stmt<'a> { + use Stmt::*; + + let arena = env.arena; + + match stmt { + Switch { + cond_symbol, + cond_layout, + branches, + default_branch, + ret_layout, + } => { + let mut new_branches = Vec::with_capacity_in(branches.len(), arena); + + for (tag, info, body) in branches.iter() { + let new_body = function_r_branch_body(env, info, body); + + new_branches.push((*tag, info.clone(), new_body.clone())); + } + + let new_default = { + let (info, body) = default_branch; + + let new_body = function_r_branch_body(env, info, body); + + (info.clone(), new_body) + }; + + let new_switch = Switch { + cond_symbol: *cond_symbol, + cond_layout: *cond_layout, + branches: new_branches.into_bump_slice(), + default_branch: new_default, + ret_layout: *ret_layout, + }; + + arena.alloc(new_switch) + } + + Join { + id, + parameters, + body, + remainder, + } => { + env.jp_live_vars.insert(*id, LiveVarSet::default()); + + let body_live_vars = collect_stmt(body, &env.jp_live_vars, LiveVarSet::default()); + + env.jp_live_vars.insert(*id, body_live_vars); + + let b = function_r(env, remainder); + + let v = function_r(env, body); + + env.jp_live_vars.remove(id); + + let join = Join { + id: *id, + parameters, + body: v, + remainder: b, + }; + + arena.alloc(join) + } + + Let(symbol, expr, layout, continuation) => { + let b = function_r(env, continuation); + + arena.alloc(Let(*symbol, expr.clone(), *layout, b)) + } + Invoke { + symbol, + call, + layout, + pass, + fail, + exception_id, + } => { + let branch_info = BranchInfo::None; + let new_pass = function_r_branch_body(env, &branch_info, pass); + let new_fail = function_r_branch_body(env, &branch_info, fail); + + let invoke = Invoke { + symbol: *symbol, + call: call.clone(), + layout: *layout, + pass: new_pass, + fail: new_fail, + exception_id: *exception_id, + }; + + arena.alloc(invoke) + } + Refcounting(modify_rc, continuation) => { + let b = function_r(env, continuation); + + arena.alloc(Refcounting(*modify_rc, b)) + } + + Resume(_) | Ret(_) | Jump(_, _) | RuntimeError(_) => { + // terminals + stmt + } + } +} + +fn has_live_var<'a>(jp_live_vars: &JPLiveVarMap, stmt: &'a Stmt<'a>, needle: Symbol) -> bool { + use Stmt::*; + + match stmt { + Let(s, e, _, c) => { + debug_assert_ne!(*s, needle); + has_live_var_expr(e, needle) || has_live_var(jp_live_vars, c, needle) + } + Invoke { + symbol, + call, + pass, + fail, + .. + } => { + debug_assert_ne!(*symbol, needle); + + has_live_var_call(call, needle) + || has_live_var(jp_live_vars, pass, needle) + || has_live_var(jp_live_vars, fail, needle) + } + Switch { cond_symbol, .. } if *cond_symbol == needle => true, + Switch { + branches, + default_branch, + .. + } => { + has_live_var(jp_live_vars, default_branch.1, needle) + || branches + .iter() + .any(|(_, _, body)| has_live_var(jp_live_vars, body, needle)) + } + Ret(s) => *s == needle, + Refcounting(modify_rc, cont) => { + modify_rc.get_symbol() == needle || has_live_var(jp_live_vars, cont, needle) + } + Join { + id, + parameters, + body, + remainder, + } => { + debug_assert!(parameters.iter().all(|p| p.symbol != needle)); + + let mut jp_live_vars = jp_live_vars.clone(); + + jp_live_vars.insert(*id, LiveVarSet::default()); + + let body_live_vars = collect_stmt(body, &jp_live_vars, LiveVarSet::default()); + + if body_live_vars.contains(&needle) { + return true; + } + + jp_live_vars.insert(*id, body_live_vars); + + has_live_var(&jp_live_vars, remainder, needle) + } + Jump(id, arguments) => { + arguments.iter().any(|s| *s == needle) || jp_live_vars[id].contains(&needle) + } + Resume(_) | RuntimeError(_) => false, + } +} + +fn has_live_var_expr<'a>(expr: &'a Expr<'a>, needle: Symbol) -> bool { + match expr { + Expr::Literal(_) => false, + Expr::Call(call) => has_live_var_call(call, needle), + Expr::Array { elems: fields, .. } + | Expr::Tag { + arguments: fields, .. + } + | Expr::Struct(fields) => fields.iter().any(|s| *s == needle), + Expr::StructAtIndex { structure, .. } + | Expr::GetTagId { structure, .. } + | Expr::UnionAtIndex { structure, .. } => *structure == needle, + Expr::EmptyArray => false, + Expr::Reuse { + symbol, arguments, .. + } => needle == *symbol || arguments.iter().any(|s| *s == needle), + Expr::Reset(symbol) => needle == *symbol, + Expr::RuntimeErrorFunction(_) => false, + } +} + +fn has_live_var_call<'a>(call: &'a Call<'a>, needle: Symbol) -> bool { + call.arguments.iter().any(|s| *s == needle) +} diff --git a/compiler/test_gen/src/gen_primitives.rs b/compiler/test_gen/src/gen_primitives.rs index 3c932db913..77b4f7cb24 100644 --- a/compiler/test_gen/src/gen_primitives.rs +++ b/compiler/test_gen/src/gen_primitives.rs @@ -1555,9 +1555,9 @@ fn rbtree_balance_full() { balance Red 0 0 Empty Empty "# ), - false, - *const i64, - |x: *const i64| x.is_null() + true, + usize, + |x| x != 0 ); } diff --git a/compiler/test_gen/src/gen_tags.rs b/compiler/test_gen/src/gen_tags.rs index 38c28b9fff..247bebde87 100644 --- a/compiler/test_gen/src/gen_tags.rs +++ b/compiler/test_gen/src/gen_tags.rs @@ -1079,8 +1079,8 @@ fn nested_recursive_literal() { #" ), 0, - &(i64, i64, u8), - |x: &(i64, i64, u8)| x.2 + usize, + |_| 0 ); } diff --git a/examples/benchmarks/CFold.roc b/examples/benchmarks/CFold.roc index fb4fa4eeca..7c17abf066 100644 --- a/examples/benchmarks/CFold.roc +++ b/examples/benchmarks/CFold.roc @@ -97,4 +97,3 @@ constFolding = \e -> Pair y1 y2 -> Add y1 y2 _ -> e - diff --git a/examples/benchmarks/RBTreeDel.roc b/examples/benchmarks/RBTreeDel.roc index d9b780fa39..ccb1753042 100644 --- a/examples/benchmarks/RBTreeDel.roc +++ b/examples/benchmarks/RBTreeDel.roc @@ -38,12 +38,12 @@ makeMapHelp = \total, n, m -> n1 = n - 1 powerOf10 = - (n % 10 |> resultWithDefault 0) == 0 + n |> Num.isMultipleOf 10 t1 = insert m n powerOf10 isFrequency = - (n % 4 |> resultWithDefault 0) == 0 + n |> Num.isMultipleOf 4 key = n1 + ((total - n1) // 5 |> resultWithDefault 0) t2 = if isFrequency then delete t1 key else t1 @@ -85,8 +85,6 @@ isRed = \tree -> Node Red _ _ _ _ -> True _ -> False -lt = \x, y -> x < y - ins : Tree I64 Bool, I64, Bool -> Tree I64 Bool ins = \tree, kx, vx -> when tree is @@ -94,19 +92,24 @@ ins = \tree, kx, vx -> Node Red Leaf kx vx Leaf Node Red a ky vy b -> - if lt kx ky then - Node Red (ins a kx vx) ky vy b - else if lt ky kx then - Node Red a ky vy (ins b kx vx) - else - Node Red a ky vy (ins b kx vx) + when Num.compare kx ky is + LT -> Node Red (ins a kx vx) ky vy b + GT -> Node Red a ky vy (ins b kx vx) + EQ -> Node Red a ky vy (ins b kx vx) Node Black a ky vy b -> - if lt kx ky then - (if isRed a then balanceLeft (ins a kx vx) ky vy b else Node Black (ins a kx vx) ky vy b) - else if lt ky kx then - (if isRed b then balanceRight a ky vy (ins b kx vx) else Node Black a ky vy (ins b kx vx)) - else Node Black a kx vx b + when Num.compare kx ky is + LT -> + when isRed a is + True -> balanceLeft (ins a kx vx) ky vy b + False -> Node Black (ins a kx vx) ky vy b + + GT -> + when isRed b is + True -> balanceRight a ky vy (ins b kx vx) + False -> Node Black a ky vy (ins b kx vx) + EQ -> + Node Black a kx vx b balanceLeft : Tree a b, a, b, Tree a b -> Tree a b balanceLeft = \l, k, v, r -> diff --git a/www/public/styles.css b/www/public/styles.css index 27f88aacd1..b90cfb58a2 100644 --- a/www/public/styles.css +++ b/www/public/styles.css @@ -282,8 +282,9 @@ code { font-family: var(--font-mono); color: var(--code-color); background-color: var(--code-bg-color); - padding: 2px 8px; + padding: 0 8px; display: inline-block; + line-height: 28px; } code a {