diff --git a/cli/src/repl/eval.rs b/cli/src/repl/eval.rs index 8401ee3511..4238432a86 100644 --- a/cli/src/repl/eval.rs +++ b/cli/src/repl/eval.rs @@ -324,9 +324,9 @@ fn jit_to_ast_help<'a>( todo!("print recursive tag unions in the REPL") } Content::Alias(_, _, _, actual) => { - let content = env.subs.get_without_compacting(*actual).content; + let content = env.subs.get_content_without_compacting(*actual); - jit_to_ast_help(env, lib, main_fn_name, layout, &content) + jit_to_ast_help(env, lib, main_fn_name, layout, content) } other => unreachable!("Weird content for Union layout: {:?}", other), } @@ -468,7 +468,7 @@ fn list_to_ast<'a>( let elem_var = *vars.first().unwrap(); - env.subs.get_without_compacting(elem_var).content + env.subs.get_content_without_compacting(elem_var) } other => { unreachable!( @@ -486,7 +486,7 @@ fn list_to_ast<'a>( let offset_bytes = index * elem_size; let elem_ptr = unsafe { ptr.add(offset_bytes) }; let loc_expr = &*arena.alloc(Located { - value: ptr_to_ast(env, elem_ptr, elem_layout, &elem_content), + value: ptr_to_ast(env, elem_ptr, elem_layout, elem_content), region: Region::zero(), }); @@ -539,8 +539,8 @@ where let mut field_ptr = ptr as *const u8; for (var, layout) in sequence { - let content = subs.get_without_compacting(var).content; - let expr = ptr_to_ast(env, field_ptr, layout, &content); + let content = subs.get_content_without_compacting(var); + let expr = ptr_to_ast(env, field_ptr, layout, content); let loc_expr = Located::at_zero(expr); output.push(&*arena.alloc(loc_expr)); @@ -577,10 +577,10 @@ fn struct_to_ast<'a>( // this is a 1-field wrapper record around another record or 1-tag tag union let (label, field) = sorted_fields.pop().unwrap(); - let inner_content = env.subs.get_without_compacting(field.into_inner()).content; + let inner_content = env.subs.get_content_without_compacting(field.into_inner()); let loc_expr = &*arena.alloc(Located { - value: ptr_to_ast(env, ptr, &Layout::Struct(field_layouts), &inner_content), + value: ptr_to_ast(env, ptr, &Layout::Struct(field_layouts), inner_content), region: Region::zero(), }); @@ -606,9 +606,9 @@ fn struct_to_ast<'a>( let mut field_ptr = ptr; for ((label, field), field_layout) in sorted_fields.iter().zip(field_layouts.iter()) { - let content = subs.get_without_compacting(*field.as_inner()).content; + let content = subs.get_content_without_compacting(*field.as_inner()); let loc_expr = &*arena.alloc(Located { - value: ptr_to_ast(env, field_ptr, field_layout, &content), + value: ptr_to_ast(env, field_ptr, field_layout, content), region: Region::zero(), }); @@ -659,9 +659,9 @@ fn bool_to_ast<'a>(env: &Env<'a, '_>, value: bool, content: &Content) -> Expr<'a // and/or records (e.g. { a: { b: { c: True } } }), // so we need to do this recursively on the field type. let field_var = *field.as_inner(); - let field_content = env.subs.get_without_compacting(field_var).content; + let field_content = env.subs.get_content_without_compacting(field_var); let loc_expr = Located { - value: bool_to_ast(env, value, &field_content), + value: bool_to_ast(env, value, field_content), region: Region::zero(), }; @@ -701,10 +701,10 @@ fn bool_to_ast<'a>(env: &Env<'a, '_>, value: bool, content: &Content) -> Expr<'a debug_assert_eq!(payload_vars.len(), 1); let var = *payload_vars.iter().next().unwrap(); - let content = env.subs.get_without_compacting(var).content; + let content = env.subs.get_content_without_compacting(var); let loc_payload = &*arena.alloc(Located { - value: bool_to_ast(env, value, &content), + value: bool_to_ast(env, value, content), region: Region::zero(), }); @@ -739,9 +739,9 @@ fn bool_to_ast<'a>(env: &Env<'a, '_>, value: bool, content: &Content) -> Expr<'a } } Alias(_, _, _, var) => { - let content = env.subs.get_without_compacting(*var).content; + let content = env.subs.get_content_without_compacting(*var); - bool_to_ast(env, value, &content) + bool_to_ast(env, value, content) } other => { unreachable!("Unexpected FlatType {:?} in bool_to_ast", other); @@ -771,9 +771,9 @@ fn byte_to_ast<'a>(env: &Env<'a, '_>, value: u8, content: &Content) -> Expr<'a> // and/or records (e.g. { a: { b: { c: True } } }), // so we need to do this recursively on the field type. let field_var = *field.as_inner(); - let field_content = env.subs.get_without_compacting(field_var).content; + let field_content = env.subs.get_content_without_compacting(field_var); let loc_expr = Located { - value: byte_to_ast(env, value, &field_content), + value: byte_to_ast(env, value, field_content), region: Region::zero(), }; @@ -813,10 +813,10 @@ fn byte_to_ast<'a>(env: &Env<'a, '_>, value: u8, content: &Content) -> Expr<'a> debug_assert_eq!(payload_vars.len(), 1); let var = *payload_vars.iter().next().unwrap(); - let content = env.subs.get_without_compacting(var).content; + let content = env.subs.get_content_without_compacting(var); let loc_payload = &*arena.alloc(Located { - value: byte_to_ast(env, value, &content), + value: byte_to_ast(env, value, content), region: Region::zero(), }); @@ -850,9 +850,9 @@ fn byte_to_ast<'a>(env: &Env<'a, '_>, value: u8, content: &Content) -> Expr<'a> } } Alias(_, _, _, var) => { - let content = env.subs.get_without_compacting(*var).content; + let content = env.subs.get_content_without_compacting(*var); - byte_to_ast(env, value, &content) + byte_to_ast(env, value, content) } other => { unreachable!("Unexpected FlatType {:?} in bool_to_ast", other); @@ -887,9 +887,9 @@ fn num_to_ast<'a>(env: &Env<'a, '_>, num_expr: Expr<'a>, content: &Content) -> E // and/or records (e.g. { a: { b: { c: 5 } } }), // so we need to do this recursively on the field type. let field_var = *field.as_inner(); - let field_content = env.subs.get_without_compacting(field_var).content; + let field_content = env.subs.get_content_without_compacting(field_var); let loc_expr = Located { - value: num_to_ast(env, num_expr, &field_content), + value: num_to_ast(env, num_expr, field_content), region: Region::zero(), }; @@ -937,10 +937,10 @@ fn num_to_ast<'a>(env: &Env<'a, '_>, num_expr: Expr<'a>, content: &Content) -> E debug_assert_eq!(payload_vars.len(), 1); let var = *payload_vars.iter().next().unwrap(); - let content = env.subs.get_without_compacting(var).content; + let content = env.subs.get_content_without_compacting(var); let loc_payload = &*arena.alloc(Located { - value: num_to_ast(env, num_expr, &content), + value: num_to_ast(env, num_expr, content), region: Region::zero(), }); @@ -955,9 +955,9 @@ fn num_to_ast<'a>(env: &Env<'a, '_>, num_expr: Expr<'a>, content: &Content) -> E } } Alias(_, _, _, var) => { - let content = env.subs.get_without_compacting(*var).content; + let content = env.subs.get_content_without_compacting(*var); - num_to_ast(env, num_expr, &content) + num_to_ast(env, num_expr, content) } other => { panic!("Unexpected FlatType {:?} in num_to_ast", other); diff --git a/cli/src/repl/gen.rs b/cli/src/repl/gen.rs index f10ec9f704..a1dd4d3372 100644 --- a/cli/src/repl/gen.rs +++ b/cli/src/repl/gen.rs @@ -154,8 +154,8 @@ pub fn gen_and_eval<'a>( // pretty-print the expr type string for later. name_all_type_vars(main_fn_var, &mut subs); - let content = subs.get(main_fn_var).content; - let expr_type_str = content_to_string(content.clone(), &subs, home, &interns); + let content = subs.get_content_without_compacting(main_fn_var); + let expr_type_str = content_to_string(content, &subs, home, &interns); let (_, main_fn_layout) = match procedures.keys().find(|(s, _)| *s == main_fn_symbol) { Some(layout) => *layout, @@ -227,7 +227,7 @@ pub fn gen_and_eval<'a>( lib, main_fn_name, main_fn_layout, - &content, + content, &env.interns, home, &subs, diff --git a/compiler/load/tests/test_load.rs b/compiler/load/tests/test_load.rs index 6716acda0c..6630bcede1 100644 --- a/compiler/load/tests/test_load.rs +++ b/compiler/load/tests/test_load.rs @@ -184,10 +184,9 @@ mod test_load { expected_types: &mut HashMap<&str, &str>, ) { for (symbol, expr_var) in &def.pattern_vars { - let content = subs.get(*expr_var).content; - name_all_type_vars(*expr_var, subs); + let content = subs.get_content_without_compacting(*expr_var); let actual_str = content_to_string(content, subs, home, interns); let fully_qualified = symbol.fully_qualified(interns, home).to_string(); let expected_type = expected_types diff --git a/compiler/mono/src/borrow.rs b/compiler/mono/src/borrow.rs index 4968fe7f06..5757341f9e 100644 --- a/compiler/mono/src/borrow.rs +++ b/compiler/mono/src/borrow.rs @@ -21,20 +21,14 @@ pub fn infer_borrow<'a>( procs: &MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>, ) -> ParamMap<'a> { // intern the layouts - let mut declaration_to_index = MutMap::with_capacity_and_hasher(procs.len(), default_hasher()); let mut param_map = { - let mut i = 0; - for key in procs.keys() { - declaration_to_index.insert(*key, ParamOffset(i)); - - i += key.1.arguments.len(); - } + let (declaration_to_index, total_number_of_params) = DeclarationToIndex::new(arena, procs); ParamMap { declaration_to_index, join_points: MutMap::default(), - declarations: bumpalo::vec![in arena; Param::EMPTY; i], + declarations: bumpalo::vec![in arena; Param::EMPTY; total_number_of_params], } }; @@ -118,7 +112,7 @@ pub fn infer_borrow<'a>( for (key, proc) in procs { let symbol = key.0; - let offset = param_map.declaration_to_index[key]; + let offset = param_map.get_param_offset(key.0, key.1); // the component this symbol is a part of let component = symbol_to_component[&symbol]; @@ -169,10 +163,71 @@ impl From for usize { } } -#[derive(Debug, Clone)] +#[derive(Debug)] +struct DeclarationToIndex<'a> { + elements: Vec<'a, ((Symbol, ProcLayout<'a>), ParamOffset)>, +} + +impl<'a> DeclarationToIndex<'a> { + fn new(arena: &'a Bump, procs: &MutMap<(Symbol, ProcLayout<'a>), Proc<'a>>) -> (Self, usize) { + let mut declaration_to_index = Vec::with_capacity_in(procs.len(), arena); + + let mut i = 0; + for key in procs.keys().copied() { + declaration_to_index.push((key, ParamOffset(i))); + + i += key.1.arguments.len(); + } + + declaration_to_index.sort_unstable_by_key(|t| t.0 .0); + + ( + DeclarationToIndex { + elements: declaration_to_index, + }, + i, + ) + } + + fn get_param_offset( + &self, + needle_symbol: Symbol, + needle_layout: ProcLayout<'a>, + ) -> ParamOffset { + if let Ok(middle_index) = self + .elements + .binary_search_by_key(&needle_symbol, |t| t.0 .0) + { + // first, iterate backward until we hit a different symbol + let backward = self.elements[..middle_index].iter().rev(); + + for ((symbol, proc_layout), param_offset) in backward { + if *symbol != needle_symbol { + break; + } else if *proc_layout == needle_layout { + return *param_offset; + } + } + + // if not found, iterate forward until we find our combo + let forward = self.elements[middle_index..].iter(); + + for ((symbol, proc_layout), param_offset) in forward { + if *symbol != needle_symbol { + break; + } else if *proc_layout == needle_layout { + return *param_offset; + } + } + } + unreachable!("symbol/layout combo must be in DeclarationToIndex") + } +} + +#[derive(Debug)] pub struct ParamMap<'a> { /// Map a (Symbol, ProcLayout) pair to the starting index in the `declarations` array - declaration_to_index: MutMap<(Symbol, ProcLayout<'a>), ParamOffset>, + declaration_to_index: DeclarationToIndex<'a>, /// the parameters of all functions in a single flat array. /// /// - the map above gives the index of the first parameter for the function @@ -184,11 +239,17 @@ pub struct ParamMap<'a> { } impl<'a> ParamMap<'a> { + pub fn get_param_offset(&self, symbol: Symbol, layout: ProcLayout<'a>) -> ParamOffset { + self.declaration_to_index.get_param_offset(symbol, layout) + } + pub fn get_symbol(&self, symbol: Symbol, layout: ProcLayout<'a>) -> Option<&[Param<'a>]> { - let index: usize = self.declaration_to_index[&(symbol, layout)].into(); + // let index: usize = self.declaration_to_index[&(symbol, layout)].into(); + let index: usize = self.get_param_offset(symbol, layout).into(); self.declarations.get(index..index + layout.arguments.len()) } + pub fn get_join_point(&self, id: JoinPointId) -> &'a [Param<'a>] { match self.join_points.get(&id) { Some(slice) => slice, @@ -197,7 +258,7 @@ impl<'a> ParamMap<'a> { } pub fn iter_symbols(&'a self) -> impl Iterator { - self.declaration_to_index.iter().map(|t| &t.0 .0) + self.declaration_to_index.elements.iter().map(|t| &t.0 .0) } } @@ -247,7 +308,7 @@ impl<'a> ParamMap<'a> { return; } - let index: usize = self.declaration_to_index[&key].into(); + let index: usize = self.get_param_offset(key.0, key.1).into(); for (i, param) in Self::init_borrow_args(arena, proc.args) .iter() @@ -266,7 +327,7 @@ impl<'a> ParamMap<'a> { proc: &Proc<'a>, key: (Symbol, ProcLayout<'a>), ) { - let index: usize = self.declaration_to_index[&key].into(); + let index: usize = self.get_param_offset(key.0, key.1).into(); for (i, param) in Self::init_borrow_args_always_owned(arena, proc.args) .iter() diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 2db89e8cf1..01250eec10 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -3010,7 +3010,7 @@ pub fn with_hole<'a>( let arena = env.arena; debug_assert!(!matches!( - env.subs.get_without_compacting(variant_var).content, + env.subs.get_content_without_compacting(variant_var), Content::Structure(FlatType::Func(_, _, _)) )); convert_tag_union( @@ -3035,12 +3035,15 @@ pub fn with_hole<'a>( } => { let arena = env.arena; - let desc = env.subs.get_without_compacting(variant_var); + let content = env.subs.get_content_without_compacting(variant_var); + + if let Content::Structure(FlatType::Func(arg_vars, _, ret_var)) = content { + let ret_var = *ret_var; + let arg_vars = arg_vars.clone(); - if let Content::Structure(FlatType::Func(arg_vars, _, ret_var)) = desc.content { tag_union_to_function( env, - arg_vars, + &arg_vars, ret_var, tag_name, closure_name, @@ -4343,7 +4346,7 @@ fn convert_tag_union<'a>( #[allow(clippy::too_many_arguments)] fn tag_union_to_function<'a>( env: &mut Env<'a, '_>, - argument_variables: std::vec::Vec, + argument_variables: &[Variable], return_variable: Variable, tag_name: TagName, proc_symbol: Symbol, @@ -4364,8 +4367,8 @@ fn tag_union_to_function<'a>( let loc_expr = Located::at_zero(roc_can::expr::Expr::Var(arg_symbol)); - loc_pattern_args.push((arg_var, loc_pattern)); - loc_expr_args.push((arg_var, loc_expr)); + loc_pattern_args.push((*arg_var, loc_pattern)); + loc_expr_args.push((*arg_var, loc_expr)); } let loc_body = Located::at_zero(roc_can::expr::Expr::Tag { @@ -7395,8 +7398,8 @@ fn from_can_pattern_help<'a>( // TODO these don't match up in the uniqueness inference; when we remove // that, reinstate this assert! // - // dbg!(&env.subs.get_without_compacting(*field_var).content); - // dbg!(&env.subs.get_without_compacting(destruct.value.var).content); + // dbg!(&env.subs.get_content_without_compacting(*field_var)); + // dbg!(&env.subs.get_content_without_compacting(destruct.var).content); // debug_assert_eq!( // env.subs.get_root_key_without_compacting(*field_var), // env.subs.get_root_key_without_compacting(destruct.value.var) @@ -7490,7 +7493,7 @@ pub fn num_argument_to_int_or_float( var: Variable, known_to_be_float: bool, ) -> IntOrFloat { - match subs.get_without_compacting(var).content { + match subs.get_content_without_compacting(var){ Content::FlexVar(_) | Content::RigidVar(_) if known_to_be_float => IntOrFloat::BinaryFloatType(FloatPrecision::F64), Content::FlexVar(_) | Content::RigidVar(_) => IntOrFloat::SignedIntType(IntPrecision::I64), // We default (Num *) to I64 diff --git a/compiler/reporting/tests/helpers/mod.rs b/compiler/reporting/tests/helpers/mod.rs index c301471f9f..07f1745dec 100644 --- a/compiler/reporting/tests/helpers/mod.rs +++ b/compiler/reporting/tests/helpers/mod.rs @@ -36,7 +36,10 @@ pub fn infer_expr( }; let (solved, _) = solve::run(&env, problems, subs, constraint); - let content = solved.inner().get_without_compacting(expr_var).content; + let content = solved + .inner() + .get_content_without_compacting(expr_var) + .clone(); (content, solved.into_inner()) } diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 4bcea3bbf7..170b016ebf 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -528,11 +528,10 @@ fn solve( .get(next_rank) .iter() .filter(|var| { - let current = subs.get_without_compacting( - roc_types::subs::Variable::clone(var), - ); + let current_rank = + subs.get_rank(roc_types::subs::Variable::clone(var)); - current.rank.into_usize() > next_rank.into_usize() + current_rank.into_usize() > next_rank.into_usize() }) .collect::>(); @@ -561,8 +560,7 @@ fn solve( let failing: Vec<_> = rigid_vars .iter() .filter(|&var| { - !subs.redundant(*var) - && subs.get_without_compacting(*var).rank != Rank::NONE + !subs.redundant(*var) && subs.get_rank(*var) != Rank::NONE }) .collect(); diff --git a/compiler/solve/tests/solve_expr.rs b/compiler/solve/tests/solve_expr.rs index 319c8a3dad..4f93fe449e 100644 --- a/compiler/solve/tests/solve_expr.rs +++ b/compiler/solve/tests/solve_expr.rs @@ -115,7 +115,7 @@ mod solve_expr { let content = { debug_assert!(exposed_to_host.len() == 1); let (_symbol, variable) = exposed_to_host.into_iter().next().unwrap(); - subs.get(variable).content + subs.get_content_without_compacting(variable) }; let actual_str = content_to_string(content, subs, home, &interns); diff --git a/compiler/types/src/pretty_print.rs b/compiler/types/src/pretty_print.rs index 89decf198c..6efa9bcaca 100644 --- a/compiler/types/src/pretty_print.rs +++ b/compiler/types/src/pretty_print.rs @@ -77,11 +77,11 @@ fn find_names_needed( use crate::subs::FlatType::*; while let Some((recursive, _chain)) = subs.occurs(variable) { - let content = subs.get_without_compacting(recursive).content; + let rec_var = subs.fresh_unnamed_flex_var(); + let content = subs.get_content_without_compacting(recursive); + match content { Content::Structure(FlatType::TagUnion(tags, ext_var)) => { - let rec_var = subs.fresh_unnamed_flex_var(); - let mut new_tags = MutMap::default(); for (label, args) in tags { @@ -94,7 +94,7 @@ fn find_names_needed( new_tags.insert(label.clone(), new_args); } - let flat_type = FlatType::RecursiveTagUnion(rec_var, new_tags, ext_var); + let flat_type = FlatType::RecursiveTagUnion(rec_var, new_tags, *ext_var); subs.set_content(recursive, Content::Structure(flat_type)); } _ => panic!( @@ -104,7 +104,7 @@ fn find_names_needed( } } - match subs.get_without_compacting(variable).content { + match &subs.get_content_without_compacting(variable).clone() { RecursionVar { opt_name: None, .. } | FlexVar(None) => { let root = subs.get_root_key_without_compacting(variable); @@ -133,24 +133,24 @@ fn find_names_needed( // User-defined names are already taken. // We must not accidentally generate names that collide with them! - names_taken.insert(name); + names_taken.insert(name.clone()); } RigidVar(name) => { // User-defined names are already taken. // We must not accidentally generate names that collide with them! - names_taken.insert(name); + names_taken.insert(name.clone()); } Structure(Apply(_, args)) => { for var in args { - find_names_needed(var, subs, roots, root_appearances, names_taken); + find_names_needed(*var, subs, roots, root_appearances, names_taken); } } Structure(Func(arg_vars, _closure_var, ret_var)) => { for var in arg_vars { - find_names_needed(var, subs, roots, root_appearances, names_taken); + find_names_needed(*var, subs, roots, root_appearances, names_taken); } - find_names_needed(ret_var, subs, roots, root_appearances, names_taken); + find_names_needed(*ret_var, subs, roots, root_appearances, names_taken); } Structure(Record(fields, ext_var)) => { let mut sorted_fields: Vec<_> = fields.iter().collect(); @@ -167,7 +167,7 @@ fn find_names_needed( ); } - find_names_needed(ext_var, subs, roots, root_appearances, names_taken); + find_names_needed(*ext_var, subs, roots, root_appearances, names_taken); } Structure(TagUnion(tags, ext_var)) => { let mut sorted_tags: Vec<_> = tags.iter().collect(); @@ -177,10 +177,10 @@ fn find_names_needed( find_names_needed(*var, subs, roots, root_appearances, names_taken); } - find_names_needed(ext_var, subs, roots, root_appearances, names_taken); + find_names_needed(*ext_var, subs, roots, root_appearances, names_taken); } Structure(FunctionOrTagUnion(_, _, ext_var)) => { - find_names_needed(ext_var, subs, roots, root_appearances, names_taken); + find_names_needed(*ext_var, subs, roots, root_appearances, names_taken); } Structure(RecursiveTagUnion(rec_var, tags, ext_var)) => { let mut sorted_tags: Vec<_> = tags.iter().collect(); @@ -190,12 +190,12 @@ fn find_names_needed( find_names_needed(*var, subs, roots, root_appearances, names_taken); } - find_names_needed(ext_var, subs, roots, root_appearances, names_taken); - find_names_needed(rec_var, subs, roots, root_appearances, names_taken); + find_names_needed(*ext_var, subs, roots, root_appearances, names_taken); + find_names_needed(*rec_var, subs, roots, root_appearances, names_taken); } Alias(_symbol, args, _, _actual) => { for (_, var) in args { - find_names_needed(var, subs, roots, root_appearances, names_taken); + find_names_needed(*var, subs, roots, root_appearances, names_taken); } // TODO should we also look in the actual variable? // find_names_needed(_actual, subs, roots, root_appearances, names_taken); @@ -240,22 +240,22 @@ fn name_root( fn set_root_name(root: Variable, name: Lowercase, subs: &mut Subs) { use crate::subs::Content::*; - let mut descriptor = subs.get_without_compacting(root); + let old_content = subs.get_content_without_compacting(root); - match descriptor.content { + match old_content { FlexVar(None) => { - descriptor.content = FlexVar(Some(name)); - subs.set(root, descriptor); + let content = FlexVar(Some(name)); + subs.set_content(root, content); } RecursionVar { opt_name: None, structure, } => { - descriptor.content = RecursionVar { - structure, + let content = RecursionVar { + structure: *structure, opt_name: Some(name), }; - subs.set(root, descriptor); + subs.set_content(root, content); } RecursionVar { opt_name: Some(_existing), @@ -270,7 +270,7 @@ fn set_root_name(root: Variable, name: Lowercase, subs: &mut Subs) { } pub fn content_to_string( - content: Content, + content: &Content, subs: &Subs, home: ModuleId, interns: &Interns, @@ -283,7 +283,7 @@ pub fn content_to_string( buf } -fn write_content(env: &Env, content: Content, subs: &Subs, buf: &mut String, parens: Parens) { +fn write_content(env: &Env, content: &Content, subs: &Subs, buf: &mut String, parens: Parens) { use crate::subs::Content::*; match content { @@ -298,14 +298,14 @@ fn write_content(env: &Env, content: Content, subs: &Subs, buf: &mut String, par Alias(symbol, args, _, _actual) => { let write_parens = parens == Parens::InTypeParam && !args.is_empty(); - match symbol { + match *symbol { Symbol::NUM_NUM => { debug_assert_eq!(args.len(), 1); let (_, arg_var) = args .get(0) .expect("Num was not applied to a type argument!"); - let content = subs.get_without_compacting(*arg_var).content; + let content = subs.get_content_without_compacting(*arg_var); match &content { Alias(nested, _, _, _) => match *nested { @@ -326,13 +326,13 @@ fn write_content(env: &Env, content: Content, subs: &Subs, buf: &mut String, par } _ => write_parens!(write_parens, buf, { - write_symbol(env, symbol, buf); + write_symbol(env, *symbol, buf); for (_, var) in args { buf.push(' '); write_content( env, - subs.get_without_compacting(var).content, + subs.get_content_without_compacting(*var), subs, buf, Parens::InTypeParam, @@ -342,7 +342,7 @@ fn write_content(env: &Env, content: Content, subs: &Subs, buf: &mut String, par // useful for debugging if false { buf.push_str("[[ but really "); - let content = subs.get_without_compacting(_actual).content; + let content = subs.get_content_without_compacting(*_actual); write_content(env, content, subs, buf, parens); buf.push_str("]]"); } @@ -353,19 +353,74 @@ fn write_content(env: &Env, content: Content, subs: &Subs, buf: &mut String, par } } -fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String, parens: Parens) { +fn write_sorted_tags<'a>( + env: &Env, + subs: &'a Subs, + buf: &mut String, + tags: &MutMap>, + ext_var: Variable, +) -> Result<(), (Variable, &'a Content)> { + // Sort the fields so they always end up in the same order. + let mut sorted_fields = Vec::with_capacity(tags.len()); + + for (label, vars) in tags { + sorted_fields.push((label, vars)); + } + + // If the `ext` contains tags, merge them into the list of tags. + // this can occur when inferring mutually recursive tags + let mut from_ext = Default::default(); + let ext_content = chase_ext_tag_union(subs, ext_var, &mut from_ext); + + for (tag_name, arguments) in from_ext.iter() { + sorted_fields.push((tag_name, arguments)); + } + + let interns = &env.interns; + let home = env.home; + + sorted_fields + .sort_by(|(a, _), (b, _)| a.as_string(interns, home).cmp(&b.as_string(interns, home))); + + let mut any_written_yet = false; + + for (label, vars) in sorted_fields { + if any_written_yet { + buf.push_str(", "); + } else { + any_written_yet = true; + } + + buf.push_str(&label.as_string(interns, home)); + + for var in vars { + buf.push(' '); + write_content( + env, + subs.get_content_without_compacting(*var), + subs, + buf, + Parens::InTypeParam, + ); + } + } + + ext_content +} + +fn write_flat_type(env: &Env, flat_type: &FlatType, subs: &Subs, buf: &mut String, parens: Parens) { use crate::subs::FlatType::*; match flat_type { - Apply(symbol, args) => write_apply(env, symbol, args, subs, buf, parens), + Apply(symbol, args) => write_apply(env, *symbol, args, subs, buf, parens), EmptyRecord => buf.push_str(EMPTY_RECORD), EmptyTagUnion => buf.push_str(EMPTY_TAG_UNION), - Func(args, _closure, ret) => write_fn(env, args, ret, subs, buf, parens), + Func(args, _closure, ret) => write_fn(env, args, *ret, subs, buf, parens), Record(fields, ext_var) => { use crate::types::{gather_fields, RecordStructure}; // If the `ext` has concrete fields (e.g. { foo : I64}{ bar : Bool }), merge them - let RecordStructure { fields, ext } = gather_fields(subs, fields, ext_var); + let RecordStructure { fields, ext } = gather_fields(subs, fields, *ext_var); let ext_var = ext; if fields.is_empty() { @@ -408,7 +463,7 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String write_content( env, - subs.get_without_compacting(var).content, + subs.get_content_without_compacting(var), subs, buf, Parens::Unnecessary, @@ -418,7 +473,7 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String buf.push_str(" }"); } - match subs.get_without_compacting(ext_var).content { + match subs.get_content_without_compacting(ext_var) { Content::Structure(EmptyRecord) => { // This is a closed record. We're done! } @@ -433,50 +488,9 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String } } TagUnion(tags, ext_var) => { - let interns = &env.interns; - let home = env.home; - buf.push_str("[ "); - // Sort the fields so they always end up in the same order. - let mut sorted_fields = Vec::with_capacity(tags.len()); - - for (label, vars) in tags { - sorted_fields.push((label.clone(), vars)); - } - - // If the `ext` contains tags, merge them into the list of tags. - // this can occur when inferring mutually recursive tags - let ext_content = chase_ext_tag_union(subs, ext_var, &mut sorted_fields); - - sorted_fields.sort_by(|(a, _), (b, _)| { - a.clone() - .as_string(interns, home) - .cmp(&b.as_string(interns, home)) - }); - - let mut any_written_yet = false; - - for (label, vars) in sorted_fields { - if any_written_yet { - buf.push_str(", "); - } else { - any_written_yet = true; - } - - buf.push_str(&label.as_string(interns, home)); - - for var in vars { - buf.push(' '); - write_content( - env, - subs.get_without_compacting(var).content, - subs, - buf, - Parens::InTypeParam, - ); - } - } + let ext_content = write_sorted_tags(env, subs, buf, tags, *ext_var); buf.push_str(" ]"); @@ -491,17 +505,14 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String } FunctionOrTagUnion(tag_name, _, ext_var) => { - let interns = &env.interns; - let home = env.home; - buf.push_str("[ "); - buf.push_str(&tag_name.as_string(interns, home)); + let mut tags = MutMap::default(); + tags.insert(tag_name.clone(), vec![]); + let ext_content = write_sorted_tags(env, subs, buf, &tags, *ext_var); buf.push_str(" ]"); - let mut sorted_fields = vec![(tag_name, vec![])]; - let ext_content = chase_ext_tag_union(subs, ext_var, &mut sorted_fields); if let Err((_, content)) = ext_content { // This is an open tag union, so print the variable // right after the ']' @@ -513,45 +524,9 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String } RecursiveTagUnion(rec_var, tags, ext_var) => { - let interns = &env.interns; - let home = env.home; - buf.push_str("[ "); - // Sort the fields so they always end up in the same order. - let mut sorted_fields = Vec::with_capacity(tags.len()); - - for (label, vars) in tags { - sorted_fields.push((label.clone(), vars)); - } - - // If the `ext` contains tags, merge them into the list of tags. - // this can occur when inferring mutually recursive tags - let ext_content = chase_ext_tag_union(subs, ext_var, &mut sorted_fields); - - sorted_fields.sort_by(|(a, _), (b, _)| a.cmp(b)); - - let mut any_written_yet = false; - - for (label, vars) in sorted_fields { - if any_written_yet { - buf.push_str(", "); - } else { - any_written_yet = true; - } - buf.push_str(&label.as_string(interns, home)); - - for var in vars { - buf.push(' '); - write_content( - env, - subs.get_without_compacting(var).content, - subs, - buf, - Parens::InTypeParam, - ); - } - } + let ext_content = write_sorted_tags(env, subs, buf, tags, *ext_var); buf.push_str(" ]"); @@ -567,7 +542,7 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String buf.push_str(" as "); write_content( env, - subs.get_without_compacting(rec_var).content, + subs.get_content_without_compacting(*rec_var), subs, buf, parens, @@ -579,11 +554,11 @@ fn write_flat_type(env: &Env, flat_type: FlatType, subs: &Subs, buf: &mut String } } -pub fn chase_ext_tag_union( - subs: &Subs, +pub fn chase_ext_tag_union<'a>( + subs: &'a Subs, var: Variable, fields: &mut Vec<(TagName, Vec)>, -) -> Result<(), (Variable, Content)> { +) -> Result<(), (Variable, &'a Content)> { use FlatType::*; match subs.get_content_without_compacting(var) { Content::Structure(EmptyTagUnion) => Ok(()), @@ -602,7 +577,7 @@ pub fn chase_ext_tag_union( } Content::Alias(_, _, _, var) => chase_ext_tag_union(subs, *var, fields), - content => Err((var, content.clone())), + content => Err((var, content)), } } @@ -614,25 +589,27 @@ pub fn chase_ext_record( use crate::subs::Content::*; use crate::subs::FlatType::*; - match subs.get_without_compacting(var).content { + match subs.get_content_without_compacting(var) { Structure(Record(sub_fields, sub_ext)) => { - fields.extend(sub_fields.into_iter()); + for (field_name, record_field) in sub_fields { + fields.insert(field_name.clone(), *record_field); + } - chase_ext_record(subs, sub_ext, fields) + chase_ext_record(subs, *sub_ext, fields) } Structure(EmptyRecord) => Ok(()), - Alias(_, _, _, var) => chase_ext_record(subs, var, fields), + Alias(_, _, _, var) => chase_ext_record(subs, *var, fields), - content => Err((var, content)), + content => Err((var, content.clone())), } } fn write_apply( env: &Env, symbol: Symbol, - args: Vec, + args: &[Variable], subs: &Subs, buf: &mut String, parens: Parens, @@ -646,10 +623,10 @@ fn write_apply( } Symbol::NUM_NUM => { let arg = args - .into_iter() + .iter() .next() .unwrap_or_else(|| panic!("Num did not have any type parameters somehow.")); - let arg_content = subs.get_without_compacting(arg).content; + let arg_content = subs.get_content_without_compacting(*arg); let mut arg_param = String::new(); let mut default_case = |subs, content| { @@ -690,7 +667,7 @@ fn write_apply( buf.push(' '); write_content( env, - subs.get_without_compacting(arg).content, + subs.get_content_without_compacting(*arg), subs, buf, Parens::InTypeParam, @@ -706,7 +683,7 @@ fn write_apply( fn write_fn( env: &Env, - args: Vec, + args: &[Variable], ret: Variable, subs: &Subs, buf: &mut String, @@ -728,7 +705,7 @@ fn write_fn( write_content( env, - subs.get_without_compacting(arg).content, + subs.get_content_without_compacting(*arg), subs, buf, Parens::InFn, @@ -738,7 +715,7 @@ fn write_fn( buf.push_str(" -> "); write_content( env, - subs.get_without_compacting(ret).content, + subs.get_content_without_compacting(ret), subs, buf, Parens::InFn, diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index 098329b826..a9d752e32c 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -990,7 +990,7 @@ impl Content { eprintln!( "{}", - crate::pretty_print::content_to_string(self.clone(), subs, home, &interns) + crate::pretty_print::content_to_string(&self, subs, home, &interns) ); self @@ -1031,7 +1031,7 @@ fn occurs( if seen.contains(&root_var) { Some((root_var, vec![])) } else { - match subs.get_without_compacting(root_var).content { + match subs.get_content_without_compacting(root_var) { FlexVar(_) | RigidVar(_) | RecursionVar { .. } | Error => None, Structure(flat_type) => { @@ -1042,14 +1042,14 @@ fn occurs( match flat_type { Apply(_, args) => short_circuit(subs, root_var, &new_seen, args.iter()), Func(arg_vars, closure_var, ret_var) => { - let it = once(&ret_var) - .chain(once(&closure_var)) + let it = once(ret_var) + .chain(once(closure_var)) .chain(arg_vars.iter()); short_circuit(subs, root_var, &new_seen, it) } Record(vars_by_field, ext_var) => { let it = - once(&ext_var).chain(vars_by_field.values().map(|field| match field { + once(ext_var).chain(vars_by_field.values().map(|field| match field { RecordField::Optional(var) => var, RecordField::Required(var) => var, RecordField::Demanded(var) => var, @@ -1057,22 +1057,22 @@ fn occurs( short_circuit(subs, root_var, &new_seen, it) } TagUnion(tags, ext_var) => { - let it = once(&ext_var).chain(tags.values().flatten()); + let it = once(ext_var).chain(tags.values().flatten()); short_circuit(subs, root_var, &new_seen, it) } FunctionOrTagUnion(_, _, ext_var) => { - let it = once(&ext_var); + let it = once(ext_var); short_circuit(subs, root_var, &new_seen, it) } RecursiveTagUnion(_rec_var, tags, ext_var) => { // TODO rec_var is excluded here, verify that this is correct - let it = once(&ext_var).chain(tags.values().flatten()); + let it = once(ext_var).chain(tags.values().flatten()); short_circuit(subs, root_var, &new_seen, it) } EmptyRecord | EmptyTagUnion | Erroneous(_) => None, } } - Alias(symbol, type_arguments, lambda_set_variables, _) => { + Alias(_, type_arguments, lambda_set_variables, _) => { let mut new_seen = seen.clone(); new_seen.insert(root_var); diff --git a/compiler/types/src/types.rs b/compiler/types/src/types.rs index d46e142684..eb6202db28 100644 --- a/compiler/types/src/types.rs +++ b/compiler/types/src/types.rs @@ -1,7 +1,7 @@ use crate::pretty_print::Parens; use crate::subs::{Subs, VarStore, Variable}; use inlinable_string::InlinableString; -use roc_collections::all::{union, ImMap, ImSet, Index, MutMap, MutSet, SendMap}; +use roc_collections::all::{ImMap, ImSet, Index, MutMap, MutSet, SendMap}; use roc_module::ident::{ForeignSymbol, Ident, Lowercase, TagName}; use roc_module::low_level::LowLevel; use roc_module::symbol::{Interns, ModuleId, Symbol}; @@ -1595,22 +1595,35 @@ pub fn name_type_var(letters_used: u32, taken: &mut MutSet) -> (Lower pub fn gather_fields( subs: &Subs, - fields: MutMap>, - var: Variable, + other_fields: &MutMap>, + mut var: Variable, ) -> RecordStructure { use crate::subs::Content::*; use crate::subs::FlatType::*; - match subs.get_without_compacting(var).content { - Structure(Record(sub_fields, sub_ext)) => { - gather_fields(subs, union(fields, &sub_fields), sub_ext) - } + let mut result = other_fields.clone(); - Alias(_, _, _, var) => { - // TODO according to elm/compiler: "TODO may be dropping useful alias info here" - gather_fields(subs, fields, var) - } + loop { + match subs.get_content_without_compacting(var) { + Structure(Record(sub_fields, sub_ext)) => { + for (lowercase, record_field) in sub_fields { + result.insert(lowercase.clone(), *record_field); + } - _ => RecordStructure { fields, ext: var }, + var = *sub_ext; + } + + Alias(_, _, _, actual_var) => { + // TODO according to elm/compiler: "TODO may be dropping useful alias info here" + var = *actual_var; + } + + _ => break, + } + } + + RecordStructure { + fields: result, + ext: var, } } diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index 9520bc5e5e..af17682d13 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -995,8 +995,8 @@ fn unify_flat_type( } (Record(fields1, ext1), Record(fields2, ext2)) => { - let rec1 = gather_fields(subs, fields1.clone(), *ext1); - let rec2 = gather_fields(subs, fields2.clone(), *ext2); + let rec1 = gather_fields(subs, fields1, *ext1); + let rec2 = gather_fields(subs, fields2, *ext2); unify_record(subs, pool, ctx, rec1, rec2) } diff --git a/editor/src/editor/mvc/ed_update.rs b/editor/src/editor/mvc/ed_update.rs index 23bba5bdbe..9fd9caee31 100644 --- a/editor/src/editor/mvc/ed_update.rs +++ b/editor/src/editor/mvc/ed_update.rs @@ -272,7 +272,7 @@ impl<'a> EdModel<'a> { let subs = solved.inner_mut(); - let content = subs.get(var).content; + let content = subs.get_content_without_compacting(var); PoolStr::new( &content_to_string(content, subs, self.module.env.home, self.interns), diff --git a/editor/src/lang/constrain.rs b/editor/src/lang/constrain.rs index 8350a2db28..0dcd47965f 100644 --- a/editor/src/lang/constrain.rs +++ b/editor/src/lang/constrain.rs @@ -57,7 +57,7 @@ pub fn constrain_expr<'a>( Expr2::Blank => True, Expr2::EmptyRecord => constrain_empty_record(expected, region), Expr2::Var(symbol) => Lookup(*symbol, expected, region), - Expr2::SmallInt { var, .. } => { + Expr2::SmallInt { var, .. } | Expr2::I128 { var, .. } | Expr2::U128 { var, .. } => { let mut flex_vars = BumpVec::with_capacity_in(1, arena); flex_vars.push(*var); @@ -910,7 +910,57 @@ pub fn constrain_expr<'a>( exists(arena, flex_vars, And(cons)) } - _ => todo!("implement constraints for {:?}", expr), + + Expr2::RunLowLevel { op, args, ret_var } => { + // This is a modified version of what we do for function calls. + + // The operation's return type + let ret_type = Type2::Variable(*ret_var); + + // This will be used in the occurs check + let mut vars = BumpVec::with_capacity_in(1 + args.len(), arena); + + vars.push(*ret_var); + + let mut arg_types = BumpVec::with_capacity_in(args.len(), arena); + let mut arg_cons = BumpVec::with_capacity_in(args.len(), arena); + + for (index, node_id) in args.iter_node_ids().enumerate() { + let (arg_var, arg_id) = env.pool.get(node_id); + + vars.push(*arg_var); + + let arg_type = Type2::Variable(*arg_var); + + let reason = Reason::LowLevelOpArg { + op: *op, + arg_index: Index::zero_based(index), + }; + let expected_arg = + Expected::ForReason(reason, arg_type.shallow_clone(), Region::zero()); + let arg = env.pool.get(*arg_id); + + let arg_con = constrain_expr(arena, env, arg, expected_arg, Region::zero()); + + arg_types.push(arg_type); + arg_cons.push(arg_con); + } + + let category = Category::LowLevelOpResult(*op); + + let mut and_constraints = BumpVec::with_capacity_in(2, arena); + + and_constraints.push(And(arg_cons)); + and_constraints.push(Eq(ret_type, expected, category, region)); + + exists(arena, vars, And(and_constraints)) + } + Expr2::RuntimeError() => True, + Expr2::Closure { .. } => todo!(), + Expr2::PrivateTag { .. } => todo!(), + Expr2::InvalidLookup(_) => todo!(), + Expr2::LetRec { .. } => todo!(), + Expr2::LetFunction { .. } => todo!(), } } diff --git a/editor/src/lang/solve.rs b/editor/src/lang/solve.rs index bee5d2a07a..a79ed35dfc 100644 --- a/editor/src/lang/solve.rs +++ b/editor/src/lang/solve.rs @@ -565,11 +565,10 @@ fn solve<'a>( .get(next_rank) .iter() .filter(|var| { - let current = subs.get_without_compacting( - roc_types::subs::Variable::clone(var), - ); + let current_rank = + subs.get_rank(roc_types::subs::Variable::clone(var)); - current.rank.into_usize() > next_rank.into_usize() + current_rank.into_usize() > next_rank.into_usize() }) .collect::>(); @@ -598,8 +597,7 @@ fn solve<'a>( let failing: Vec<_> = rigid_vars .iter() .filter(|&var| { - !subs.redundant(*var) - && subs.get_without_compacting(*var).rank != Rank::NONE + !subs.redundant(*var) && subs.get_rank(*var) != Rank::NONE }) .collect(); diff --git a/editor/tests/solve_expr2.rs b/editor/tests/solve_expr2.rs index 0850a6af18..f4f0c4c313 100644 --- a/editor/tests/solve_expr2.rs +++ b/editor/tests/solve_expr2.rs @@ -114,7 +114,7 @@ fn infer_eq(actual: &str, expected_str: &str) { let subs = solved.inner_mut(); - let content = subs.get(var).content; + let content = subs.get_content_without_compacting(var); let interns = Interns { module_ids: env.module_ids.clone(), @@ -328,3 +328,16 @@ fn constrain_update() { "{ name : Str }", ) } + +#[ignore = "TODO: implement builtins in the editor"] +#[test] +fn constrain_run_low_level() { + infer_eq( + indoc!( + r#" + List.map [ { name: "roc" }, { name: "bird" } ] .name + "# + ), + "List Str", + ) +}