diff --git a/crates/compiler/mono/src/decision_tree.rs b/crates/compiler/mono/src/decision_tree.rs index 444565e5c3..7dcf9bb984 100644 --- a/crates/compiler/mono/src/decision_tree.rs +++ b/crates/compiler/mono/src/decision_tree.rs @@ -1,7 +1,7 @@ use crate::borrow::Ownership; use crate::ir::{ - build_list_index_probe, BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId, - ListIndex, Literal, Param, Pattern, Procs, Stmt, + build_list_index_probe, substitute_in_exprs_many, BranchInfo, Call, CallType, DestructType, + Env, Expr, JoinPointId, ListIndex, Literal, Param, Pattern, Procs, Stmt, }; use crate::layout::{ Builtin, InLayout, Layout, LayoutCache, LayoutInterner, TLLayoutInterner, TagIdIntType, @@ -57,6 +57,10 @@ impl<'a> Guard<'a> { fn is_none(&self) -> bool { self == &Guard::NoGuard } + + fn is_some(&self) -> bool { + !self.is_none() + } } type Edge<'a> = (GuardedTest<'a>, DecisionTree<'a>); @@ -82,10 +86,12 @@ enum GuardedTest<'a> { /// body stmt: Stmt<'a>, }, + // e.g. ` -> ...` TestNotGuarded { test: Test<'a>, }, - Placeholder, + // e.g. `_ -> ...` or `x -> ...` + PlaceholderWithGuard, } #[derive(Clone, Copy, Debug, PartialEq, Hash)] @@ -196,7 +202,7 @@ impl<'a> Hash for GuardedTest<'a> { state.write_u8(0); test.hash(state); } - GuardedTest::Placeholder => { + GuardedTest::PlaceholderWithGuard => { state.write_u8(2); } } @@ -264,6 +270,7 @@ fn to_decision_tree<'a>( let path = pick_path(&branches).clone(); let bs = branches.clone(); + let (edges, fallback) = gather_edges(interner, branches, &path); let mut decision_edges: Vec<_> = edges @@ -308,7 +315,7 @@ fn break_out_guard<'a>( ) -> DecisionTree<'a> { match edges .iter() - .position(|(t, _)| matches!(t, GuardedTest::Placeholder)) + .position(|(t, _)| matches!(t, GuardedTest::PlaceholderWithGuard)) { None => DecisionTree::Decision { path, @@ -347,7 +354,7 @@ fn guarded_tests_are_complete(tests: &[GuardedTest]) -> bool { .all(|t| matches!(t, GuardedTest::TestNotGuarded { .. })); match tests.last().unwrap() { - GuardedTest::Placeholder => false, + GuardedTest::PlaceholderWithGuard => false, GuardedTest::GuardedNoTest { .. } => false, GuardedTest::TestNotGuarded { test } => no_guard && tests_are_complete_help(test, length), } @@ -687,7 +694,7 @@ fn test_at_path<'a>( if let Guard::Guard { .. } = &branch.guard { // no tests for this pattern remain, but we cannot discard it yet // because it has a guard! - Some(GuardedTest::Placeholder) + Some(GuardedTest::PlaceholderWithGuard) } else { None } @@ -709,10 +716,33 @@ fn edges_for<'a>( // if we test for a guard, skip all branches until one that has a guard let it = match test { - GuardedTest::GuardedNoTest { .. } | GuardedTest::Placeholder => { + GuardedTest::GuardedNoTest { .. } => { let index = branches .iter() - .position(|b| !b.guard.is_none()) + .position(|b| b.guard.is_some()) + .expect("if testing for a guard, one branch must have a guard"); + + branches[index..].iter() + } + GuardedTest::PlaceholderWithGuard => { + // Skip all branches until we hit the one with a placeholder and a guard. + let index = branches + .iter() + .position(|b| { + if b.guard.is_none() { + return false; + } + + let (_, pattern) = b + .patterns + .iter() + .find(|(branch_path, _)| branch_path == path) + .expect( + "if testing for a placeholder with guard, must find a branch matching the path", + ); + + test_for_pattern(pattern).is_none() + }) .expect("if testing for a guard, one branch must have a guard"); branches[index..].iter() @@ -741,7 +771,7 @@ fn to_relevant_branch<'a>( found_pattern: pattern, end, } => match guarded_test { - GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => { + GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => { // if there is no test, the pattern should not require any debug_assert!( matches!(pattern, Pattern::Identifier(_) | Pattern::Underscore,), @@ -1332,7 +1362,7 @@ fn small_branching_factor(branches: &[Branch], path: &[PathInstruction]) -> usiz relevant_tests.len() + (if !fallbacks { 0 } else { 1 }) } -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] enum Decider<'a, T> { Leaf(T), Guarded { @@ -1364,6 +1394,17 @@ enum Choice<'a> { type StoresVec<'a> = bumpalo::collections::Vec<'a, (Symbol, InLayout<'a>, Expr<'a>)>; +struct JumpSpec<'a> { + target_index: u64, + id: JoinPointId, + /// Symbols, from the unpacked pattern, to add on when jumping to the target. + jump_pattern_param_symbols: &'a [Symbol], + + // Used to construct the joinpoint + join_params: &'a [Param<'a>], + join_body: Stmt<'a>, +} + pub fn optimize_when<'a>( env: &mut Env<'a, '_>, procs: &mut Procs<'a>, @@ -1373,11 +1414,11 @@ pub fn optimize_when<'a>( ret_layout: InLayout<'a>, opt_branches: bumpalo::collections::Vec<'a, (Pattern<'a>, Guard<'a>, Stmt<'a>)>, ) -> Stmt<'a> { - let (patterns, _indexed_branches) = opt_branches + let (patterns, indexed_branches): (_, Vec<_>) = opt_branches .into_iter() .enumerate() .map(|(index, (pattern, guard, branch))| { - let has_guard = !guard.is_none(); + let has_guard = guard.is_some(); ( (guard, pattern.clone(), index as u64), (index as u64, branch, pattern, has_guard), @@ -1385,8 +1426,6 @@ pub fn optimize_when<'a>( }) .unzip(); - let indexed_branches: Vec<_> = _indexed_branches; - let decision_tree = compile(&layout_cache.interner, patterns); let decider = tree_to_decider(decision_tree); @@ -1397,19 +1436,95 @@ pub fn optimize_when<'a>( let mut choices = MutMap::default(); let mut jumps = Vec::new(); - for (index, mut branch, pattern, has_guard) in indexed_branches.into_iter() { - // bind the fields referenced in the pattern. For guards this happens separately, so - // the pattern variables are defined when evaluating the guard. - if !has_guard { - branch = - crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, branch); + for (target, mut branch, pattern, has_guard) in indexed_branches.into_iter() { + let should_inline = { + let target_counts = &target_counts; + match target_counts.get(target as usize) { + None => unreachable!( + "this should never happen: {:?} not in {:?}", + target, target_counts + ), + Some(count) => *count == 1, + } + }; + + let join_params: &'a [Param<'a>]; + let jump_pattern_param_symbols: &'a [Symbol]; + match (has_guard, should_inline) { + (false, _) => { + // Bind the fields referenced in the pattern. + branch = crate::ir::store_pattern( + env, + procs, + layout_cache, + &pattern, + cond_symbol, + branch, + ); + + join_params = env.arena.alloc([]); + jump_pattern_param_symbols = env.arena.alloc([]); + } + (true, true) => { + // Nothing more to do - the patterns will be bound when the guard is evaluated in + // `decide_to_branching`. + join_params = env.arena.alloc([]); + jump_pattern_param_symbols = env.arena.alloc([]); + } + (true, false) => { + // The patterns will be bound when the guard is evaluated, and then we need to get + // them back into the joinpoint here. + // + // So, figure out what symbols the pattern binds, and update the joinpoint + // parameter to take each symbol. Then, when the joinpoint is called, the unpacked + // symbols will be filled in. + // + // Since the joinpoint's parameters will be fresh symbols, the join body also needs + // updating. + let pattern_bindings = pattern.collect_symbols(cond_layout); + + let mut parameters_buf = + bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena); + let mut pattern_symbols_buf = + bumpalo::collections::Vec::with_capacity_in(pattern_bindings.len(), env.arena); + + for &(pattern_symbol, layout) in pattern_bindings.iter() { + let param_symbol = env.unique_symbol(); + parameters_buf.push(Param { + symbol: param_symbol, + layout, + ownership: Ownership::Owned, + }); + pattern_symbols_buf.push(pattern_symbol); + } + + join_params = parameters_buf.into_bump_slice(); + jump_pattern_param_symbols = pattern_symbols_buf.into_bump_slice(); + + let substitutions = pattern_bindings + .iter() + .zip(join_params.iter()) + .map(|((pat, _), param)| (*pat, param.symbol)) + .collect(); + substitute_in_exprs_many(env.arena, &mut branch, substitutions); + } } - let ((branch_index, choice), opt_jump) = create_choices(&target_counts, index, branch); + let ((branch_index, choice), opt_jump) = if should_inline { + ((target, Choice::Inline(branch)), None) + } else { + ((target, Choice::Jump(target)), Some((target, branch))) + }; - if let Some((index, body)) = opt_jump { + if let Some((target_index, body)) = opt_jump { let id = JoinPointId(env.unique_symbol()); - jumps.push((index, id, body)); + jumps.push(JumpSpec { + target_index, + id, + jump_pattern_param_symbols, + join_params, + join_body: body, + }); } choices.insert(branch_index, choice); @@ -1428,11 +1543,18 @@ pub fn optimize_when<'a>( &jumps, ); - for (_, id, body) in jumps.into_iter() { + for JumpSpec { + target_index: _, + id, + jump_pattern_param_symbols: _, + join_params, + join_body, + } in jumps.into_iter() + { stmt = Stmt::Join { id, - parameters: &[], - body: env.arena.alloc(body), + parameters: join_params, + body: env.arena.alloc(join_body), remainder: env.arena.alloc(stmt), }; } @@ -1929,7 +2051,7 @@ fn decide_to_branching<'a>( cond_layout: InLayout<'a>, ret_layout: InLayout<'a>, decider: Decider<'a, Choice<'a>>, - jumps: &[(u64, JoinPointId, Stmt<'a>)], + jumps: &[JumpSpec<'a>], ) -> Stmt<'a> { use Choice::*; use Decider::*; @@ -1939,10 +2061,10 @@ fn decide_to_branching<'a>( match decider { Leaf(Jump(label)) => { let index = jumps - .binary_search_by_key(&label, |r| r.0) + .binary_search_by_key(&label, |r| r.target_index) .expect("jump not in list of jumps"); - Stmt::Jump(jumps[index].1, &[]) + Stmt::Jump(jumps[index].id, jumps[index].jump_pattern_param_symbols) } Leaf(Inline(expr)) => expr, Guarded { @@ -1997,8 +2119,8 @@ fn decide_to_branching<'a>( let join = Stmt::Join { id, parameters: arena.alloc([param]), - remainder: arena.alloc(stmt), body: arena.alloc(decide), + remainder: arena.alloc(stmt), }; crate::ir::store_pattern(env, procs, layout_cache, &pattern, cond_symbol, join) @@ -2282,15 +2404,17 @@ fn sort_edge_tests_by_priority(edges: &mut [Edge<'_>]) { edges.sort_by(|(t1, _), (t2, _)| match (t1, t2) { // Guarded takes priority (GuardedNoTest { .. }, GuardedNoTest { .. }) => Equal, - (GuardedNoTest { .. }, TestNotGuarded { .. }) | (GuardedNoTest { .. }, Placeholder) => Less, + (GuardedNoTest { .. }, TestNotGuarded { .. }) + | (GuardedNoTest { .. }, PlaceholderWithGuard) => Less, // Interesting case: what test do we pick? (TestNotGuarded { test: t1 }, TestNotGuarded { test: t2 }) => order_tests(t1, t2), // Otherwise we are between guarded and fall-backs (TestNotGuarded { .. }, GuardedNoTest { .. }) => Greater, - (TestNotGuarded { .. }, Placeholder) => Less, + (TestNotGuarded { .. }, PlaceholderWithGuard) => Less, // Placeholder is always last - (Placeholder, Placeholder) => Equal, - (Placeholder, GuardedNoTest { .. }) | (Placeholder, TestNotGuarded { .. }) => Greater, + (PlaceholderWithGuard, PlaceholderWithGuard) => Equal, + (PlaceholderWithGuard, GuardedNoTest { .. }) + | (PlaceholderWithGuard, TestNotGuarded { .. }) => Greater, }); fn order_tests(t1: &Test, t2: &Test) -> Ordering { @@ -2452,7 +2576,7 @@ fn fanout_decider_help<'a>( guarded_test: GuardedTest<'a>, ) -> (Test<'a>, Decider<'a, u64>) { match guarded_test { - GuardedTest::Placeholder | GuardedTest::GuardedNoTest { .. } => { + GuardedTest::PlaceholderWithGuard | GuardedTest::GuardedNoTest { .. } => { unreachable!("this would not end up in a switch") } GuardedTest::TestNotGuarded { test } => { @@ -2478,7 +2602,7 @@ fn chain_decider<'a>( stmt, pattern, success, - failure: failure.clone(), + failure, } } GuardedTest::TestNotGuarded { test } => { @@ -2489,7 +2613,7 @@ fn chain_decider<'a>( } } - GuardedTest::Placeholder => { + GuardedTest::PlaceholderWithGuard => { // ? tree_to_decider(success_tree) } @@ -2572,22 +2696,6 @@ fn count_targets(targets: &mut bumpalo::collections::Vec, initial: &Decider } } -#[allow(clippy::type_complexity)] -fn create_choices<'a>( - target_counts: &bumpalo::collections::Vec<'a, u64>, - target: u64, - branch: Stmt<'a>, -) -> ((u64, Choice<'a>), Option<(u64, Stmt<'a>)>) { - match target_counts.get(target as usize) { - None => unreachable!( - "this should never happen: {:?} not in {:?}", - target, target_counts - ), - Some(1) => ((target, Choice::Inline(branch)), None), - Some(_) => ((target, Choice::Jump(target)), Some((target, branch))), - } -} - fn insert_choices<'a>( choice_dict: &MutMap>, decider: Decider<'a, u64>,