diff --git a/compiler/mono/src/decision_tree2.rs b/compiler/mono/src/decision_tree2.rs new file mode 100644 index 0000000000..8e9a14c993 --- /dev/null +++ b/compiler/mono/src/decision_tree2.rs @@ -0,0 +1,1575 @@ +use crate::experiment::{Expr, Literal, Stmt}; +use crate::expr::{DestructType, Env, Pattern}; +use crate::layout::{Builtin, Layout}; +use crate::pattern::{Ctor, RenderAs, TagId, Union}; +use bumpalo::Bump; +use roc_collections::all::{MutMap, MutSet}; +use roc_module::ident::TagName; +use roc_module::low_level::LowLevel; +use roc_module::symbol::Symbol; + +/// COMPILE CASES + +type Label = u64; + +/// Users of this module will mainly interact with this function. It takes +/// some normal branches and gives out a decision tree that has "labels" at all +/// the leafs and a dictionary that maps these "labels" to the code that should +/// run. +pub fn compile<'a>(raw_branches: Vec<(Guard<'a>, Pattern<'a>, u64)>) -> DecisionTree<'a> { + let formatted = raw_branches + .into_iter() + .map(|(guard, pattern, index)| Branch { + goal: index, + patterns: vec![(Path::Empty, guard, pattern)], + }) + .collect(); + + to_decision_tree(formatted) +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Guard<'a> { + NoGuard, + Guard { + stores: &'a [(Symbol, Layout<'a>, Expr<'a>)], + expr: Stmt<'a>, + }, +} + +impl<'a> Guard<'a> { + fn is_none(&self) -> bool { + self == &Guard::NoGuard + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum DecisionTree<'a> { + Match(Label), + Decision { + path: Path, + edges: Vec<(Test<'a>, DecisionTree<'a>)>, + default: Option>>, + }, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Test<'a> { + IsCtor { + tag_id: u8, + tag_name: TagName, + union: crate::pattern::Union, + arguments: Vec<(Pattern<'a>, Layout<'a>)>, + }, + IsInt(i64), + // float patterns are stored as u64 so they are comparable/hashable + IsFloat(u64), + IsStr(Box), + IsBit(bool), + IsByte { + tag_id: u8, + num_alts: usize, + }, + // A pattern that always succeeds (like `_`) can still have a guard + Guarded { + opt_test: Option>>, + stores: &'a [(Symbol, Layout<'a>, Expr<'a>)], + expr: Stmt<'a>, + }, +} +use std::hash::{Hash, Hasher}; +impl<'a> Hash for Test<'a> { + fn hash(&self, state: &mut H) { + use Test::*; + + match self { + IsCtor { tag_id, .. } => { + state.write_u8(0); + tag_id.hash(state); + // The point of this custom implementation is to not hash the tag arguments + } + IsInt(v) => { + state.write_u8(1); + v.hash(state); + } + IsFloat(v) => { + state.write_u8(2); + v.hash(state); + } + IsStr(v) => { + state.write_u8(3); + v.hash(state); + } + IsBit(v) => { + state.write_u8(4); + v.hash(state); + } + IsByte { tag_id, num_alts } => { + state.write_u8(5); + tag_id.hash(state); + num_alts.hash(state); + } + Guarded { opt_test: None, .. } => { + state.write_u8(6); + } + Guarded { + opt_test: Some(nested), + .. + } => { + state.write_u8(7); + nested.hash(state); + } + } + } +} + +#[derive(Clone, Debug, PartialEq)] +pub enum Path { + Index { + index: u64, + tag_id: u8, + path: Box, + }, + Unbox(Box), + Empty, +} + +// ACTUALLY BUILD DECISION TREES + +#[derive(Clone, Debug, PartialEq)] +struct Branch<'a> { + goal: Label, + patterns: Vec<(Path, Guard<'a>, Pattern<'a>)>, +} + +fn to_decision_tree(raw_branches: Vec) -> DecisionTree { + let branches: Vec<_> = raw_branches.into_iter().map(flatten_patterns).collect(); + + match check_for_match(&branches) { + Some(goal) => DecisionTree::Match(goal), + None => { + // TODO remove clone + let path = pick_path(branches.clone()); + + let (edges, fallback) = gather_edges(branches, &path); + + let mut decision_edges: Vec<_> = edges + .into_iter() + .map(|(a, b)| (a, to_decision_tree(b))) + .collect(); + + match (decision_edges.split_last_mut(), fallback.split_last()) { + (Some(((_tag, decision_tree), rest)), None) if rest.is_empty() => { + // TODO remove clone + decision_tree.clone() + } + (_, None) => DecisionTree::Decision { + path, + edges: decision_edges, + default: None, + }, + (None, Some(_)) => to_decision_tree(fallback), + _ => DecisionTree::Decision { + path, + edges: decision_edges, + default: Some(Box::new(to_decision_tree(fallback))), + }, + } + } + } +} + +fn is_complete(tests: &[Test]) -> bool { + let length = tests.len(); + debug_assert!(length > 0); + match tests.get(length - 1) { + None => unreachable!("should never happen"), + Some(v) => match v { + Test::IsCtor { union, .. } => length == union.alternatives.len(), + Test::IsByte { num_alts, .. } => length == *num_alts, + Test::IsBit(_) => length == 2, + Test::IsInt(_) => false, + Test::IsFloat(_) => false, + Test::IsStr(_) => false, + Test::Guarded { .. } => false, + }, + } +} + +fn flatten_patterns(branch: Branch) -> Branch { + let mut result = Vec::with_capacity(branch.patterns.len()); + + for path_pattern in branch.patterns { + flatten(path_pattern, &mut result); + } + Branch { + goal: branch.goal, + patterns: result, + } +} + +fn flatten<'a>( + path_pattern: (Path, Guard<'a>, Pattern<'a>), + path_patterns: &mut Vec<(Path, Guard<'a>, Pattern<'a>)>, +) { + match &path_pattern.2 { + Pattern::AppliedTag { + union, + arguments, + tag_id, + .. + } => { + // TODO do we need to check that guard.is_none() here? + if union.alternatives.len() == 1 { + let path = path_pattern.0; + // Theory: unbox doesn't have any value for us, because one-element tag unions + // don't store the tag anyway. + if arguments.len() == 1 { + path_patterns.push(( + Path::Unbox(Box::new(path)), + path_pattern.1.clone(), + path_pattern.2.clone(), + )); + } else { + for (index, (arg_pattern, _)) in arguments.iter().enumerate() { + flatten( + ( + Path::Index { + index: index as u64, + tag_id: *tag_id, + path: Box::new(path.clone()), + }, + // same guard here? + path_pattern.1.clone(), + arg_pattern.clone(), + ), + path_patterns, + ); + } + } + } else { + path_patterns.push(path_pattern); + } + } + + _ => { + path_patterns.push(path_pattern); + } + } +} + +/// SUCCESSFULLY MATCH + +/// If the first branch has no more "decision points" we can finally take that +/// path. If that is the case we give the resulting label and a mapping from free +/// variables to "how to get their value". So a pattern like (Just (x,_)) will give +/// us something like ("x" => value.0.0) +fn check_for_match<'a>(branches: &Vec>) -> Option