Ensure that when jumping to a branch, all pattern symbols are loaded

If we are jumping to a target branch, it is necessary that the target
branch has all required pattern symbols loaded in it. Usually this is
already the case, but there is an exception with guarded patterns.
Guarded patterns have their patterns loaded only right before the guard
is evaluated, which happens at some point further along the decision
tree. As such, when a guarded pattern jumps to its target destination,
it should append the loaded patterns as parameters on the target
joinpoint.
This commit is contained in:
Ayaz Hafiz 2023-03-24 18:49:46 -05:00
parent 393250db92
commit ecad660e7f
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -1,7 +1,7 @@
use crate::borrow::Ownership;
use crate::ir::{
build_list_index_probe, BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId,
ListIndex, Literal, Param, Pattern, Procs, Stmt,
build_list_index_probe, substitute_in_exprs_many, BranchInfo, Call, CallType, DestructType,
Env, Expr, JoinPointId, ListIndex, Literal, Param, Pattern, Procs, Stmt,
};
use crate::layout::{
Builtin, InLayout, Layout, LayoutCache, LayoutInterner, TLLayoutInterner, TagIdIntType,
@ -57,6 +57,10 @@ impl<'a> Guard<'a> {
fn is_none(&self) -> bool {
self == &Guard::NoGuard
}
fn is_some(&self) -> bool {
!self.is_none()
}
}
type Edge<'a> = (GuardedTest<'a>, DecisionTree<'a>);
@ -82,10 +86,12 @@ enum GuardedTest<'a> {
/// body
stmt: Stmt<'a>,
},
// e.g. `<pattern> -> ...`
TestNotGuarded {
test: Test<'a>,
},
Placeholder,
// e.g. `_ -> ...` or `x -> ...`
PlaceholderWithGuard,
}
#[derive(Clone, Copy, Debug, PartialEq, Hash)]
@ -196,7 +202,7 @@ impl<'a> Hash for GuardedTest<'a> {
state.write_u8(0);
test.hash(state);
}
GuardedTest::Placeholder => {
GuardedTest::PlaceholderWithGuard => {
state.write_u8(2);
}
}
@ -264,6 +270,7 @@ fn to_decision_tree<'a>(
let path = pick_path(&branches).clone();
let bs = branches.clone();
let (edges, fallback) = gather_edges(interner, branches, &path);
let mut decision_edges: Vec<_> = edges
@ -308,7 +315,7 @@ fn break_out_guard<'a>(
) -> DecisionTree<'a> {
match edges
.iter()
.position(|(t, _)| matches!(t, GuardedTest::Placeholder))
.position(|(t, _)| matches!(t, GuardedTest::PlaceholderWithGuard))
{
None => DecisionTree::Decision {
path,
@ -347,7 +354,7 @@ fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool {
.all(|t| matches!(t, GuardedTest::TestNotGuarded { .. }));
match tests.last().unwrap() {
GuardedTest::Placeholder => false,
GuardedTest::PlaceholderWithGuard => false,
GuardedTest::GuardedNoTest { .. } => false,
GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length),
}
@ -687,7 +694,7 @@ fn test_at_path<'a>(
if let Guard::Guard { .. } = &branch.guard {
// no tests for this pattern remain, but we cannot discard it yet
// because it has a guard!
Some(GuardedTest::Placeholder)
Some(GuardedTest::PlaceholderWithGuard)
} else {
None
}
@ -709,10 +716,33 @@ fn edges_for<'a>(
// if we test for a guard, skip all branches until one that has a guard
let it = match test {
GuardedTest::GuardedNoTest { .. } | GuardedTest::Placeholder => {
GuardedTest::GuardedNoTest { .. } => {
let index = branches
.iter()
.position(|b| !b.guard.is_none())
.position(|b| b.guard.is_some())
.expect("if testing for a guard, one branch must have a guard");
branches[index..].iter()
}
GuardedTest::PlaceholderWithGuard => {
// Skip all branches until we hit the one with a placeholder and a guard.
let index = branches
.iter()
.position(|b| {
if b.guard.is_none() {
return false;
}
let (_, pattern) = b
.patterns
.iter()
.find(|(branch_path, _)| branch_path == path)
.expect(
"if testing for a placeholder with guard, must find a branch matching the path",
);
test_for_pattern(pattern).is_none()
})
.expect("if testing for a guard, one branch must have a guard");
branches[index..].iter()
@ -741,7 +771,7 @@ fn to_relevant_branch<'a>(
found_pattern: pattern,
end,
} => match guarded_test {
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => {
GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => {
// if there is no test, the pattern should not require any
debug_assert!(
matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,),
@ -1332,7 +1362,7 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz
relevant_tests.len() + (if !fallbacks { 0 } else { 1 })
}
#[derive(Clone, Debug, PartialEq)]
#[derive(Debug, PartialEq)]
enum Decider<'a, T> {
Leaf(T),
Guarded {
@ -1364,6 +1394,17 @@ enum Choice<'a> {
type StoresVec<'a> = bumpalo::collections::Vec<'a, (Symbol, InLayout<'a>, Expr<'a>)>;
struct JumpSpec<'a> {
target_index: u64,
id: JoinPointId,
/// Symbols, from the unpacked pattern, to add on when jumping to the target.
jump_pattern_param_symbols: &'a [Symbol],
// Used to construct the joinpoint
join_params: &'a [Param<'a>],
join_body: Stmt<'a>,
}
pub fn optimize_when<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
@ -1373,11 +1414,11 @@ pub fn optimize_when<'a>(
ret_layout: InLayout<'a>,
opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>,
) -> Stmt<'a> {
let (patterns, _indexed_branches) = opt_branches
let (patterns, indexed_branches): (_, Vec<_>) = opt_branches
.into_iter()
.enumerate()
.map(|(index, (pattern, guard, branch))| {
let has_guard = !guard.is_none();
let has_guard = guard.is_some();
(
(guard, pattern.clone(), index as u64),
(index as u64, branch, pattern, has_guard),
@ -1385,8 +1426,6 @@ pub fn optimize_when<'a>(
})
.unzip();
let indexed_branches: Vec<_> = _indexed_branches;
let decision_tree = compile(&layout_cache.interner, patterns);
let decider = tree_to_decider(decision_tree);
@ -1397,19 +1436,95 @@ pub fn optimize_when<'a>(
let mut choices = MutMap::default();
let mut jumps = Vec::new();
for (index, mut branch, pattern, has_guard) in indexed_branches.into_iter() {
// bind the fields referenced in the pattern. For guards this happens separately, so
// the pattern variables are defined when evaluating the guard.
if !has_guard {
branch =
crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, branch);
for (target, mut branch, pattern, has_guard) in indexed_branches.into_iter() {
let should_inline = {
let target_counts = &target_counts;
match target_counts.get(target as usize) {
None => unreachable!(
"this should never happen: {:?} not in {:?}",
target, target_counts
),
Some(count) => *count == 1,
}
};
let join_params: &'a [Param<'a>];
let jump_pattern_param_symbols: &'a [Symbol];
match (has_guard, should_inline) {
(false, _) => {
// Bind the fields referenced in the pattern.
branch = crate::ir::store_pattern(
env,
procs,
layout_cache,
&pattern,
cond_symbol,
branch,
);
join_params = env.arena.alloc([]);
jump_pattern_param_symbols = env.arena.alloc([]);
}
(true, true) => {
// Nothing more to do - the patterns will be bound when the guard is evaluated in
// `decide_to_branching`.
join_params = env.arena.alloc([]);
jump_pattern_param_symbols = env.arena.alloc([]);
}
(true, false) => {
// The patterns will be bound when the guard is evaluated, and then we need to get
// them back into the joinpoint here.
//
// So, figure out what symbols the pattern binds, and update the joinpoint
// parameter to take each symbol. Then, when the joinpoint is called, the unpacked
// symbols will be filled in.
//
// Since the joinpoint's parameters will be fresh symbols, the join body also needs
// updating.
let pattern_bindings = pattern.collect_symbols(cond_layout);
let mut parameters_buf =
bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena);
let mut pattern_symbols_buf =
bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena);
for &(pattern_symbol, layout) in pattern_bindings.iter() {
let param_symbol = env.unique_symbol();
parameters_buf.push(Param {
symbol: param_symbol,
layout,
ownership: Ownership::Owned,
});
pattern_symbols_buf.push(pattern_symbol);
}
join_params = parameters_buf.into_bump_slice();
jump_pattern_param_symbols = pattern_symbols_buf.into_bump_slice();
let substitutions = pattern_bindings
.iter()
.zip(join_params.iter())
.map(|((pat, _), param)| (*pat, param.symbol))
.collect();
substitute_in_exprs_many(env.arena, &mut branch, substitutions);
}
}
let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch);
let ((branch_index, choice), opt_jump) = if should_inline {
((target, Choice::Inline(branch)), None)
} else {
((target, Choice::Jump(target)), Some((target, branch)))
};
if let Some((index, body)) = opt_jump {
if let Some((target_index, body)) = opt_jump {
let id = JoinPointId(env.unique_symbol());
jumps.push((index, id, body));
jumps.push(JumpSpec {
target_index,
id,
jump_pattern_param_symbols,
join_params,
join_body: body,
});
}
choices.insert(branch_index, choice);
@ -1428,11 +1543,18 @@ pub fn optimize_when<'a>(
&jumps,
);
for (_, id, body) in jumps.into_iter() {
for JumpSpec {
target_index: _,
id,
jump_pattern_param_symbols: _,
join_params,
join_body,
} in jumps.into_iter()
{
stmt = Stmt::Join {
id,
parameters: &[],
body: env.arena.alloc(body),
parameters: join_params,
body: env.arena.alloc(join_body),
remainder: env.arena.alloc(stmt),
};
}
@ -1929,7 +2051,7 @@ fn decide_to_branching<'a>(
cond_layout: InLayout<'a>,
ret_layout: InLayout<'a>,
decider: Decider<'a, Choice<'a>>,
jumps: &[(u64, JoinPointId, Stmt<'a>)],
jumps: &[JumpSpec<'a>],
) -> Stmt<'a> {
use Choice::*;
use Decider::*;
@ -1939,10 +2061,10 @@ fn decide_to_branching<'a>(
match decider {
Leaf(Jump(label)) => {
let index = jumps
.binary_search_by_key(&label, |r| r.0)
.binary_search_by_key(&label, |r| r.target_index)
.expect("jump not in list of jumps");
Stmt::Jump(jumps[index].1, &[])
Stmt::Jump(jumps[index].id, jumps[index].jump_pattern_param_symbols)
}
Leaf(Inline(expr)) => expr,
Guarded {
@ -1997,8 +2119,8 @@ fn decide_to_branching<'a>(
let join = Stmt::Join {
id,
parameters: arena.alloc([param]),
remainder: arena.alloc(stmt),
body: arena.alloc(decide),
remainder: arena.alloc(stmt),
};
crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, join)
@ -2282,15 +2404,17 @@ fn sort_edge_tests_by_priority(edges: &mut [Edge<'_>]) {
edges.sort_by(|(t1, _), (t2, _)| match (t1, t2) {
// Guarded takes priority
(GuardedNoTest { .. }, GuardedNoTest { .. }) => Equal,
(GuardedNoTest { .. }, TestNotGuarded { .. }) | (GuardedNoTest { .. }, Placeholder) => Less,
(GuardedNoTest { .. }, TestNotGuarded { .. })
| (GuardedNoTest { .. }, PlaceholderWithGuard) => Less,
// Interesting case: what test do we pick?
(TestNotGuarded { test: t1 }, TestNotGuarded { test: t2 }) => order_tests(t1, t2),
// Otherwise we are between guarded and fall-backs
(TestNotGuarded { .. }, GuardedNoTest { .. }) => Greater,
(TestNotGuarded { .. }, Placeholder) => Less,
(TestNotGuarded { .. }, PlaceholderWithGuard) => Less,
// Placeholder is always last
(Placeholder, Placeholder) => Equal,
(Placeholder, GuardedNoTest { .. }) | (Placeholder, TestNotGuarded { .. }) => Greater,
(PlaceholderWithGuard, PlaceholderWithGuard) => Equal,
(PlaceholderWithGuard, GuardedNoTest { .. })
| (PlaceholderWithGuard, TestNotGuarded { .. }) => Greater,
});
fn order_tests(t1: &Test, t2: &Test) -> Ordering {
@ -2452,7 +2576,7 @@ fn fanout_decider_help<'a>(
guarded_test: GuardedTest<'a>,
) -> (Test<'a>, Decider<'a, u64>) {
match guarded_test {
GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => {
GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => {
unreachable!("this would not end up in a switch")
}
GuardedTest::TestNotGuarded { test } => {
@ -2478,7 +2602,7 @@ fn chain_decider<'a>(
stmt,
pattern,
success,
failure: failure.clone(),
failure,
}
}
GuardedTest::TestNotGuarded { test } => {
@ -2489,7 +2613,7 @@ fn chain_decider<'a>(
}
}
GuardedTest::Placeholder => {
GuardedTest::PlaceholderWithGuard => {
// ?
tree_to_decider(success_tree)
}
@ -2572,22 +2696,6 @@ fn count_targets(targets: &mut bumpalo::collections::Vec<u64>, initial: &Decider
}
}
#[allow(clippy::type_complexity)]
fn create_choices<'a>(
target_counts: &bumpalo::collections::Vec<'a, u64>,
target: u64,
branch: Stmt<'a>,
) -> ((u64, Choice<'a>), Option<(u64, Stmt<'a>)>) {
match target_counts.get(target as usize) {
None => unreachable!(
"this should never happen: {:?} not in {:?}",
target, target_counts
),
Some(1) => ((target, Choice::Inline(branch)), None),
Some(_) => ((target, Choice::Jump(target)), Some((target, branch))),
}
}
fn insert_choices<'a>(
choice_dict: &MutMap<u64, Choice<'a>>,
decider: Decider<'a, u64>,