diff --git a/compiler/can/src/annotation.rs b/compiler/can/src/annotation.rs index a78ed60cb9..5cd1b22164 100644 --- a/compiler/can/src/annotation.rs +++ b/compiler/can/src/annotation.rs @@ -357,7 +357,7 @@ fn can_annotation_help( actual: Box::new(actual), } } - None => Type::Apply(symbol, args), + None => Type::Apply(symbol, args, region), } } BoundVariable(v) => { @@ -377,7 +377,8 @@ fn can_annotation_help( As( loc_inner, _spaces, - AliasHeader { + alias_header + @ AliasHeader { name, vars: loc_vars, }, @@ -439,20 +440,43 @@ fn can_annotation_help( } } + let alias_args = vars.iter().map(|(_, v)| v.clone()).collect::>(); + let alias_actual = if let Type::TagUnion(tags, ext) = inner_type { let rec_var = var_store.fresh(); let mut new_tags = Vec::with_capacity(tags.len()); + let mut is_nested_datatype = false; for (tag_name, args) in tags { let mut new_args = Vec::with_capacity(args.len()); for arg in args { let mut new_arg = arg.clone(); - new_arg.substitute_alias(symbol, &Type::Variable(rec_var)); + let substitution_result = + new_arg.substitute_alias(symbol, &alias_args, &Type::Variable(rec_var)); + + if let Err(differing_recursion_region) = substitution_result { + env.problems + .push(roc_problem::can::Problem::NestedDatatype { + alias: symbol, + def_region: alias_header.region(), + differing_recursion_region, + }); + is_nested_datatype = true; + } + + // Either way, add the argument; not doing so would only result in more + // confusing error messages later on. new_args.push(new_arg); } new_tags.push((tag_name.clone(), new_args)); } - Type::RecursiveTagUnion(rec_var, new_tags, ext) + if is_nested_datatype { + // We don't have a way to represent nested data types; hence, we don't actually + // use the recursion var in them, and should avoid marking them as such. + Type::TagUnion(new_tags, ext) + } else { + Type::RecursiveTagUnion(rec_var, new_tags, ext) + } } else { inner_type }; diff --git a/compiler/can/src/def.rs b/compiler/can/src/def.rs index 43e7af2e95..42015839ec 100644 --- a/compiler/can/src/def.rs +++ b/compiler/can/src/def.rs @@ -277,7 +277,7 @@ pub fn canonicalize_defs<'a>( let mut can_vars: Vec> = Vec::with_capacity(vars.len()); let mut is_phantom = false; - for loc_lowercase in vars { + for loc_lowercase in vars.iter() { if let Some(var) = can_ann .introduced_variables .var_by_name(&loc_lowercase.value) @@ -303,10 +303,18 @@ pub fn canonicalize_defs<'a>( continue; } + let mut is_nested_datatype = false; if can_ann.typ.contains_symbol(symbol) { - make_tag_union_recursive( + let alias_args = can_vars + .iter() + .map(|l| (l.value.0.clone(), Type::Variable(l.value.1))) + .collect::>(); + let alias_region = + Region::across_all([name.region].iter().chain(vars.iter().map(|l| &l.region))); + + let made_recursive = make_tag_union_recursive( env, - symbol, + Loc::at(alias_region, (symbol, &alias_args)), name.region, vec![], &mut can_ann.typ, @@ -315,6 +323,13 @@ pub fn canonicalize_defs<'a>( // recursion errors after the sorted introductions are complete. &mut false, ); + + is_nested_datatype = made_recursive.is_err(); + } + + if is_nested_datatype { + // Bail out + continue; } scope.add_alias(symbol, name.region, can_vars.clone(), can_ann.typ.clone()); @@ -1624,9 +1639,16 @@ fn correct_mutual_recursive_type_alias<'a>( var_store, &mut ImSet::default(), ); - make_tag_union_recursive( + + let alias_args = &alias + .type_variables + .iter() + .map(|l| (l.value.0.clone(), Type::Variable(l.value.1))) + .collect::>(); + + let _made_recursive = make_tag_union_recursive( env, - *rec, + Loc::at(alias.header_region(), (*rec, &alias_args)), alias.region, others, &mut alias.typ, @@ -1640,25 +1662,71 @@ fn correct_mutual_recursive_type_alias<'a>( } } +/// Attempt to make a tag union recursive at the position of `recursive_alias`; for example, +/// +/// ```roc +/// [ Cons a (ConsList a), Nil ] as ConsList a +/// ``` +/// +/// can be made recursive at the position "ConsList a" with a fresh recursive variable, say r1: +/// +/// ```roc +/// [ Cons a r1, Nil ] as r1 +/// ``` +/// +/// Returns `Err` if the tag union is recursive, but there is no structure-preserving recursion +/// variable for it. This can happen when the type is a nested datatype, for example in either of +/// +/// ```roc +/// Nested a : [ Chain a (Nested (List a)), Term ] +/// DuoList a b : [ Cons a (DuoList b a), Nil ] +/// ``` +/// +/// When `Err` is returned, a problem will be added to `env`. fn make_tag_union_recursive<'a>( env: &mut Env<'a>, - symbol: Symbol, + recursive_alias: Loc<(Symbol, &[(Lowercase, Type)])>, region: Region, others: Vec, typ: &mut Type, var_store: &mut VarStore, can_report_error: &mut bool, -) { +) -> Result<(), ()> { + let Loc { + value: (symbol, args), + region: alias_region, + } = recursive_alias; + let vars = args.iter().map(|(_, t)| t.clone()).collect::>(); match typ { Type::TagUnion(tags, ext) => { let rec_var = var_store.fresh(); - *typ = Type::RecursiveTagUnion(rec_var, tags.to_vec(), ext.clone()); - typ.substitute_alias(symbol, &Type::Variable(rec_var)); + let mut pending_typ = Type::RecursiveTagUnion(rec_var, tags.to_vec(), ext.clone()); + let substitution_result = + pending_typ.substitute_alias(symbol, &vars, &Type::Variable(rec_var)); + match substitution_result { + Ok(()) => { + // We can substitute the alias presence for the variable exactly. + *typ = pending_typ; + Ok(()) + } + Err(differing_recursion_region) => { + env.problems.push(Problem::NestedDatatype { + alias: symbol, + def_region: alias_region, + differing_recursion_region, + }); + Err(()) + } + } } - Type::RecursiveTagUnion(_, _, _) => {} - Type::Alias { actual, .. } => make_tag_union_recursive( + Type::RecursiveTagUnion(_, _, _) => Ok(()), + Type::Alias { + actual, + type_arguments, + .. + } => make_tag_union_recursive( env, - symbol, + Loc::at_zero((symbol, &type_arguments)), region, others, actual, @@ -1676,6 +1744,7 @@ fn make_tag_union_recursive<'a>( let problem = Problem::CyclicAlias(symbol, region, others); env.problems.push(problem); } + Ok(()) } } } diff --git a/compiler/constrain/src/builtins.rs b/compiler/constrain/src/builtins.rs index c47a1e5796..8fe6bfeece 100644 --- a/compiler/constrain/src/builtins.rs +++ b/compiler/constrain/src/builtins.rs @@ -71,7 +71,7 @@ pub fn exists(flex_vars: Vec, constraint: Constraint) -> Constraint { #[inline(always)] pub fn builtin_type(symbol: Symbol, args: Vec) -> Type { - Type::Apply(symbol, args) + Type::Apply(symbol, args, Region::zero()) } #[inline(always)] diff --git a/compiler/parse/src/ast.rs b/compiler/parse/src/ast.rs index 879d62e98d..3fc72c72fd 100644 --- a/compiler/parse/src/ast.rs +++ b/compiler/parse/src/ast.rs @@ -232,6 +232,16 @@ pub struct AliasHeader<'a> { pub vars: &'a [Loc>], } +impl<'a> AliasHeader<'a> { + pub fn region(&self) -> Region { + Region::across_all( + [self.name.region] + .iter() + .chain(self.vars.iter().map(|v| &v.region)), + ) + } +} + #[derive(Debug, Clone, Copy, PartialEq)] pub enum Def<'a> { // TODO in canonicalization, validate the pattern; only certain patterns diff --git a/compiler/problem/src/can.rs b/compiler/problem/src/can.rs index d9e910becb..e2dce07a9a 100644 --- a/compiler/problem/src/can.rs +++ b/compiler/problem/src/can.rs @@ -78,6 +78,11 @@ pub enum Problem { InvalidInterpolation(Region), InvalidHexadecimal(Region), InvalidUnicodeCodePt(Region), + NestedDatatype { + alias: Symbol, + def_region: Region, + differing_recursion_region: Region, + }, } #[derive(Clone, Debug, PartialEq)] diff --git a/compiler/types/src/solved_types.rs b/compiler/types/src/solved_types.rs index 92282b0e55..2174ace2f0 100644 --- a/compiler/types/src/solved_types.rs +++ b/compiler/types/src/solved_types.rs @@ -85,7 +85,7 @@ impl SolvedType { match typ { EmptyRec => SolvedType::EmptyRecord, EmptyTagUnion => SolvedType::EmptyTagUnion, - Apply(symbol, types) => { + Apply(symbol, types, _) => { let mut solved_types = Vec::with_capacity(types.len()); for typ in types { @@ -454,7 +454,7 @@ pub fn to_type( new_args.push(to_type(arg, free_vars, var_store)); } - Type::Apply(*symbol, new_args) + Type::Apply(*symbol, new_args, Region::zero()) } Rigid(lowercase) => { if let Some(var) = free_vars.named_vars.get(lowercase) { diff --git a/compiler/types/src/types.rs b/compiler/types/src/types.rs index 9087ac80d6..648f1089f6 100644 --- a/compiler/types/src/types.rs +++ b/compiler/types/src/types.rs @@ -94,13 +94,18 @@ impl RecordField { } } - pub fn substitute_alias(&mut self, rep_symbol: Symbol, actual: &Type) { + pub fn substitute_alias( + &mut self, + rep_symbol: Symbol, + rep_args: &[Type], + actual: &Type, + ) -> Result<(), Region> { use RecordField::*; match self { - Optional(typ) => typ.substitute_alias(rep_symbol, actual), - Required(typ) => typ.substitute_alias(rep_symbol, actual), - Demanded(typ) => typ.substitute_alias(rep_symbol, actual), + Optional(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), + Required(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), + Demanded(typ) => typ.substitute_alias(rep_symbol, rep_args, actual), } } @@ -189,7 +194,7 @@ pub enum Type { }, RecursiveTagUnion(Variable, Vec<(TagName, Vec)>, Box), /// Applying a type to some arguments (e.g. Dict.Dict String Int) - Apply(Symbol, Vec), + Apply(Symbol, Vec, Region), Variable(Variable), /// A type error, which will code gen to a runtime error Erroneous(Problem), @@ -220,7 +225,7 @@ impl fmt::Debug for Type { } Type::Variable(var) => write!(f, "<{:?}>", var), - Type::Apply(symbol, args) => { + Type::Apply(symbol, args, _) => { write!(f, "({:?}", symbol)?; for arg in args { @@ -539,7 +544,7 @@ impl Type { } actual_type.substitute(substitutions); } - Apply(_, args) => { + Apply(_, args, _) => { for arg in args { arg.substitute(substitutions); } @@ -549,62 +554,69 @@ impl Type { } } - // swap Apply with Alias if their module and tag match - pub fn substitute_alias(&mut self, rep_symbol: Symbol, actual: &Type) { + /// Swap Apply(rep_symbol, rep_args) with `actual`. Returns `Err` if there is an + /// `Apply(rep_symbol, _)`, but the args don't match. + pub fn substitute_alias( + &mut self, + rep_symbol: Symbol, + rep_args: &[Type], + actual: &Type, + ) -> Result<(), Region> { use Type::*; match self { Function(args, closure, ret) => { for arg in args { - arg.substitute_alias(rep_symbol, actual); + arg.substitute_alias(rep_symbol, rep_args, actual)?; } - closure.substitute_alias(rep_symbol, actual); - ret.substitute_alias(rep_symbol, actual); - } - FunctionOrTagUnion(_, _, ext) => { - ext.substitute_alias(rep_symbol, actual); + closure.substitute_alias(rep_symbol, rep_args, actual)?; + ret.substitute_alias(rep_symbol, rep_args, actual) } + FunctionOrTagUnion(_, _, ext) => ext.substitute_alias(rep_symbol, rep_args, actual), RecursiveTagUnion(_, tags, ext) | TagUnion(tags, ext) => { for (_, args) in tags { for x in args { - x.substitute_alias(rep_symbol, actual); + x.substitute_alias(rep_symbol, rep_args, actual)?; } } - ext.substitute_alias(rep_symbol, actual); + ext.substitute_alias(rep_symbol, rep_args, actual) } Record(fields, ext) => { for (_, x) in fields.iter_mut() { - x.substitute_alias(rep_symbol, actual); + x.substitute_alias(rep_symbol, rep_args, actual)?; } - ext.substitute_alias(rep_symbol, actual); + ext.substitute_alias(rep_symbol, rep_args, actual) } Alias { actual: alias_actual, .. - } => { - alias_actual.substitute_alias(rep_symbol, actual); - } + } => alias_actual.substitute_alias(rep_symbol, rep_args, actual), HostExposedAlias { actual: actual_type, .. - } => { - actual_type.substitute_alias(rep_symbol, actual); - } - Apply(symbol, _) if *symbol == rep_symbol => { - *self = actual.clone(); + } => actual_type.substitute_alias(rep_symbol, rep_args, actual), + Apply(symbol, args, region) if *symbol == rep_symbol => { + if args.len() == rep_args.len() + && args.iter().zip(rep_args.iter()).all(|(t1, t2)| t1 == t2) + { + *self = actual.clone(); - if let Apply(_, args) = self { - for arg in args { - arg.substitute_alias(rep_symbol, actual); + if let Apply(_, args, _) = self { + for arg in args { + arg.substitute_alias(rep_symbol, rep_args, actual)?; + } } + return Ok(()); } + Err(*region) } - Apply(_, args) => { + Apply(_, args, _) => { for arg in args { - arg.substitute_alias(rep_symbol, actual); + arg.substitute_alias(rep_symbol, rep_args, actual)?; } + Ok(()) } - EmptyRec | EmptyTagUnion | ClosureTag { .. } | Erroneous(_) | Variable(_) => {} + EmptyRec | EmptyTagUnion | ClosureTag { .. } | Erroneous(_) | Variable(_) => Ok(()), } } @@ -639,8 +651,8 @@ impl Type { HostExposedAlias { name, actual, .. } => { name == &rep_symbol || actual.contains_symbol(rep_symbol) } - Apply(symbol, _) if *symbol == rep_symbol => true, - Apply(_, args) => args.iter().any(|arg| arg.contains_symbol(rep_symbol)), + Apply(symbol, _, _) if *symbol == rep_symbol => true, + Apply(_, args, _) => args.iter().any(|arg| arg.contains_symbol(rep_symbol)), EmptyRec | EmptyTagUnion | ClosureTag { .. } | Erroneous(_) | Variable(_) => false, } } @@ -676,7 +688,7 @@ impl Type { .. } => actual_type.contains_variable(rep_variable), HostExposedAlias { actual, .. } => actual.contains_variable(rep_variable), - Apply(_, args) => args.iter().any(|arg| arg.contains_variable(rep_variable)), + Apply(_, args, _) => args.iter().any(|arg| arg.contains_variable(rep_variable)), EmptyRec | EmptyTagUnion | Erroneous(_) => false, } } @@ -753,7 +765,7 @@ impl Type { actual_type.instantiate_aliases(region, aliases, var_store, introduced); } - Apply(symbol, args) => { + Apply(symbol, args, _) => { if let Some(alias) = aliases.get(symbol) { if args.len() != alias.type_variables.len() { *self = Type::Erroneous(Problem::BadTypeArguments { @@ -882,7 +894,7 @@ fn symbols_help(tipe: &Type, accum: &mut ImSet) { accum.insert(*name); symbols_help(actual, accum); } - Apply(symbol, args) => { + Apply(symbol, args, _) => { accum.insert(*symbol); args.iter().for_each(|arg| symbols_help(arg, accum)); } @@ -967,7 +979,7 @@ fn variables_help(tipe: &Type, accum: &mut ImSet) { } variables_help(actual, accum); } - Apply(_, args) => { + Apply(_, args, _) => { for x in args { variables_help(x, accum); } @@ -1071,7 +1083,7 @@ fn variables_help_detailed(tipe: &Type, accum: &mut VariableDetail) { } variables_help_detailed(actual, accum); } - Apply(_, args) => { + Apply(_, args, _) => { for x in args { variables_help_detailed(x, accum); } @@ -1241,6 +1253,16 @@ pub struct Alias { pub typ: Type, } +impl Alias { + pub fn header_region(&self) -> Region { + Region::across_all( + [self.region] + .iter() + .chain(self.type_variables.iter().map(|tv| &tv.region)), + ) + } +} + #[derive(PartialEq, Eq, Debug, Clone, Hash)] pub enum Problem { CanonicalizationProblem, diff --git a/reporting/src/error/canonicalize.rs b/reporting/src/error/canonicalize.rs index dd64d2e612..ed0c9d97e1 100644 --- a/reporting/src/error/canonicalize.rs +++ b/reporting/src/error/canonicalize.rs @@ -24,6 +24,7 @@ const CIRCULAR_DEF: &str = "CIRCULAR DEFINITION"; const DUPLICATE_NAME: &str = "DUPLICATE NAME"; const VALUE_NOT_EXPOSED: &str = "NOT EXPOSED"; const MODULE_NOT_IMPORTED: &str = "MODULE NOT IMPORTED"; +const NESTED_DATATYPE: &str = "NESTED DATATYPE"; pub fn can_problem<'b>( alloc: &'b RocDocAllocator<'b>, @@ -437,6 +438,34 @@ pub fn can_problem<'b>( title = answer.1.to_string(); severity = Severity::RuntimeError; } + Problem::NestedDatatype { + alias, + def_region, + differing_recursion_region, + } => { + doc = alloc.stack(vec![ + alloc.concat(vec![ + alloc.symbol_unqualified(alias), + alloc.reflow(" is a nested datatype. Here is one recursive usage of it:"), + ]), + alloc.region(lines.convert_region(differing_recursion_region)), + alloc.concat(vec![ + alloc.reflow("But recursive usages of "), + alloc.symbol_unqualified(alias), + alloc.reflow(" must match its definition:"), + ]), + alloc.region(lines.convert_region(def_region)), + alloc.reflow("Nested datatypes are not supported in Roc."), + alloc.concat(vec![ + alloc.hint("Consider rewriting the definition of "), + alloc.symbol_unqualified(alias), + alloc.text(" to use the recursive type with the same arguments."), + ]), + ]); + + title = NESTED_DATATYPE.to_string(); + severity = Severity::RuntimeError; + } }; Report { diff --git a/reporting/tests/test_reporting.rs b/reporting/tests/test_reporting.rs index 4be661a40e..f736655cea 100644 --- a/reporting/tests/test_reporting.rs +++ b/reporting/tests/test_reporting.rs @@ -7238,4 +7238,70 @@ I need all branches in an `if` to have the same type! ), ) } + + #[test] + fn nested_datatype() { + report_problem_as( + indoc!( + r#" + Nested a : [ Chain a (Nested (List a)), Term ] + + s : Nested Str + + s + "# + ), + indoc!( + r#" + ── NESTED DATATYPE ───────────────────────────────────────────────────────────── + + `Nested` is a nested datatype. Here is one recursive usage of it: + + 1│ Nested a : [ Chain a (Nested (List a)), Term ] + ^^^^^^^^^^^^^^^ + + But recursive usages of `Nested` must match its definition: + + 1│ Nested a : [ Chain a (Nested (List a)), Term ] + ^^^^^^^^ + + Nested datatypes are not supported in Roc. + + Hint: Consider rewriting the definition of `Nested` to use the recursive type with the same arguments. + "# + ), + ) + } + + #[test] + fn nested_datatype_inline() { + report_problem_as( + indoc!( + r#" + f : {} -> [ Chain a (Nested (List a)), Term ] as Nested a + + f + "# + ), + indoc!( + r#" + ── NESTED DATATYPE ───────────────────────────────────────────────────────────── + + `Nested` is a nested datatype. Here is one recursive usage of it: + + 1│ f : {} -> [ Chain a (Nested (List a)), Term ] as Nested a + ^^^^^^^^^^^^^^^ + + But recursive usages of `Nested` must match its definition: + + 1│ f : {} -> [ Chain a (Nested (List a)), Term ] as Nested a + ^^^^^^^^ + + Nested datatypes are not supported in Roc. + + Hint: Consider rewriting the definition of `Nested` to use the recursive type with the same arguments. + "# + ), + ) + } }