Ensure that when jumping to a branch, all pattern symbols are loaded

If we are jumping to a target branch, it is necessary that the target
branch has all required pattern symbols loaded in it. Usually this is
already the case, but there is an exception with guarded patterns.
Guarded patterns have their patterns loaded only right before the guard
is evaluated, which happens at some point further along the decision
tree. As such, when a guarded pattern jumps to its target destination,
it should append the loaded patterns as parameters on the target
joinpoint.
This commit is contained in:
Ayaz Hafiz 2023-03-24 18:49:46 -05:00
parent 393250db92
commit ecad660e7f
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -1,7 +1,7 @@
use crate::borrow::Ownership; use crate::borrow::Ownership;
use crate::ir::{ use crate::ir::{
build_list_index_probe, BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId, build_list_index_probe, substitute_in_exprs_many, BranchInfo, Call, CallType, DestructType,
ListIndex, Literal, Param, Pattern, Procs, Stmt, Env, Expr, JoinPointId, ListIndex, Literal, Param, Pattern, Procs, Stmt,
}; };
use crate::layout::{ use crate::layout::{
Builtin, InLayout, Layout, LayoutCache, LayoutInterner, TLLayoutInterner, TagIdIntType, Builtin, InLayout, Layout, LayoutCache, LayoutInterner, TLLayoutInterner, TagIdIntType,
@ -57,6 +57,10 @@ impl<'a> Guard<'a> {
fn is_none(&self) -> bool { fn is_none(&self) -> bool {
self == &Guard::NoGuard self == &Guard::NoGuard
} }
fn is_some(&self) -> bool {
!self.is_none()
}
} }
type Edge<'a> = (GuardedTest<'a>, DecisionTree<'a>); type Edge<'a> = (GuardedTest<'a>, DecisionTree<'a>);
@ -82,10 +86,12 @@ enum GuardedTest<'a> {
/// body /// body
stmt: Stmt<'a>, stmt: Stmt<'a>,
}, },
// e.g. `<pattern> -> ...`
TestNotGuarded { TestNotGuarded {
test: Test<'a>, test: Test<'a>,
}, },
Placeholder, // e.g. `_ -> ...` or `x -> ...`
PlaceholderWithGuard,
} }
#[derive(Clone, Copy, Debug, PartialEq, Hash)] #[derive(Clone, Copy, Debug, PartialEq, Hash)]
@ -196,7 +202,7 @@ impl<'a> Hash for GuardedTest<'a> {
state.write_u8(0); state.write_u8(0);
test.hash(state); test.hash(state);
} }
GuardedTest::Placeholder => { GuardedTest::PlaceholderWithGuard => {
state.write_u8(2); state.write_u8(2);
} }
} }
@ -264,6 +270,7 @@ fn to_decision_tree<'a>(
let path = pick_path(&branches).clone(); let path = pick_path(&branches).clone();
let bs = branches.clone(); let bs = branches.clone();
let (edges, fallback) = gather_edges(interner, branches, &path); let (edges, fallback) = gather_edges(interner, branches, &path);
let mut decision_edges: Vec<_> = edges let mut decision_edges: Vec<_> = edges
@ -308,7 +315,7 @@ fn break_out_guard<'a>(
) -> DecisionTree<'a> { ) -> DecisionTree<'a> {
match edges match edges
.iter() .iter()
.position(|(t, _)| matches!(t, GuardedTest::Placeholder)) .position(|(t, _)| matches!(t, GuardedTest::PlaceholderWithGuard))
{ {
None => DecisionTree::Decision { None => DecisionTree::Decision {
path, path,
@ -347,7 +354,7 @@ fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool {
.all(|t| matches!(t, GuardedTest::TestNotGuarded { .. })); .all(|t| matches!(t, GuardedTest::TestNotGuarded { .. }));
match tests.last().unwrap() { match tests.last().unwrap() {
GuardedTest::Placeholder => false, GuardedTest::PlaceholderWithGuard => false,
GuardedTest::GuardedNoTest { .. } => false, GuardedTest::GuardedNoTest { .. } => false,
GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length), GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length),
} }
@ -687,7 +694,7 @@ fn test_at_path<'a>(
if let Guard::Guard { .. } = &branch.guard { if let Guard::Guard { .. } = &branch.guard {
// no tests for this pattern remain, but we cannot discard it yet // no tests for this pattern remain, but we cannot discard it yet
// because it has a guard! // because it has a guard!
Some(GuardedTest::Placeholder) Some(GuardedTest::PlaceholderWithGuard)
} else { } else {
None None
} }
@ -709,10 +716,33 @@ fn edges_for<'a>(
// if we test for a guard, skip all branches until one that has a guard // if we test for a guard, skip all branches until one that has a guard
let it = match test { let it = match test {
GuardedTest::GuardedNoTest { .. } | GuardedTest::Placeholder => { GuardedTest::GuardedNoTest { .. } => {
let index = branches let index = branches
.iter() .iter()
.position(|b| !b.guard.is_none()) .position(|b| b.guard.is_some())
.expect("if testing for a guard, one branch must have a guard");
branches[index..].iter()
}
GuardedTest::PlaceholderWithGuard => {
// Skip all branches until we hit the one with a placeholder and a guard.
let index = branches
.iter()
.position(|b| {
if b.guard.is_none() {
return false;
}
let (_, pattern) = b
.patterns
.iter()
.find(|(branch_path, _)| branch_path == path)
.expect(
"if testing for a placeholder with guard, must find a branch matching the path",
);
test_for_pattern(pattern).is_none()
})
.expect("if testing for a guard, one branch must have a guard"); .expect("if testing for a guard, one branch must have a guard");
branches[index..].iter() branches[index..].iter()
@ -741,7 +771,7 @@ fn to_relevant_branch<'a>(
found_pattern: pattern, found_pattern: pattern,
end, end,
} => match guarded_test { } => match guarded_test {
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => { GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => {
// if there is no test, the pattern should not require any // if there is no test, the pattern should not require any
debug_assert!( debug_assert!(
matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,), matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,),
@ -1332,7 +1362,7 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz
relevant_tests.len() + (if !fallbacks { 0 } else { 1 }) relevant_tests.len() + (if !fallbacks { 0 } else { 1 })
} }
#[derive(Clone, Debug, PartialEq)] #[derive(Debug, PartialEq)]
enum Decider<'a, T> { enum Decider<'a, T> {
Leaf(T), Leaf(T),
Guarded { Guarded {
@ -1364,6 +1394,17 @@ enum Choice<'a> {
type StoresVec<'a> = bumpalo::collections::Vec<'a, (Symbol, InLayout<'a>, Expr<'a>)>; type StoresVec<'a> = bumpalo::collections::Vec<'a, (Symbol, InLayout<'a>, Expr<'a>)>;
struct JumpSpec<'a> {
target_index: u64,
id: JoinPointId,
/// Symbols, from the unpacked pattern, to add on when jumping to the target.
jump_pattern_param_symbols: &'a [Symbol],
// Used to construct the joinpoint
join_params: &'a [Param<'a>],
join_body: Stmt<'a>,
}
pub fn optimize_when<'a>( pub fn optimize_when<'a>(
env: &mut Env<'a, '_>, env: &mut Env<'a, '_>,
procs: &mut Procs<'a>, procs: &mut Procs<'a>,
@ -1373,11 +1414,11 @@ pub fn optimize_when<'a>(
ret_layout: InLayout<'a>, ret_layout: InLayout<'a>,
opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>, opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>,
) -> Stmt<'a> { ) -> Stmt<'a> {
let (patterns, _indexed_branches) = opt_branches let (patterns, indexed_branches): (_, Vec<_>) = opt_branches
.into_iter() .into_iter()
.enumerate() .enumerate()
.map(|(index, (pattern, guard, branch))| { .map(|(index, (pattern, guard, branch))| {
let has_guard = !guard.is_none(); let has_guard = guard.is_some();
( (
(guard, pattern.clone(), index as u64), (guard, pattern.clone(), index as u64),
(index as u64, branch, pattern, has_guard), (index as u64, branch, pattern, has_guard),
@ -1385,8 +1426,6 @@ pub fn optimize_when<'a>(
}) })
.unzip(); .unzip();
let indexed_branches: Vec<_> = _indexed_branches;
let decision_tree = compile(&layout_cache.interner, patterns); let decision_tree = compile(&layout_cache.interner, patterns);
let decider = tree_to_decider(decision_tree); let decider = tree_to_decider(decision_tree);
@ -1397,19 +1436,95 @@ pub fn optimize_when<'a>(
let mut choices = MutMap::default(); let mut choices = MutMap::default();
let mut jumps = Vec::new(); let mut jumps = Vec::new();
for (index, mut branch, pattern, has_guard) in indexed_branches.into_iter() { for (target, mut branch, pattern, has_guard) in indexed_branches.into_iter() {
// bind the fields referenced in the pattern. For guards this happens separately, so let should_inline = {
// the pattern variables are defined when evaluating the guard. let target_counts = &target_counts;
if !has_guard { match target_counts.get(target as usize) {
branch = None => unreachable!(
crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, branch); "this should never happen: {:?} not in {:?}",
target, target_counts
),
Some(count) => *count == 1,
}
};
let join_params: &'a [Param<'a>];
let jump_pattern_param_symbols: &'a [Symbol];
match (has_guard, should_inline) {
(false, _) => {
// Bind the fields referenced in the pattern.
branch = crate::ir::store_pattern(
env,
procs,
layout_cache,
&pattern,
cond_symbol,
branch,
);
join_params = env.arena.alloc([]);
jump_pattern_param_symbols = env.arena.alloc([]);
}
(true, true) => {
// Nothing more to do - the patterns will be bound when the guard is evaluated in
// `decide_to_branching`.
join_params = env.arena.alloc([]);
jump_pattern_param_symbols = env.arena.alloc([]);
}
(true, false) => {
// The patterns will be bound when the guard is evaluated, and then we need to get
// them back into the joinpoint here.
//
// So, figure out what symbols the pattern binds, and update the joinpoint
// parameter to take each symbol. Then, when the joinpoint is called, the unpacked
// symbols will be filled in.
//
// Since the joinpoint's parameters will be fresh symbols, the join body also needs
// updating.
let pattern_bindings = pattern.collect_symbols(cond_layout);
let mut parameters_buf =
bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena);
let mut pattern_symbols_buf =
bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena);
for &(pattern_symbol, layout) in pattern_bindings.iter() {
let param_symbol = env.unique_symbol();
parameters_buf.push(Param {
symbol: param_symbol,
layout,
ownership: Ownership::Owned,
});
pattern_symbols_buf.push(pattern_symbol);
}
join_params = parameters_buf.into_bump_slice();
jump_pattern_param_symbols = pattern_symbols_buf.into_bump_slice();
let substitutions = pattern_bindings
.iter()
.zip(join_params.iter())
.map(|((pat, _), param)| (*pat, param.symbol))
.collect();
substitute_in_exprs_many(env.arena, &mut branch, substitutions);
}
} }
let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch); let ((branch_index, choice), opt_jump) = if should_inline {
((target, Choice::Inline(branch)), None)
} else {
((target, Choice::Jump(target)), Some((target, branch)))
};
if let Some((index, body)) = opt_jump { if let Some((target_index, body)) = opt_jump {
let id = JoinPointId(env.unique_symbol()); let id = JoinPointId(env.unique_symbol());
jumps.push((index, id, body)); jumps.push(JumpSpec {
target_index,
id,
jump_pattern_param_symbols,
join_params,
join_body: body,
});
} }
choices.insert(branch_index, choice); choices.insert(branch_index, choice);
@ -1428,11 +1543,18 @@ pub fn optimize_when<'a>(
&jumps, &jumps,
); );
for (_, id, body) in jumps.into_iter() { for JumpSpec {
target_index: _,
id,
jump_pattern_param_symbols: _,
join_params,
join_body,
} in jumps.into_iter()
{
stmt = Stmt::Join { stmt = Stmt::Join {
id, id,
parameters: &[], parameters: join_params,
body: env.arena.alloc(body), body: env.arena.alloc(join_body),
remainder: env.arena.alloc(stmt), remainder: env.arena.alloc(stmt),
}; };
} }
@ -1929,7 +2051,7 @@ fn decide_to_branching<'a>(
cond_layout: InLayout<'a>, cond_layout: InLayout<'a>,
ret_layout: InLayout<'a>, ret_layout: InLayout<'a>,
decider: Decider<'a, Choice<'a>>, decider: Decider<'a, Choice<'a>>,
jumps: &[(u64, JoinPointId, Stmt<'a>)], jumps: &[JumpSpec<'a>],
) -> Stmt<'a> { ) -> Stmt<'a> {
use Choice::*; use Choice::*;
use Decider::*; use Decider::*;
@ -1939,10 +2061,10 @@ fn decide_to_branching<'a>(
match decider { match decider {
Leaf(Jump(label)) => { Leaf(Jump(label)) => {
let index = jumps let index = jumps
.binary_search_by_key(&label, |r| r.0) .binary_search_by_key(&label, |r| r.target_index)
.expect("jump not in list of jumps"); .expect("jump not in list of jumps");
Stmt::Jump(jumps[index].1, &[]) Stmt::Jump(jumps[index].id, jumps[index].jump_pattern_param_symbols)
} }
Leaf(Inline(expr)) => expr, Leaf(Inline(expr)) => expr,
Guarded { Guarded {
@ -1997,8 +2119,8 @@ fn decide_to_branching<'a>(
let join = Stmt::Join { let join = Stmt::Join {
id, id,
parameters: arena.alloc([param]), parameters: arena.alloc([param]),
remainder: arena.alloc(stmt),
body: arena.alloc(decide), body: arena.alloc(decide),
remainder: arena.alloc(stmt),
}; };
crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, join) crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, join)
@ -2282,15 +2404,17 @@ fn sort_edge_tests_by_priority(edges: &mut [Edge<'_>]) {
edges.sort_by(|(t1, _), (t2, _)| match (t1, t2) { edges.sort_by(|(t1, _), (t2, _)| match (t1, t2) {
// Guarded takes priority // Guarded takes priority
(GuardedNoTest { .. }, GuardedNoTest { .. }) => Equal, (GuardedNoTest { .. }, GuardedNoTest { .. }) => Equal,
(GuardedNoTest { .. }, TestNotGuarded { .. }) | (GuardedNoTest { .. }, Placeholder) => Less, (GuardedNoTest { .. }, TestNotGuarded { .. })
| (GuardedNoTest { .. }, PlaceholderWithGuard) => Less,
// Interesting case: what test do we pick? // Interesting case: what test do we pick?
(TestNotGuarded { test: t1 }, TestNotGuarded { test: t2 }) => order_tests(t1, t2), (TestNotGuarded { test: t1 }, TestNotGuarded { test: t2 }) => order_tests(t1, t2),
// Otherwise we are between guarded and fall-backs // Otherwise we are between guarded and fall-backs
(TestNotGuarded { .. }, GuardedNoTest { .. }) => Greater, (TestNotGuarded { .. }, GuardedNoTest { .. }) => Greater,
(TestNotGuarded { .. }, Placeholder) => Less, (TestNotGuarded { .. }, PlaceholderWithGuard) => Less,
// Placeholder is always last // Placeholder is always last
(Placeholder, Placeholder) => Equal, (PlaceholderWithGuard, PlaceholderWithGuard) => Equal,
(Placeholder, GuardedNoTest { .. }) | (Placeholder, TestNotGuarded { .. }) => Greater, (PlaceholderWithGuard, GuardedNoTest { .. })
| (PlaceholderWithGuard, TestNotGuarded { .. }) => Greater,
}); });
fn order_tests(t1: &Test, t2: &Test) -> Ordering { fn order_tests(t1: &Test, t2: &Test) -> Ordering {
@ -2452,7 +2576,7 @@ fn fanout_decider_help<'a>(
guarded_test: GuardedTest<'a>, guarded_test: GuardedTest<'a>,
) -> (Test<'a>, Decider<'a, u64>) { ) -> (Test<'a>, Decider<'a, u64>) {
match guarded_test { match guarded_test {
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => { GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => {
unreachable!("this would not end up in a switch") unreachable!("this would not end up in a switch")
} }
GuardedTest::TestNotGuarded { test } => { GuardedTest::TestNotGuarded { test } => {
@ -2478,7 +2602,7 @@ fn chain_decider<'a>(
stmt, stmt,
pattern, pattern,
success, success,
failure: failure.clone(), failure,
} }
} }
GuardedTest::TestNotGuarded { test } => { GuardedTest::TestNotGuarded { test } => {
@ -2489,7 +2613,7 @@ fn chain_decider<'a>(
} }
} }
GuardedTest::Placeholder => { GuardedTest::PlaceholderWithGuard => {
// ? // ?
tree_to_decider(success_tree) tree_to_decider(success_tree)
} }
@ -2572,22 +2696,6 @@ fn count_targets(targets: &mut bumpalo::collections::Vec<u64>, initial: &Decider
} }
} }
#[allow(clippy::type_complexity)]
fn create_choices<'a>(
target_counts: &bumpalo::collections::Vec<'a, u64>,
target: u64,
branch: Stmt<'a>,
) -> ((u64, Choice<'a>), Option<(u64, Stmt<'a>)>) {
match target_counts.get(target as usize) {
None => unreachable!(
"this should never happen: {:?} not in {:?}",
target, target_counts
),
Some(1) => ((target, Choice::Inline(branch)), None),
Some(_) => ((target, Choice::Jump(target)), Some((target, branch))),
}
}
fn insert_choices<'a>( fn insert_choices<'a>(
choice_dict: &MutMap<u64, Choice<'a>>, choice_dict: &MutMap<u64, Choice<'a>>,
decider: Decider<'a, u64>, decider: Decider<'a, u64>,