Merge pull request #5215 from roc-lang/fix-non-nullable-unwrapped-recursive-lset

Fix compilation problems with recursive lambda sets
This commit is contained in:
Folkert de Vries 2023-03-27 20:00:16 +02:00 committed by GitHub
commit 53851f5738
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 382 additions and 23 deletions

View File

@ -35,7 +35,10 @@ pub fn pretty_print_declarations(c: &Ctx, declarations: &Declarations) -> String
DeclarationTag::Expectation => todo!(),
DeclarationTag::ExpectationFx => todo!(),
DeclarationTag::Destructure(_) => todo!(),
DeclarationTag::MutualRecursion { .. } => todo!(),
DeclarationTag::MutualRecursion { .. } => {
// the defs will be printed next
continue;
}
};
defs.push(def);
@ -124,8 +127,9 @@ fn toplevel_function<'a>(
.append(f.text("\\"))
.append(f.intersperse(args, f.text(", ")))
.append(f.text(" ->"))
.group()
.append(f.line())
.append(expr(c, EPrec::Free, f, body))
.append(expr(c, EPrec::Free, f, body).group())
.nest(2)
.group()
}

View File

@ -3506,6 +3506,7 @@ fn specialize_proc_help<'a>(
UnionLayout::NonRecursive(_)
| UnionLayout::Recursive(_)
| UnionLayout::NullableUnwrapped { .. }
| UnionLayout::NullableWrapped { .. }
));
debug_assert_eq!(field_layouts.len(), captured.len());

View File

@ -1621,11 +1621,27 @@ impl<'a> LambdaSet<'a> {
union_layout: union,
}
}
UnionLayout::NonNullableUnwrapped(_) => todo!("recursive closures"),
UnionLayout::NullableWrapped {
nullable_id: _,
other_tags: _,
} => todo!("recursive closures"),
} => {
let (index, (name, fields)) = self
.set
.iter()
.enumerate()
.find(|(_, (s, layouts))| comparator(*s, layouts))
.unwrap();
let closure_name = *name;
ClosureRepresentation::Union {
tag_id: index as TagIdIntType,
alphabetic_order_fields: fields,
closure_name,
union_layout: union,
}
}
UnionLayout::NonNullableUnwrapped(_) => internal_error!("I thought a non-nullable-unwrapped variant for a lambda set was impossible: how could such a lambda set be created without a base case?"),
}
}
Layout::Struct { .. } => {

View File

@ -8612,8 +8612,7 @@ mod solve_expr {
map = \simpleParser, transform -> apply \{} -[12]-> transform simpleParser
parseInput =
\{}->
parseInput = \{} ->
when [
map v1 \{} -[13]-> "",
map v2 \s -[14]-> s,

View File

@ -3911,6 +3911,34 @@ fn compose_recursive_lambda_set_productive_inferred() {
);
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn compose_recursive_lambda_set_productive_nullable_wrapped() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [main] to "./platform"
compose = \forward -> \f, g ->
if forward
then \x -> g (f x)
else \x -> f (g x)
identity = \x -> x
exclame = \s -> "\(s)!"
whisper = \s -> "(\(s))"
main =
res: Str -> Str
res = List.walk [ exclame, whisper ] identity (compose Bool.false)
res "hello"
"#
),
RocStr::from("(hello)!"),
RocStr
)
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn local_binding_aliases_function() {
@ -4335,3 +4363,26 @@ fn when_guard_appears_multiple_times_in_compiled_decision_tree_issue_5176() {
u8
)
}
#[test]
#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))]
fn recursive_lambda_set_resolved_only_upon_specialization() {
assert_evals_to!(
indoc!(
r#"
app "test" provides [main] to "./platform"
factCPS = \n, cont ->
if n == 0 then
cont 1
else
factCPS (n - 1) \value -> cont (n * value)
main =
factCPS 5u64 \x -> x
"#
),
120,
u64
);
}

View File

@ -0,0 +1,176 @@
procedure Bool.2 ():
let Bool.23 : Int1 = true;
ret Bool.23;
procedure List.139 (List.140, List.141, List.138):
let List.513 : [<rnw><null>, C *self Int1, C *self Int1] = CallByName Test.6 List.140 List.141 List.138;
ret List.513;
procedure List.18 (List.136, List.137, List.138):
let List.494 : [<rnw><null>, C *self Int1, C *self Int1] = CallByName List.92 List.136 List.137 List.138;
ret List.494;
procedure List.6 (#Attr.2):
let List.511 : U64 = lowlevel ListLen #Attr.2;
ret List.511;
procedure List.66 (#Attr.2, #Attr.3):
let List.510 : Int1 = lowlevel ListGetUnsafe #Attr.2 #Attr.3;
ret List.510;
procedure List.80 (List.517, List.518, List.519, List.520, List.521):
joinpoint List.500 List.433 List.434 List.435 List.436 List.437:
let List.502 : Int1 = CallByName Num.22 List.436 List.437;
if List.502 then
let List.509 : Int1 = CallByName List.66 List.433 List.436;
let List.503 : [<rnw><null>, C *self Int1, C *self Int1] = CallByName List.139 List.434 List.509 List.435;
let List.506 : U64 = 1i64;
let List.505 : U64 = CallByName Num.19 List.436 List.506;
jump List.500 List.433 List.503 List.435 List.505 List.437;
else
ret List.434;
in
jump List.500 List.517 List.518 List.519 List.520 List.521;
procedure List.92 (List.430, List.431, List.432):
let List.498 : U64 = 0i64;
let List.499 : U64 = CallByName List.6 List.430;
let List.497 : [<rnw><null>, C *self Int1, C *self Int1] = CallByName List.80 List.430 List.431 List.432 List.498 List.499;
ret List.497;
procedure Num.19 (#Attr.2, #Attr.3):
let Num.275 : U64 = lowlevel NumAdd #Attr.2 #Attr.3;
ret Num.275;
procedure Num.22 (#Attr.2, #Attr.3):
let Num.276 : Int1 = lowlevel NumLt #Attr.2 #Attr.3;
ret Num.276;
procedure Str.3 (#Attr.2, #Attr.3):
let Str.268 : Str = lowlevel StrConcat #Attr.2 #Attr.3;
ret Str.268;
procedure Test.1 (Test.5):
ret Test.5;
procedure Test.11 (Test.53, Test.54):
joinpoint Test.27 Test.12 #Attr.12:
let Test.8 : Int1 = UnionAtIndex (Id 2) (Index 1) #Attr.12;
let Test.7 : [<rnw><null>, C *self Int1, C *self Int1] = UnionAtIndex (Id 2) (Index 0) #Attr.12;
inc Test.7;
dec #Attr.12;
joinpoint Test.31 Test.29:
let Test.30 : U8 = GetTagId Test.7;
switch Test.30:
case 0:
dec Test.7;
let Test.28 : Str = CallByName Test.2 Test.29;
dec Test.29;
ret Test.28;
case 1:
let Test.28 : Str = CallByName Test.9 Test.29 Test.7;
ret Test.28;
default:
jump Test.27 Test.29 Test.7;
in
switch Test.8:
case 0:
let Test.32 : Str = CallByName Test.3 Test.12;
jump Test.31 Test.32;
default:
let Test.32 : Str = CallByName Test.4 Test.12;
jump Test.31 Test.32;
in
jump Test.27 Test.53 Test.54;
procedure Test.2 (Test.13):
inc Test.13;
ret Test.13;
procedure Test.3 (Test.14):
let Test.48 : Str = "!";
let Test.47 : Str = CallByName Str.3 Test.14 Test.48;
dec Test.48;
ret Test.47;
procedure Test.4 (Test.15):
let Test.44 : Str = "(";
let Test.46 : Str = ")";
let Test.45 : Str = CallByName Str.3 Test.15 Test.46;
dec Test.46;
let Test.43 : Str = CallByName Str.3 Test.44 Test.45;
dec Test.45;
ret Test.43;
procedure Test.6 (Test.7, Test.8, Test.5):
if Test.5 then
let Test.33 : [<rnw><null>, C *self Int1, C *self Int1] = TagId(1) Test.7 Test.8;
ret Test.33;
else
let Test.26 : [<rnw><null>, C *self Int1, C *self Int1] = TagId(2) Test.7 Test.8;
ret Test.26;
procedure Test.9 (Test.10, #Attr.12):
let Test.8 : Int1 = UnionAtIndex (Id 1) (Index 1) #Attr.12;
let Test.7 : [<rnw><null>, C *self Int1, C *self Int1] = UnionAtIndex (Id 1) (Index 0) #Attr.12;
inc Test.7;
dec #Attr.12;
let Test.37 : U8 = GetTagId Test.7;
joinpoint Test.38 Test.36:
switch Test.8:
case 0:
let Test.35 : Str = CallByName Test.3 Test.36;
ret Test.35;
default:
let Test.35 : Str = CallByName Test.4 Test.36;
ret Test.35;
in
switch Test.37:
case 0:
dec Test.7;
let Test.39 : Str = CallByName Test.2 Test.10;
dec Test.10;
jump Test.38 Test.39;
case 1:
let Test.39 : Str = CallByName Test.9 Test.10 Test.7;
jump Test.38 Test.39;
default:
let Test.39 : Str = CallByName Test.11 Test.10 Test.7;
jump Test.38 Test.39;
procedure Test.0 ():
let Test.41 : Int1 = false;
let Test.42 : Int1 = true;
let Test.20 : List Int1 = Array [Test.41, Test.42];
let Test.21 : [<rnw><null>, C *self Int1, C *self Int1] = TagId(0) ;
let Test.23 : Int1 = CallByName Bool.2;
let Test.22 : Int1 = CallByName Test.1 Test.23;
let Test.16 : [<rnw><null>, C *self Int1, C *self Int1] = CallByName List.18 Test.20 Test.21 Test.22;
dec Test.20;
let Test.18 : Str = "hello";
let Test.19 : U8 = GetTagId Test.16;
switch Test.19:
case 0:
dec Test.16;
let Test.17 : Str = CallByName Test.2 Test.18;
dec Test.18;
ret Test.17;
case 1:
let Test.17 : Str = CallByName Test.9 Test.18 Test.16;
ret Test.17;
default:
let Test.17 : Str = CallByName Test.11 Test.18 Test.16;
ret Test.17;

View File

@ -0,0 +1,65 @@
procedure Bool.11 (#Attr.2, #Attr.3):
let Bool.23 : Int1 = lowlevel Eq #Attr.2 #Attr.3;
ret Bool.23;
procedure Num.20 (#Attr.2, #Attr.3):
let Num.276 : U8 = lowlevel NumSub #Attr.2 #Attr.3;
ret Num.276;
procedure Num.21 (#Attr.2, #Attr.3):
let Num.275 : U8 = lowlevel NumMul #Attr.2 #Attr.3;
ret Num.275;
procedure Test.1 (Test.26, Test.27):
joinpoint Test.11 Test.2 Test.3:
let Test.24 : U8 = 0i64;
let Test.20 : Int1 = CallByName Bool.11 Test.2 Test.24;
if Test.20 then
let Test.22 : U8 = 1i64;
let Test.23 : U8 = GetTagId Test.3;
switch Test.23:
case 0:
let Test.21 : U8 = CallByName Test.4 Test.22 Test.3;
ret Test.21;
default:
dec Test.3;
let Test.21 : U8 = CallByName Test.6 Test.22;
ret Test.21;
else
let Test.19 : U8 = 1i64;
let Test.13 : U8 = CallByName Num.20 Test.2 Test.19;
let Test.14 : [<rnu><null>, C *self U8] = TagId(0) Test.3 Test.2;
jump Test.11 Test.13 Test.14;
in
jump Test.11 Test.26 Test.27;
procedure Test.4 (Test.28, Test.29):
joinpoint Test.15 Test.5 #Attr.12:
let Test.2 : U8 = UnionAtIndex (Id 0) (Index 1) #Attr.12;
let Test.3 : [<rnu><null>, C *self U8] = UnionAtIndex (Id 0) (Index 0) #Attr.12;
inc Test.3;
dec #Attr.12;
let Test.17 : U8 = CallByName Num.21 Test.2 Test.5;
let Test.18 : U8 = GetTagId Test.3;
switch Test.18:
case 0:
jump Test.15 Test.17 Test.3;
default:
dec Test.3;
let Test.16 : U8 = CallByName Test.6 Test.17;
ret Test.16;
in
jump Test.15 Test.28 Test.29;
procedure Test.6 (Test.7):
ret Test.7;
procedure Test.0 ():
let Test.9 : U8 = 5i64;
let Test.10 : [<rnu><null>, C *self U8] = TagId(1) ;
let Test.8 : U8 = CallByName Test.1 Test.9 Test.10;
ret Test.8;

View File

@ -2804,3 +2804,44 @@ fn when_guard_appears_multiple_times_in_compiled_decision_tree_issue_5176() {
"#
)
}
#[mono_test]
fn recursive_lambda_set_resolved_only_upon_specialization() {
indoc!(
r#"
app "test" provides [main] to "./platform"
factCPS = \n, cont ->
if n == 0u8 then
cont 1u8
else
factCPS (n - 1) \value -> cont (n * value)
main =
factCPS 5 \x -> x
"#
)
}
#[mono_test]
fn compose_recursive_lambda_set_productive_nullable_wrapped() {
indoc!(
r#"
app "test" provides [main] to "./platform"
compose = \forward -> \f, g ->
if forward
then \x -> g (f x)
else \x -> f (g x)
identity = \x -> x
exclame = \s -> "\(s)!"
whisper = \s -> "(\(s))"
main =
res: Str -> Str
res = List.walk [ exclame, whisper ] identity (compose Bool.true)
res "hello"
"#
)
}

View File

@ -3778,30 +3778,38 @@ fn unify_recursion<M: MetaCollector>(
structure: Variable,
other: &Content,
) -> Outcome<M> {
if !matches!(other, RecursionVar { .. }) {
if env.seen_recursion_pair(ctx.first, ctx.second) {
return Default::default();
}
env.add_recursion_pair(ctx.first, ctx.second);
}
let outcome = match other {
RecursionVar {
opt_name: other_opt_name,
structure: _other_structure,
structure: other_structure,
} => {
// NOTE: structure and other_structure may not be unified yet, but will be
// we should not do that here, it would create an infinite loop!
// We haven't seen these two recursion vars yet, so go and unify their structures.
// We need to do this before we merge the two recursion vars, since the unification of
// the structures may be material.
let mut outcome = unify_pool(env, pool, structure, *other_structure, ctx.mode);
if !outcome.mismatches.is_empty() {
return outcome;
}
let name = (*opt_name).or(*other_opt_name);
merge(
let merge_outcome = merge(
env,
ctx,
RecursionVar {
opt_name: name,
structure,
},
)
);
outcome.union(merge_outcome);
outcome
}
Structure(_) => {
@ -3863,9 +3871,7 @@ fn unify_recursion<M: MetaCollector>(
Error => merge(env, ctx, Error),
};
if !matches!(other, RecursionVar { .. }) {
env.remove_recursion_pair(ctx.first, ctx.second);
}
outcome
}