diff --git a/compiler/gen/src/llvm/build.rs b/compiler/gen/src/llvm/build.rs index c5f4f05009..dc7a1ab487 100644 --- a/compiler/gen/src/llvm/build.rs +++ b/compiler/gen/src/llvm/build.rs @@ -468,13 +468,17 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( Tag { arguments, - tag_layout, + tag_layout: Layout::Union(fields), union_size, + tag_id, .. } => { + let tag_layout = Layout::Union(fields); + debug_assert!(*union_size > 1); let ptr_size = env.ptr_bytes; + dbg!(&tag_layout); let mut filler = tag_layout.stack_size(ptr_size); let ctx = env.context; @@ -485,18 +489,35 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( let mut field_types = Vec::with_capacity_in(num_fields, env.arena); let mut field_vals = Vec::with_capacity_in(num_fields, env.arena); - for field_symbol in arguments.iter() { + for (field_symbol, tag_field_layout) in + arguments.iter().zip(fields[*tag_id as usize].iter()) + { + // note field_layout is the layout of the argument. + // tag_field_layout is the layout that the tag will store + // these are different for recursive tag unions let (val, field_layout) = load_symbol_and_layout(env, scope, field_symbol); - let field_size = field_layout.stack_size(ptr_size); + let field_size = tag_field_layout.stack_size(ptr_size); // Zero-sized fields have no runtime representation. // The layout of the struct expects them to be dropped! if field_size != 0 { let field_type = - basic_type_from_layout(env.arena, env.context, field_layout, ptr_size); + basic_type_from_layout(env.arena, env.context, tag_field_layout, ptr_size); field_types.push(field_type); - field_vals.push(val); + + if let Layout::RecursivePointer = tag_field_layout { + let ptr = allocate_with_refcount(env, field_layout, val).into(); + let ptr = cast_basic_basic( + builder, + ptr, + ctx.i64_type().ptr_type(AddressSpace::Generic).into(), + ); + dbg!(&ptr); + field_vals.push(ptr); + } else { + field_vals.push(val); + } filler -= field_size; } @@ -543,7 +564,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( // https://github.com/raviqqe/ssf/blob/bc32aae68940d5bddf5984128e85af75ca4f4686/ssf-llvm/src/expression_compiler.rs#L116 let internal_type = - basic_type_from_layout(env.arena, env.context, tag_layout, env.ptr_bytes); + basic_type_from_layout(env.arena, env.context, &tag_layout, env.ptr_bytes); cast_basic_basic( builder, @@ -551,6 +572,7 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( internal_type, ) } + Tag { .. } => unreachable!("tags should have a union layout"), AccessAtIndex { index, structure, @@ -611,9 +633,24 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( let struct_value = cast_struct_struct(builder, argument, struct_type); - builder + let result = builder .build_extract_value(struct_value, *index as u32, "") - .expect("desired field did not decode") + .expect("desired field did not decode"); + + if let Some(Layout::RecursivePointer) = field_layouts.get(*index as usize) { + // the value is a pointer to the actual value; load that value! + let ptr = cast_basic_basic( + builder, + result, + struct_value + .get_type() + .ptr_type(AddressSpace::Generic) + .into(), + ); + builder.build_load(ptr.into_pointer_value(), "load_recursive_field") + } else { + result + } } EmptyArray => empty_polymorphic_list(env), Array { elem_layout, elems } => list_literal(env, scope, elem_layout, elems), @@ -634,6 +671,73 @@ pub fn build_exp_expr<'a, 'ctx, 'env>( } } +pub fn allocate_with_refcount<'a, 'ctx, 'env>( + env: &Env<'a, 'ctx, 'env>, + layout: &Layout<'a>, + value: BasicValueEnum<'ctx>, +) -> PointerValue<'ctx> { + let builder = env.builder; + let ctx = env.context; + + let value_type = basic_type_from_layout(env.arena, ctx, layout, env.ptr_bytes); + let value_bytes = layout.stack_size(env.ptr_bytes) as u64; + + let len_type = env.ptr_int(); + // bytes per element + let bytes_len = len_type.const_int(value_bytes, false); + let offset = (env.ptr_bytes as u64).max(value_bytes); + + let ptr = { + let len = bytes_len; + let len = + builder.build_int_add(len, len_type.const_int(offset, false), "add_refcount_space"); + + env.builder + .build_array_malloc(ctx.i8_type(), len, "create_list_ptr") + .unwrap() + + // TODO check if malloc returned null; if so, runtime error for OOM! + }; + + // We must return a pointer to the first element: + let ptr_bytes = env.ptr_bytes; + let int_type = ptr_int(ctx, ptr_bytes); + let ptr_as_int = builder.build_ptr_to_int(ptr, int_type, "list_cast_ptr"); + let incremented = builder.build_int_add( + ptr_as_int, + ctx.i64_type().const_int(offset, false), + "increment_list_ptr", + ); + + let ptr_type = get_ptr_type(&value_type, AddressSpace::Generic); + let list_element_ptr = builder.build_int_to_ptr(incremented, ptr_type, "list_cast_ptr"); + + // subtract ptr_size, to access the refcount + let refcount_ptr = builder.build_int_sub( + incremented, + ctx.i64_type().const_int(env.ptr_bytes as u64, false), + "refcount_ptr", + ); + + let refcount_ptr = builder.build_int_to_ptr( + refcount_ptr, + int_type.ptr_type(AddressSpace::Generic), + "make ptr", + ); + + // the refcount of a new list is initially 1 + // we assume that the list is indeed used (dead variables are eliminated) + let ref_count_one = ctx + .i64_type() + .const_int(crate::llvm::build::REFCOUNT_1 as _, false); + builder.build_store(refcount_ptr, ref_count_one); + + // store the value in the pointer + builder.build_store(list_element_ptr, value); + + list_element_ptr +} + fn list_literal<'a, 'ctx, 'env>( env: &Env<'a, 'ctx, 'env>, scope: &Scope<'a, 'ctx>, @@ -702,6 +806,7 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( parent: FunctionValue<'ctx>, stmt: &roc_mono::ir::Stmt<'a>, ) -> BasicValueEnum<'ctx> { + use roc_mono::ir::Expr; use roc_mono::ir::Stmt::*; match stmt { @@ -709,7 +814,20 @@ pub fn build_exp_stmt<'a, 'ctx, 'env>( let context = &env.context; let val = build_exp_expr(env, layout_ids, &scope, parent, &expr); - let expr_bt = basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes); + let expr_bt = if let Layout::RecursivePointer = layout { + match expr { + Expr::AccessAtIndex { field_layouts, .. } => { + let layout = Layout::Struct(field_layouts); + + basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes) + } + _ => unreachable!( + "a recursive pointer can only be loaded from a recursive tag union" + ), + } + } else { + basic_type_from_layout(env.arena, context, &layout, env.ptr_bytes) + }; let alloca = create_entry_block_alloca(env, parent, expr_bt, symbol.ident_string(&env.interns)); diff --git a/compiler/gen/src/llvm/convert.rs b/compiler/gen/src/llvm/convert.rs index f3cc316a10..ccd68b3964 100644 --- a/compiler/gen/src/llvm/convert.rs +++ b/compiler/gen/src/llvm/convert.rs @@ -107,9 +107,7 @@ pub fn basic_type_from_layout<'ctx>( .struct_type(field_types.into_bump_slice(), false) .as_basic_type_enum() } - RecursiveUnion(_) => todo!("TODO implement layout of recursive tag union"), - RecursivePointer => todo!("TODO implement layout of recursive tag union"), - Union(_) => { + RecursiveUnion(_) | Union(_) => { // TODO make this dynamic let ptr_size = std::mem::size_of::(); let union_size = layout.stack_size(ptr_size as u32); @@ -140,6 +138,13 @@ pub fn basic_type_from_layout<'ctx>( .into() } } + RecursivePointer => { + // TODO make this dynamic + context + .i64_type() + .ptr_type(AddressSpace::Generic) + .as_basic_type_enum() + } Builtin(builtin) => match builtin { Int128 => context.i128_type().as_basic_type_enum(), diff --git a/compiler/gen/tests/gen_primitives.rs b/compiler/gen/tests/gen_primitives.rs index 6c310b5561..f301c9626b 100644 --- a/compiler/gen/tests/gen_primitives.rs +++ b/compiler/gen/tests/gen_primitives.rs @@ -470,4 +470,45 @@ mod gen_primitives { i64 ); } + + #[test] + fn peano() { + assert_evals_to!( + indoc!( + r#" + Peano : [ S Peano, Z ] + + three : Peano + three = S (S (S Z)) + + when three is + S _ -> 1 + Z -> 0 + "# + ), + 1, + i64 + ); + } + + #[test] + fn peano2() { + assert_evals_to!( + indoc!( + r#" + Peano : [ S Peano, Z ] + + three : Peano + three = S (S (S Z)) + + when three is + S (S _) -> 1 + S (_) -> 0 + Z -> 0 + "# + ), + 1, + i64 + ); + } } diff --git a/compiler/gen/tests/helpers/eval.rs b/compiler/gen/tests/helpers/eval.rs index 42834c10f9..ad8743c7a8 100644 --- a/compiler/gen/tests/helpers/eval.rs +++ b/compiler/gen/tests/helpers/eval.rs @@ -184,7 +184,7 @@ pub fn helper_without_uniqueness<'a>( ); // Uncomment this to see the module's un-optimized LLVM instruction output: - // env.module.print_to_stderr(); + env.module.print_to_stderr(); if main_fn.verify(true) { function_pass.run_on(&main_fn); diff --git a/compiler/mono/src/layout.rs b/compiler/mono/src/layout.rs index 035e17fddf..42a23d5e36 100644 --- a/compiler/mono/src/layout.rs +++ b/compiler/mono/src/layout.rs @@ -1,6 +1,6 @@ use bumpalo::collections::Vec; use bumpalo::Bump; -use roc_collections::all::MutMap; +use roc_collections::all::{MutMap, MutSet}; use roc_module::ident::{Lowercase, TagName}; use roc_module::symbol::Symbol; use roc_types::subs::{Content, FlatType, Subs, Variable}; @@ -57,13 +57,43 @@ pub enum Builtin<'a> { EmptySet, } +pub struct Env<'a, 'b> { + arena: &'a Bump, + seen: MutSet, + subs: &'b Subs, +} + +impl<'a, 'b> Env<'a, 'b> { + fn is_seen(&self, var: Variable) -> bool { + let var = self.subs.get_root_key_without_compacting(var); + + self.seen.contains(&var) + } + + fn insert_seen(&mut self, var: Variable) -> bool { + let var = self.subs.get_root_key_without_compacting(var); + + self.seen.insert(var) + } +} + impl<'a> Layout<'a> { pub fn new(arena: &'a Bump, content: Content, subs: &Subs) -> Result { + let mut env = Env { + arena, + subs, + seen: MutSet::default(), + }; + + Self::new_help(&mut env, content) + } + + fn new_help<'b>(env: &mut Env<'a, 'b>, content: Content) -> Result { use roc_types::subs::Content::*; match content { FlexVar(_) | RigidVar(_) => Err(LayoutProblem::UnresolvedTypeVar), - Structure(flat_type) => layout_from_flat_type(arena, flat_type, subs), + Structure(flat_type) => layout_from_flat_type(env, flat_type), Alias(Symbol::NUM_INT, args, _) => { debug_assert!(args.is_empty()); @@ -73,7 +103,7 @@ impl<'a> Layout<'a> { debug_assert!(args.is_empty()); Ok(Layout::Builtin(Builtin::Float64)) } - Alias(_, _, var) => Self::new(arena, subs.get_without_compacting(var).content, subs), + Alias(_, _, var) => Self::from_var(env, var), Error => Err(LayoutProblem::Erroneous), } } @@ -81,10 +111,14 @@ impl<'a> Layout<'a> { /// Returns Err(()) if given an error, or Ok(Layout) if given a non-erroneous Structure. /// Panics if given a FlexVar or RigidVar, since those should have been /// monomorphized away already! - fn from_var(arena: &'a Bump, var: Variable, subs: &Subs) -> Result { - let content = subs.get_without_compacting(var).content; - - Self::new(arena, content, subs) + fn from_var(env: &mut Env<'a, '_>, var: Variable) -> Result { + if env.is_seen(var) { + Ok(Layout::RecursivePointer) + } else { + let content = env.subs.get_without_compacting(var).content; + println!("{:?} {:?}", var, &content); + Self::new_help(env, content) + } } pub fn safe_to_memcpy(&self) -> bool { @@ -211,13 +245,15 @@ impl<'a> LayoutCache<'a> { // Store things according to the root Variable, to avoid duplicate work. let var = subs.get_root_key_without_compacting(var); + let mut env = Env { + arena, + subs, + seen: MutSet::default(), + }; + self.layouts .entry(var) - .or_insert_with(|| { - let content = subs.get_without_compacting(var).content; - - Layout::new(arena, content, subs) - }) + .or_insert_with(|| Layout::from_var(&mut env, var)) .clone() } } @@ -296,12 +332,14 @@ impl<'a> Builtin<'a> { } fn layout_from_flat_type<'a>( - arena: &'a Bump, + env: &mut Env<'a, '_>, flat_type: FlatType, - subs: &Subs, ) -> Result, LayoutProblem> { use roc_types::subs::FlatType::*; + let arena = env.arena; + let subs = env.subs; + match flat_type { Apply(symbol, args) => { match symbol { @@ -323,7 +361,7 @@ fn layout_from_flat_type<'a>( layout_from_num_content(content) } Symbol::STR_STR => Ok(Layout::Builtin(Builtin::Str)), - Symbol::LIST_LIST => list_layout_from_elem(arena, subs, args[0]), + Symbol::LIST_LIST => list_layout_from_elem(env, args[0]), Symbol::ATTR_ATTR => { debug_assert_eq!(args.len(), 2); @@ -332,7 +370,7 @@ fn layout_from_flat_type<'a>( let wrapped_var = args[1]; // correct the memory mode of unique lists - match Layout::from_var(arena, wrapped_var, subs)? { + match Layout::from_var(env, wrapped_var)? { Layout::Builtin(Builtin::List(_, elem_layout)) => { let uniqueness_var = args[0]; let uniqueness_content = @@ -358,13 +396,10 @@ fn layout_from_flat_type<'a>( let mut fn_args = Vec::with_capacity_in(args.len(), arena); for arg_var in args { - let arg_content = subs.get_without_compacting(arg_var).content; - - fn_args.push(Layout::new(arena, arg_content, subs)?); + fn_args.push(Layout::from_var(env, arg_var)?); } - let ret_content = subs.get_without_compacting(ret_var).content; - let ret = Layout::new(arena, ret_content, subs)?; + let ret = Layout::from_var(env, ret_var)?; Ok(Layout::FunctionPointer( fn_args.into_bump_slice(), @@ -400,9 +435,8 @@ fn layout_from_flat_type<'a>( Demanded(var) => var, } }; - let field_content = subs.get_without_compacting(field_var).content; - match Layout::new(arena, field_content, subs) { + match Layout::from_var(env, field_var) { Ok(layout) => { // Drop any zero-sized fields like {}. if !layout.is_zero_sized() { @@ -441,6 +475,7 @@ fn layout_from_flat_type<'a>( // That means none of the optimizations for enums or single tag tag unions apply let rec_var = subs.get_root_key_without_compacting(rec_var); + env.insert_seen(rec_var); let mut tag_layouts = Vec::with_capacity_in(tags.len(), arena); // tags: MutMap>, @@ -457,15 +492,12 @@ fn layout_from_flat_type<'a>( continue; } - let var_content = subs.get_without_compacting(var).content; - - tag_layout.push(Layout::new(arena, var_content, subs)?); + tag_layout.push(Layout::from_var(env, var)?); } tag_layouts.push(tag_layout.into_bump_slice()); } - dbg!(&tag_layouts); Ok(Layout::RecursiveUnion(tag_layouts.into_bump_slice())) } EmptyTagUnion => { @@ -486,6 +518,12 @@ pub fn sort_record_fields<'a>( ) -> Vec<'a, (Lowercase, Result, Layout<'a>>)> { let mut fields_map = MutMap::default(); + let mut env = Env { + arena, + subs, + seen: MutSet::default(), + }; + match roc_types::pretty_print::chase_ext_record(subs, var, &mut fields_map) { Ok(()) | Err((_, Content::FlexVar(_))) => { // Sort the fields by label @@ -498,13 +536,13 @@ pub fn sort_record_fields<'a>( RecordField::Required(v) => v, RecordField::Optional(v) => { let layout = - Layout::from_var(arena, v, subs).expect("invalid layout from var"); + Layout::from_var(&mut env, v).expect("invalid layout from var"); sorted_fields.push((label, Err(layout))); continue; } }; - let layout = Layout::from_var(arena, var, subs).expect("invalid layout from var"); + let layout = Layout::from_var(&mut env, var).expect("invalid layout from var"); // Drop any zero-sized fields like {} if !layout.is_zero_sized() { @@ -532,20 +570,44 @@ pub enum UnionVariant<'a> { pub fn union_sorted_tags<'a>(arena: &'a Bump, var: Variable, subs: &Subs) -> UnionVariant<'a> { let mut tags_vec = std::vec::Vec::new(); - match roc_types::pretty_print::chase_ext_tag_union(subs, var, &mut tags_vec) { - Ok(()) | Err((_, Content::FlexVar(_))) => union_sorted_tags_help(arena, tags_vec, subs), + let result = match roc_types::pretty_print::chase_ext_tag_union(subs, var, &mut tags_vec) { + Ok(()) | Err((_, Content::FlexVar(_))) => { + let opt_rec_var = get_recursion_var(subs, var); + union_sorted_tags_help(arena, tags_vec, opt_rec_var, subs) + } Err(other) => panic!("invalid content in tag union variable: {:?}", other), + }; + + result +} + +fn get_recursion_var(subs: &Subs, var: Variable) -> Option { + match subs.get_without_compacting(var).content { + Content::Structure(FlatType::RecursiveTagUnion(rec_var, _, _)) => Some(rec_var), + Content::Alias(_, _, actual) => get_recursion_var(subs, actual), + _ => None, } } fn union_sorted_tags_help<'a>( arena: &'a Bump, mut tags_vec: std::vec::Vec<(TagName, std::vec::Vec)>, + opt_rec_var: Option, subs: &Subs, ) -> UnionVariant<'a> { // sort up front; make sure the ordering stays intact! tags_vec.sort(); + let mut env = Env { + arena, + subs, + seen: MutSet::default(), + }; + + if let Some(rec_var) = opt_rec_var { + env.insert_seen(rec_var); + } + match tags_vec.len() { 0 => { // trying to instantiate a type with no values @@ -564,7 +626,7 @@ fn union_sorted_tags_help<'a>( } _ => { for var in arguments { - match Layout::from_var(arena, var, subs) { + match Layout::from_var(&mut env, var) { Ok(layout) => { // Drop any zero-sized arguments like {} if !layout.is_zero_sized() { @@ -603,10 +665,8 @@ fn union_sorted_tags_help<'a>( arg_layouts.push(Layout::Builtin(Builtin::Int64)); for var in arguments { - dbg!(&var); - match dbg!(Layout::from_var(arena, var, subs)) { + match Layout::from_var(&mut env, var) { Ok(layout) => { - dbg!(&layout); // Drop any zero-sized arguments like {} if !layout.is_zero_sized() { has_any_arguments = true; @@ -665,7 +725,8 @@ pub fn layout_from_tag_union<'a>( let tags_vec: std::vec::Vec<_> = tags.into_iter().collect(); if tags_vec[0].0 != TagName::Private(Symbol::NUM_AT_NUM) { - let variant = union_sorted_tags_help(arena, tags_vec, subs); + let opt_rec_var = None; + let variant = union_sorted_tags_help(arena, tags_vec, opt_rec_var, subs); match variant { Never => panic!("TODO gracefully handle trying to instantiate Never"), @@ -796,29 +857,28 @@ fn unwrap_num_tag<'a>(subs: &Subs, var: Variable) -> Result, LayoutPr } pub fn list_layout_from_elem<'a>( - arena: &'a Bump, - subs: &Subs, + env: &mut Env<'a, '_>, elem_var: Variable, ) -> Result, LayoutProblem> { - match subs.get_without_compacting(elem_var).content { + match env.subs.get_without_compacting(elem_var).content { Content::Structure(FlatType::Apply(Symbol::ATTR_ATTR, args)) => { debug_assert_eq!(args.len(), 2); let var = *args.get(1).unwrap(); - list_layout_from_elem(arena, subs, var) + list_layout_from_elem(env, var) } Content::FlexVar(_) | Content::RigidVar(_) => { // If this was still a (List *) then it must have been an empty list Ok(Layout::Builtin(Builtin::EmptyList)) } content => { - let elem_layout = Layout::new(arena, content, subs)?; + let elem_layout = Layout::new_help(env, content)?; // This is a normal list. Ok(Layout::Builtin(Builtin::List( MemoryMode::Refcounted, - arena.alloc(elem_layout), + env.arena.alloc(elem_layout), ))) } } diff --git a/compiler/mono/tests/test_mono.rs b/compiler/mono/tests/test_mono.rs index 59bd4452eb..ff966d66e2 100644 --- a/compiler/mono/tests/test_mono.rs +++ b/compiler/mono/tests/test_mono.rs @@ -82,6 +82,7 @@ mod test_mono { .map(|proc| proc.to_pretty(200)) .collect::>(); + procs_string.sort(); procs_string.push(ir_expr.to_pretty(200)); let result = procs_string.join("\n"); @@ -536,15 +537,15 @@ mod test_mono { "#, indoc!( r#" + procedure List.5 (#Attr.2, #Attr.3): + let Test.9 = lowlevel ListAppend #Attr.2 #Attr.3; + ret Test.9; + procedure Test.0 (Test.2): let Test.8 = 42i64; let Test.7 = CallByName List.5 Test.2 Test.8; ret Test.7; - procedure List.5 (#Attr.2, #Attr.3): - let Test.9 = lowlevel ListAppend #Attr.2 #Attr.3; - ret Test.9; - let Test.5 = 1i64; let Test.6 = 2i64; let Test.4 = Array [Test.5, Test.6]; @@ -590,17 +591,17 @@ mod test_mono { "#, indoc!( r#" - procedure Num.14 (#Attr.2, #Attr.3): - let Test.11 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.11; + procedure List.7 (#Attr.2): + let Test.10 = lowlevel ListLen #Attr.2; + ret Test.10; procedure List.7 (#Attr.2): let Test.9 = lowlevel ListLen #Attr.2; ret Test.9; - procedure List.7 (#Attr.2): - let Test.10 = lowlevel ListLen #Attr.2; - ret Test.10; + procedure Num.14 (#Attr.2, #Attr.3): + let Test.11 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.11; let Test.8 = 1f64; let Test.1 = Array [Test.8]; @@ -832,6 +833,10 @@ mod test_mono { ), indoc!( r#" + procedure Bool.5 (#Attr.2, #Attr.3): + let Test.11 = lowlevel Eq #Attr.2 #Attr.3; + ret Test.11; + procedure Test.0 (Test.3): let Test.6 = 10i64; let Test.14 = true; @@ -847,10 +852,6 @@ mod test_mono { let Test.12 = 42i64; ret Test.12; - procedure Bool.5 (#Attr.2, #Attr.3): - let Test.11 = lowlevel Eq #Attr.2 #Attr.3; - ret Test.11; - let Test.5 = Struct {}; let Test.4 = CallByName Test.0 Test.5; ret Test.4; @@ -1257,6 +1258,10 @@ mod test_mono { let Test.13 = lowlevel NumSub #Attr.2 #Attr.3; ret Test.13; + procedure Num.16 (#Attr.2, #Attr.3): + let Test.11 = lowlevel NumMul #Attr.2 #Attr.3; + ret Test.11; + procedure Test.0 (Test.2, Test.3): jump Test.18 Test.2 Test.3; joinpoint Test.18 Test.2 Test.3: @@ -1272,10 +1277,6 @@ mod test_mono { let Test.10 = CallByName Num.16 Test.2 Test.3; jump Test.18 Test.9 Test.10; - procedure Num.16 (#Attr.2, #Attr.3): - let Test.11 = lowlevel NumMul #Attr.2 #Attr.3; - ret Test.11; - let Test.5 = 10i64; let Test.6 = 1i64; let Test.4 = CallByName Test.0 Test.5 Test.6; @@ -1444,16 +1445,6 @@ mod test_mono { ), indoc!( r#" - procedure Num.14 (#Attr.2, #Attr.3): - let Test.19 = lowlevel NumAdd #Attr.2 #Attr.3; - ret Test.19; - - procedure Test.1 (Test.3): - let Test.13 = 0i64; - let Test.14 = 0i64; - let Test.12 = CallByName List.4 Test.3 Test.13 Test.14; - ret Test.12; - procedure List.4 (#Attr.2, #Attr.3, #Attr.4): let Test.18 = lowlevel ListLen #Attr.2; let Test.16 = lowlevel NumLt #Attr.3 Test.18; @@ -1467,6 +1458,16 @@ mod test_mono { let Test.11 = lowlevel ListLen #Attr.2; ret Test.11; + procedure Num.14 (#Attr.2, #Attr.3): + let Test.19 = lowlevel NumAdd #Attr.2 #Attr.3; + ret Test.19; + + procedure Test.1 (Test.3): + let Test.13 = 0i64; + let Test.14 = 0i64; + let Test.12 = CallByName List.4 Test.3 Test.13 Test.14; + ret Test.12; + let Test.8 = 1i64; let Test.9 = 2i64; let Test.10 = 3i64; @@ -1497,16 +1498,6 @@ mod test_mono { ), indoc!( r#" - procedure Test.0 (Test.2): - let Test.16 = 1i64; - let Test.17 = 2i64; - let Test.18 = 3i64; - let Test.6 = Array [Test.16, Test.17, Test.18]; - let Test.7 = 0i64; - let Test.5 = CallByName List.3 Test.6 Test.7; - dec Test.6; - ret Test.5; - procedure List.3 (#Attr.2, #Attr.3): let Test.15 = lowlevel ListLen #Attr.2; let Test.11 = lowlevel NumLt #Attr.3 Test.15; @@ -1521,6 +1512,16 @@ mod test_mono { let Test.8 = Err Test.9 Test.10; ret Test.8; + procedure Test.0 (Test.2): + let Test.16 = 1i64; + let Test.17 = 2i64; + let Test.18 = 3i64; + let Test.6 = Array [Test.16, Test.17, Test.18]; + let Test.7 = 0i64; + let Test.5 = CallByName List.3 Test.6 Test.7; + dec Test.6; + ret Test.5; + let Test.4 = Struct {}; let Test.3 = CallByName Test.0 Test.4; ret Test.3; @@ -1540,11 +1541,71 @@ mod test_mono { three = S (S (S Z)) three - "# ), indoc!( r#" + let Test.3 = 0i64; + let Test.5 = 0i64; + let Test.7 = 0i64; + let Test.9 = 1i64; + let Test.8 = Z Test.9; + let Test.6 = S Test.7 Test.8; + let Test.4 = S Test.5 Test.6; + let Test.1 = S Test.3 Test.4; + ret Test.1; + "# + ), + ) + } + + #[test] + fn peano2() { + compiles_to_ir( + indoc!( + r#" + Peano : [ S Peano, Z ] + + three : Peano + three = S (S (S Z)) + + when three is + S (S _) -> 1 + S (_) -> 0 + Z -> 0 + "# + ), + indoc!( + r#" + let Test.16 = 0i64; + let Test.18 = 0i64; + let Test.20 = 0i64; + let Test.22 = 1i64; + let Test.21 = Z Test.22; + let Test.19 = S Test.20 Test.21; + let Test.17 = S Test.18 Test.19; + let Test.1 = S Test.16 Test.17; + let Test.12 = true; + let Test.14 = Index 0 Test.1; + let Test.13 = 0i64; + let Test.15 = lowlevel Eq Test.13 Test.14; + let Test.11 = lowlevel And Test.15 Test.12; + if Test.11 then + let Test.7 = true; + let Test.9 = Index 0 Test.1; + let Test.8 = 0i64; + let Test.10 = lowlevel Eq Test.8 Test.9; + let Test.6 = lowlevel And Test.10 Test.7; + if Test.6 then + let Test.3 = Index 1 Test.1; + let Test.2 = 1i64; + ret Test.2; + else + let Test.4 = 0i64; + ret Test.4; + else + let Test.5 = 0i64; + ret Test.5; "# ), ) diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index fbe5ab6f28..3fff70340d 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -1295,6 +1295,8 @@ fn deep_copy_var_help( RecursiveTagUnion(rec_var, tags, ext_var) => { let mut new_tags = MutMap::default(); + let new_rec_var = deep_copy_var_help(subs, max_rank, pools, rec_var); + for (tag, vars) in tags { let new_vars: Vec = vars .into_iter() @@ -1304,7 +1306,7 @@ fn deep_copy_var_help( } RecursiveTagUnion( - deep_copy_var_help(subs, max_rank, pools, rec_var), + new_rec_var, new_tags, deep_copy_var_help(subs, max_rank, pools, ext_var), ) diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index 72aa41c6d2..f917bdc998 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -563,9 +563,12 @@ fn unify_shared_tags( // > RecursiveTagUnion(rvar, [ Cons a [ Cons a rvar, Nil ], Nil ], ext) // // and so on until the whole non-recursive tag union can be unified with it. - let problems = if let Some(rvar) = recursion_var { + let mut problems = Vec::new(); + + if let Some(rvar) = recursion_var { if expected == rvar { - unify_pool(subs, pool, actual, ctx.second) + problems.extend(unify_pool(subs, pool, actual, ctx.second)); + println!("A"); } else if is_structure(actual, subs) { // the recursion variable is hidden behind some structure (commonly an Attr // with uniqueness inference). Thus we must expand the recursive tag union to @@ -578,19 +581,29 @@ fn unify_shared_tags( // when `actual` is just a flex/rigid variable, the substitution would expand a // recursive tag union infinitely! - unify_pool(subs, pool, actual, expected) + problems.extend(unify_pool(subs, pool, actual, expected)); + println!("B"); } else { // unification with a non-structure is trivial - unify_pool(subs, pool, actual, expected) + problems.extend(unify_pool(subs, pool, actual, expected)); + println!("C"); } } else { // we always unify NonRecursive with Recursive, so this should never happen debug_assert_ne!(Some(actual), recursion_var); - unify_pool(subs, pool, actual, expected) + problems.extend(unify_pool(subs, pool, actual, expected)); + println!("D"); }; + // TODO this changes some error messages + // but is important for the inference of recursive types if problems.is_empty() { + problems.extend(unify_pool(subs, pool, expected, actual)); + } + + if problems.is_empty() { + // debug_assert_eq!(subs.get_root_key(actual), subs.get_root_key(expected)); matching_vars.push(actual); } }