diff --git a/src/term/transform/linearize_matches.rs b/src/term/transform/linearize_matches.rs index df83d148..7e40664c 100644 --- a/src/term/transform/linearize_matches.rs +++ b/src/term/transform/linearize_matches.rs @@ -209,7 +209,9 @@ fn fixed_and_linearized_terms(used_in_arg: HashSet, bind_terms: Vec) /// Get which binds are fixed because they are in the dependency graph /// of a free var or of a var used in the match arg. -fn binds_fixed_by_dependency(mut fixed_binds: HashSet, bind_terms: &[Term]) -> HashSet { +fn binds_fixed_by_dependency(used_in_arg: HashSet, bind_terms: &[Term]) -> HashSet { + let mut fixed_binds = used_in_arg; + // Find the use dependencies of each bind let mut binds = vec![]; let mut dependency_digraph = HashMap::new(); @@ -275,7 +277,42 @@ fn binds_fixed_by_dependency(mut fixed_binds: HashSet, bind_terms: &[Term] to_visit.extend(deps); } } - used_component + + // Mark lambdas that come before a fixed lambda as also fixed + let mut fixed_start = false; + let mut fixed_lams = HashSet::new(); + for term in bind_terms.iter().rev() { + if let Term::Lam { pat, .. } = term { + if pat.binds().flatten().any(|p| used_component.contains(p)) { + fixed_start = true; + } + if fixed_start { + for bind in pat.binds().flatten() { + fixed_lams.insert(bind.clone()); + } + } + } + } + + let mut fixed_binds = used_component; + + // Mark binds that depend on fixed lambdas as also fixed. + let mut visited = HashSet::new(); + let mut to_visit = fixed_lams.iter().collect::>(); + while let Some(node) = to_visit.pop() { + if visited.contains(node) { + continue; + } + fixed_binds.insert(node.clone()); + visited.insert(node); + + // Add these dependencies to be checked (if it's not a free var in the match arg) + if let Some(deps) = dependency_graph.get(node) { + to_visit.extend(deps); + } + } + + fixed_binds } /* Linearize all used vars */ diff --git a/tests/golden_tests/simplify_matches/complex_with_case.hvm b/tests/golden_tests/simplify_matches/complex_with_case.hvm new file mode 100644 index 00000000..b926852f --- /dev/null +++ b/tests/golden_tests/simplify_matches/complex_with_case.hvm @@ -0,0 +1,11 @@ +data Tree = (Node lt rt rd ld) | (Leaf val) + +(map) = + λarg1 λarg2 use tree = arg2; + use f = arg1; + match tree with f { + Node: (Node (map f tree.lt) (map f tree.rt) (map f tree.rd) (map f tree.ld)); + Leaf: (Leaf (f tree.val)); + } + +main = map \ No newline at end of file diff --git a/tests/snapshots/simplify_matches__complex_with_case.hvm.snap b/tests/snapshots/simplify_matches__complex_with_case.hvm.snap new file mode 100644 index 00000000..ca64e130 --- /dev/null +++ b/tests/snapshots/simplify_matches__complex_with_case.hvm.snap @@ -0,0 +1,11 @@ +--- +source: tests/golden_tests.rs +input_file: tests/golden_tests/simplify_matches/complex_with_case.hvm +--- +(map) = λa λb (match b { Node c d e f: λg (Node (map g c) (map g d) (map g e) (map g f)); Leaf h: λi (Leaf (i h)); } a) + +(main) = map + +(Node) = λa λb λc λd λe λf (e a b c d) + +(Leaf) = λa λb λc (c a)