Decision tree compilation of suffixed list patterns

This commit is contained in:
Ayaz Hafiz 2022-11-01 15:22:00 -05:00
parent b0edcc9af4
commit 0706615d29
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58
3 changed files with 152 additions and 43 deletions

View File

@ -1,6 +1,6 @@
use crate::ir::{
BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId, Literal, Param, Pattern,
Procs, Stmt,
build_list_index_probe, BranchInfo, Call, CallType, DestructType, Env, Expr, JoinPointId,
ListIndex, Literal, Param, Pattern, Procs, Stmt,
};
use crate::layout::{Builtin, Layout, LayoutCache, TagIdIntType, UnionLayout};
use roc_builtins::bitcode::{FloatWidth, IntWidth};
@ -808,15 +808,14 @@ fn to_relevant_branch_help<'a>(
}
}) =>
{
if matches!(my_arity, ListArity::Slice(_, n) if n > 0) {
todo!();
}
let sub_positions = elements.into_iter().enumerate().map(|(index, elem_pat)| {
let mut new_path = path.to_vec();
let probe_index = ListIndex::from_pattern_index(index, my_arity);
let next_instr = PathInstruction::ListIndex {
// TODO index into back as well
index: index as _,
index: probe_index as _,
};
new_path.push(next_instr);
@ -1348,15 +1347,8 @@ pub fn optimize_when<'a>(
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum PathInstruction {
NewType,
TagIndex {
index: u64,
tag_id: TagIdIntType,
},
ListIndex {
// Positive if it should be indexed from the front, negative otherwise
// (-1 means the last index)
index: i64,
},
TagIndex { index: u64, tag_id: TagIdIntType },
ListIndex { index: ListIndex },
}
fn path_to_expr_help<'a>(
@ -1429,17 +1421,12 @@ fn path_to_expr_help<'a>(
PathInstruction::ListIndex { index } => {
let list_sym = symbol;
let usize_layout = Layout::usize(env.target_info);
if index < &0 {
todo!();
}
match layout {
Layout::Builtin(Builtin::List(elem_layout)) => {
let index_sym = env.unique_symbol();
let index_expr =
Expr::Literal(Literal::Int((*index as i128).to_ne_bytes()));
let (index_sym, new_stores) = build_list_index_probe(env, list_sym, index);
stores.extend(new_stores);
let load_sym = env.unique_symbol();
let load_expr = Expr::Call(Call {
@ -1450,7 +1437,6 @@ fn path_to_expr_help<'a>(
arguments: env.arena.alloc([list_sym, index_sym]),
});
stores.push((index_sym, usize_layout, index_expr));
stores.push((load_sym, *elem_layout, load_expr));
layout = *elem_layout;

View File

@ -7379,6 +7379,100 @@ fn store_pattern_help<'a>(
StorePattern::Productive(stmt)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) struct ListIndex(
/// Positive if we should index from the head, negative if we should index from the tail
/// 0 is lst[0]
/// -1 is lst[List.len lst - 1]
i64,
);
impl ListIndex {
pub fn from_pattern_index(index: usize, arity: ListArity) -> Self {
match arity {
ListArity::Exact(_) => ListIndex::nth_head(index as _),
ListArity::Slice(head, tail) => {
if index < head {
ListIndex::nth_head(index as _)
} else {
// Slice(2, 6)
//
// s t ... w y z x q
// 0 1 2 3 4 5 6 index
// 0 1 2 3 4 (index - head)
// 4 3 2 1 0 tail - (index - head)
ListIndex::nth_tail((tail - (index - head)) as _)
}
}
}
}
fn nth_head(offset: u64) -> Self {
Self(offset as _)
}
fn nth_tail(offset: u64) -> Self {
let offset = offset as i64;
Self(-1 - offset)
}
}
pub(crate) type Store<'a> = (Symbol, Layout<'a>, Expr<'a>);
/// Builds the list index we should index into
#[must_use]
pub(crate) fn build_list_index_probe<'a>(
env: &mut Env<'a, '_>,
list_sym: Symbol,
list_index: &ListIndex,
) -> (Symbol, impl DoubleEndedIterator<Item = Store<'a>>) {
let usize_layout = Layout::usize(env.target_info);
let list_index = list_index.0;
let index_sym = env.unique_symbol();
let (opt_len_store, opt_offset_store, index_store) = if list_index >= 0 {
let index_expr = Expr::Literal(Literal::Int((list_index as i128).to_ne_bytes()));
let index_store = (index_sym, usize_layout, index_expr);
(None, None, index_store)
} else {
let len_sym = env.unique_symbol();
let len_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::ListLen,
update_mode: env.next_update_mode_id(),
},
arguments: env.arena.alloc([list_sym]),
});
let offset = (list_index + 1).abs();
let offset_sym = env.unique_symbol();
let offset_expr = Expr::Literal(Literal::Int((offset as i128).to_ne_bytes()));
let index_expr = Expr::Call(Call {
call_type: CallType::LowLevel {
op: LowLevel::NumSub,
update_mode: env.next_update_mode_id(),
},
arguments: env.arena.alloc([len_sym, offset_sym]),
});
let len_store = (len_sym, usize_layout, len_expr);
let offset_store = (offset_sym, usize_layout, offset_expr);
let index_store = (index_sym, usize_layout, index_expr);
(Some(len_store), Some(offset_store), index_store)
};
let stores = (opt_len_store.into_iter())
.chain(opt_offset_store)
.chain([index_store]);
(index_sym, stores)
}
#[allow(clippy::too_many_arguments)]
fn store_list_pattern<'a>(
env: &mut Env<'a, '_>,
@ -7392,16 +7486,13 @@ fn store_list_pattern<'a>(
) -> StorePattern<'a> {
use Pattern::*;
if matches!(list_arity, ListArity::Slice(_, n) if n > 0) {
todo!();
}
let mut is_productive = false;
let usize_layout = Layout::usize(env.target_info);
for (index, element) in elements.iter().enumerate().rev() {
let index_lit = Expr::Literal(Literal::Int((index as i128).to_ne_bytes()));
let index_sym = env.unique_symbol();
let list_index = ListIndex::from_pattern_index(index, list_arity);
// TODO do this only lazily
let (index_sym, needed_stores) = build_list_index_probe(env, list_sym, &list_index);
let load = Expr::Call(Call {
call_type: CallType::LowLevel {
@ -7454,12 +7545,11 @@ fn store_list_pattern<'a>(
};
is_productive = true;
stmt = Stmt::Let(
index_sym,
index_lit,
usize_layout,
env.arena.alloc(store_loaded),
);
stmt = store_loaded;
for (sym, lay, expr) in needed_stores.rev() {
stmt = Stmt::Let(sym, expr, lay, env.arena.alloc(stmt));
}
}
if is_productive {

View File

@ -3652,28 +3652,61 @@ mod pattern_match {
)
}
#[test]
fn ranged_matches_tail() {
assert_evals_to!(
r#"
helper = \l -> when l is
[] -> 1u8
[A] -> 2u8
[.., A, A] -> 3u8
[.., B, A] -> 4u8
[.., B] -> 5u8
[
helper [],
helper [A],
helper [A, A], helper [A, A, A], helper [B, A, A], helper [A, B, A, A],
helper [B, A], helper [A, B, A], helper [B, B, A], helper [B, A, B, A],
helper [B], helper [A, B], helper [B, B], helper [B, A, B, B],
]
"#,
RocList::from_slice(&[
1, //
2, //
3, 3, 3, 3, //
4, 4, 4, 4, //
5, 5, 5, 5, //
]),
RocList<u8>
)
}
#[test]
fn bind_variables() {
assert_evals_to!(
r#"
helper : List U8 -> U8
helper : List U16 -> U16
helper = \l -> when l is
[] -> 1u8
[] -> 1
[x] -> x
[.., w, x, y, z] -> w * x * y * z
[x, y, ..] -> x * y
[
helper [],
helper [5],
helper [3, 5], helper [3, 5, 7], helper [3, 5, 7, 11],
helper [3, 5], helper [3, 5, 7],
helper [2, 3, 5, 7], helper [11, 2, 3, 5, 7], helper [13, 11, 2, 3, 5, 7],
]
"#,
RocList::from_slice(&[
1, //
5, //
15, 15, 15 //
15, 15, //
210, 210, 210, //
]),
RocList<u8>
RocList<u16>
)
}
}