Implement List.find

`List.find : List elem, (elem -> Bool) -> Result elem [ NotFound ]*`
behaves as follows:

```
>>> List.find [1, 2, 3] (\n -> n > 2)
Ok 2
>>> List.find [1, 2, 3] (\n -> n > 4)
Err NotFound
```

We implement this as builtin in two phases. First, we call out to a
pure-llvm-lowlevel `ListFindUnsafe` that returns a record indicating
whether a satisfying element was found, and the value of that element
(the value is all null bytes if the element wasn't found). Then, we lift
that record to a `Result` via a standard construction of the can AST.

Closes #1909
This commit is contained in:
ayazhafiz 2021-11-07 20:56:46 -05:00
parent 35df58c18f
commit f65b174ab5
16 changed files with 417 additions and 8 deletions

View File

@ -1333,3 +1333,39 @@ inline fn listSetImmutable(
//return list; //return list;
return new_bytes; return new_bytes;
} }
pub fn listFindUnsafe(
list: RocList,
caller: Caller1,
data: Opaque,
inc_n_data: IncN,
data_is_owned: bool,
alignment: u32,
element_width: usize,
inc: Inc,
dec: Dec,
) callconv(.C) extern struct { value: Opaque, found: bool } {
if (list.bytes) |source_ptr| {
const size = list.len();
if (data_is_owned) {
inc_n_data(data, size);
}
var i: usize = 0;
while (i < size) : (i += 1) {
var theOne = false;
const element = source_ptr + (i * element_width);
inc(element);
caller(data, element, @ptrCast(?[*]u8, &theOne));
if (theOne) {
return .{ .value = element, .found = true };
} else {
dec(element);
}
}
return .{ .value = null, .found = false };
} else {
return .{ .value = null, .found = false };
}
}

View File

@ -52,6 +52,7 @@ comptime {
exportListFn(list.listSetInPlace, "set_in_place"); exportListFn(list.listSetInPlace, "set_in_place");
exportListFn(list.listSwap, "swap"); exportListFn(list.listSwap, "swap");
exportListFn(list.listAny, "any"); exportListFn(list.listAny, "any");
exportListFn(list.listFindUnsafe, "find_unsafe");
} }
// Dict Module // Dict Module

View File

@ -690,3 +690,7 @@ all : List elem, (elem -> Bool) -> Bool
## Run the given predicate on each element of the list, returning `True` if ## Run the given predicate on each element of the list, returning `True` if
## any of the elements satisfy it. ## any of the elements satisfy it.
any : List elem, (elem -> Bool) -> Bool any : List elem, (elem -> Bool) -> Bool
## Returns the first element of the list satisfying a predicate function.
## If no satisfying element is found, an `Err NotFound` is returned.
find : List elem, (elem -> Bool) -> Result elem [ NotFound ]*

View File

@ -190,6 +190,7 @@ pub const LIST_CONCAT: &str = "roc_builtins.list.concat";
pub const LIST_SET: &str = "roc_builtins.list.set"; pub const LIST_SET: &str = "roc_builtins.list.set";
pub const LIST_SET_IN_PLACE: &str = "roc_builtins.list.set_in_place"; pub const LIST_SET_IN_PLACE: &str = "roc_builtins.list.set_in_place";
pub const LIST_ANY: &str = "roc_builtins.list.any"; pub const LIST_ANY: &str = "roc_builtins.list.any";
pub const LIST_FIND_UNSAFE: &str = "roc_builtins.list.find_unsafe";
pub const DEC_FROM_F64: &str = "roc_builtins.dec.from_f64"; pub const DEC_FROM_F64: &str = "roc_builtins.dec.from_f64";
pub const DEC_EQ: &str = "roc_builtins.dec.eq"; pub const DEC_EQ: &str = "roc_builtins.dec.eq";

View File

@ -1086,6 +1086,23 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
Box::new(list_type(flex(TVAR1))), Box::new(list_type(flex(TVAR1))),
); );
// find : List elem, (elem -> Bool) -> Result elem [ NotFound ]*
{
let not_found = SolvedType::TagUnion(
vec![(TagName::Global("NotFound".into()), vec![])],
Box::new(SolvedType::Wildcard),
);
let (elem, cvar) = (TVAR1, TVAR2);
add_top_level_function_type!(
Symbol::LIST_FIND,
vec![
list_type(flex(elem)),
closure(vec![flex(elem)], cvar, Box::new(bool_type())),
],
Box::new(result_type(flex(elem), not_found)),
)
}
// Dict module // Dict module
// len : Dict * * -> Nat // len : Dict * * -> Nat

View File

@ -107,6 +107,8 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option<Def>
LIST_WALK_UNTIL => list_walk_until, LIST_WALK_UNTIL => list_walk_until,
LIST_SORT_WITH => list_sort_with, LIST_SORT_WITH => list_sort_with,
LIST_ANY => list_any, LIST_ANY => list_any,
LIST_FIND => list_find,
DICT_TEST_HASH => dict_hash_test_only,
DICT_LEN => dict_len, DICT_LEN => dict_len,
DICT_EMPTY => dict_empty, DICT_EMPTY => dict_empty,
DICT_SINGLE => dict_single, DICT_SINGLE => dict_single,
@ -2724,6 +2726,92 @@ fn list_any(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::ListAny, var_store) lowlevel_2(symbol, LowLevel::ListAny, var_store)
} }
/// List.find : List elem, (elem -> Bool) -> Result elem [ NotFound ]*
fn list_find(symbol: Symbol, var_store: &mut VarStore) -> Def {
let list = Symbol::ARG_1;
let find_predicate = Symbol::ARG_2;
let find_result = Symbol::LIST_FIND_RESULT;
let t_list = var_store.fresh();
let t_pred_fn = var_store.fresh();
let t_bool = var_store.fresh();
let t_found = var_store.fresh();
let t_value = var_store.fresh();
let t_ret = var_store.fresh();
let t_find_result = var_store.fresh();
let t_ext_var1 = var_store.fresh();
let t_ext_var2 = var_store.fresh();
// ListFindUnsafe returns { value: elem, found: Bool }.
// When `found` is true, the value was found. Otherwise `List.find` should return `Err ...`
let find_result_def = Def {
annotation: None,
expr_var: t_find_result,
loc_expr: no_region(RunLowLevel {
op: LowLevel::ListFindUnsafe,
args: vec![(t_list, Var(list)), (t_pred_fn, Var(find_predicate))],
ret_var: t_find_result,
}),
loc_pattern: no_region(Pattern::Identifier(find_result)),
pattern_vars: Default::default(),
};
let get_value = Access {
record_var: t_find_result,
ext_var: t_ext_var1,
field_var: t_value,
loc_expr: Box::new(no_region(Var(find_result))),
field: "value".into(),
};
let get_found = Access {
record_var: t_find_result,
ext_var: t_ext_var2,
field_var: t_found,
loc_expr: Box::new(no_region(Var(find_result))),
field: "found".into(),
};
let make_ok = tag("Ok", vec![get_value], var_store);
let make_err = tag(
"Err",
vec![tag("NotFound", Vec::new(), var_store)],
var_store,
);
let inspect = If {
cond_var: t_bool,
branch_var: t_ret,
branches: vec![(
// if-condition
no_region(get_found),
no_region(make_ok),
)],
final_else: Box::new(no_region(make_err)),
};
let body = LetNonRec(
Box::new(find_result_def),
Box::new(no_region(inspect)),
t_ret,
);
defn(
symbol,
vec![(t_list, Symbol::ARG_1), (t_pred_fn, Symbol::ARG_2)],
var_store,
body,
t_ret,
)
}
/// Dict.hashTestOnly : k, v -> Nat
fn dict_hash_test_only(symbol: Symbol, var_store: &mut VarStore) -> Def {
lowlevel_2(symbol, LowLevel::Hash, var_store)
}
/// Dict.len : Dict * * -> Nat /// Dict.len : Dict * * -> Nat
fn dict_len(symbol: Symbol, var_store: &mut VarStore) -> Def { fn dict_len(symbol: Symbol, var_store: &mut VarStore) -> Def {
let arg1_var = var_store.fresh(); let arg1_var = var_store.fresh();

View File

@ -9,10 +9,10 @@ use crate::llvm::build_dict::{
use crate::llvm::build_hash::generic_hash; use crate::llvm::build_hash::generic_hash;
use crate::llvm::build_list::{ use crate::llvm::build_list::{
self, allocate_list, empty_list, empty_polymorphic_list, list_any, list_append, list_concat, self, allocate_list, empty_list, empty_polymorphic_list, list_any, list_append, list_concat,
list_contains, list_drop, list_drop_at, list_get_unsafe, list_join, list_keep_errs, list_contains, list_drop, list_drop_at, list_find_trivial_not_found, list_find_unsafe,
list_keep_if, list_keep_oks, list_len, list_map, list_map2, list_map3, list_map4, list_get_unsafe, list_join, list_keep_errs, list_keep_if, list_keep_oks, list_len, list_map,
list_map_with_index, list_prepend, list_range, list_repeat, list_reverse, list_set, list_map2, list_map3, list_map4, list_map_with_index, list_prepend, list_range, list_repeat,
list_single, list_sort_with, list_swap, list_take_first, list_reverse, list_set, list_single, list_sort_with, list_swap, list_take_first,
}; };
use crate::llvm::build_str::{ use crate::llvm::build_str::{
empty_str, str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int, empty_str, str_concat, str_count_graphemes, str_ends_with, str_from_float, str_from_int,
@ -4887,6 +4887,37 @@ fn run_higher_order_low_level<'a, 'ctx, 'env>(
_ => unreachable!("invalid list layout"), _ => unreachable!("invalid list layout"),
} }
} }
ListFindUnsafe { xs } => {
let (list, list_layout) = load_symbol_and_layout(scope, &xs);
let (function, closure, closure_layout) = function_details!();
match list_layout {
Layout::Builtin(Builtin::EmptyList) => {
// Returns { found: False, elem: \empty }, where the `elem` field is zero-sized.
// NB: currently we never hit this case, since the only caller of this
// lowlevel, namely List.find, will fail during monomorphization when there is no
// concrete list element type. This is because List.find returns a
// `Result elem [ NotFound ]*`, and we can't figure out the size of that if
// `elem` is not concrete.
list_find_trivial_not_found(env)
}
Layout::Builtin(Builtin::List(element_layout)) => {
let argument_layouts = &[**element_layout];
let roc_function_call = roc_function_call(
env,
layout_ids,
function,
closure,
closure_layout,
function_owns_closure_data,
argument_layouts,
);
list_find_unsafe(env, layout_ids, roc_function_call, list, element_layout)
}
_ => unreachable!("invalid list layout"),
}
}
DictWalk { xs, state } => { DictWalk { xs, state } => {
let (dict, dict_layout) = load_symbol_and_layout(scope, &xs); let (dict, dict_layout) = load_symbol_and_layout(scope, &xs);
let (default, default_layout) = load_symbol_and_layout(scope, &state); let (default, default_layout) = load_symbol_and_layout(scope, &state);
@ -5757,7 +5788,9 @@ fn run_low_level<'a, 'ctx, 'env>(
ListMap | ListMap2 | ListMap3 | ListMap4 | ListMapWithIndex | ListKeepIf | ListWalk ListMap | ListMap2 | ListMap3 | ListMap4 | ListMapWithIndex | ListKeepIf | ListWalk
| ListWalkUntil | ListWalkBackwards | ListKeepOks | ListKeepErrs | ListSortWith | ListWalkUntil | ListWalkBackwards | ListKeepOks | ListKeepErrs | ListSortWith
| ListAny | DictWalk => unreachable!("these are higher order, and are handled elsewhere"), | ListAny | ListFindUnsafe | DictWalk => {
unreachable!("these are higher order, and are handled elsewhere")
}
} }
} }

View File

@ -936,6 +936,123 @@ pub fn list_any<'a, 'ctx, 'env>(
) )
} }
/// List.findUnsafe : List elem, (elem -> Bool) -> { value: elem, found: bool }
pub fn list_find_unsafe<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
layout_ids: &mut LayoutIds<'a>,
roc_function_call: RocFunctionCall<'ctx>,
list: BasicValueEnum<'ctx>,
element_layout: &Layout<'a>,
) -> BasicValueEnum<'ctx> {
let inc_element_fn = build_inc_wrapper(env, layout_ids, element_layout);
let dec_element_fn = build_dec_wrapper(env, layout_ids, element_layout);
// { value: *const u8, found: bool }
let result = call_bitcode_fn(
env,
&[
pass_list_cc(env, list),
roc_function_call.caller.into(),
pass_as_opaque(env, roc_function_call.data),
roc_function_call.inc_n_data.into(),
roc_function_call.data_is_owned.into(),
env.alignment_intvalue(element_layout),
layout_width(env, element_layout),
inc_element_fn.as_global_value().as_pointer_value().into(),
dec_element_fn.as_global_value().as_pointer_value().into(),
],
bitcode::LIST_FIND_UNSAFE,
)
.into_struct_value();
// We promised the caller we'd give them back a struct containing the element
// loaded on the stack, so we do that now. The element can't be loaded directly
// in the Zig definition called above, because we don't know the size of the
// element until user compile time, which is later than the compile time of bitcode defs.
let value_u8_ptr = env
.builder
.build_extract_value(result, 0, "get_value_ptr")
.unwrap()
.into_pointer_value();
let found = env
.builder
.build_extract_value(result, 1, "get_found")
.unwrap()
.into_int_value();
let start_block = env.builder.get_insert_block().unwrap();
let parent = start_block.get_parent().unwrap();
let if_not_null = env.context.append_basic_block(parent, "if_not_null");
let done_block = env.context.append_basic_block(parent, "done");
let value_bt = basic_type_from_layout(env, element_layout);
let default = value_bt.const_zero();
env.builder
.build_conditional_branch(found, if_not_null, done_block);
env.builder.position_at_end(if_not_null);
let value_ptr = env
.builder
.build_bitcast(
value_u8_ptr,
value_bt.ptr_type(AddressSpace::Generic),
"from_opaque",
)
.into_pointer_value();
let loaded = env.builder.build_load(value_ptr, "load_value");
env.builder.build_unconditional_branch(done_block);
env.builder.position_at_end(done_block);
let result_phi = env.builder.build_phi(value_bt, "result");
result_phi.add_incoming(&[(&default, start_block), (&loaded, if_not_null)]);
let value = result_phi.as_basic_value();
let result = env
.context
.struct_type(&[value_bt, env.context.bool_type().into()], false)
.const_zero();
let result = env
.builder
.build_insert_value(result, value, 0, "insert_value")
.unwrap();
env.builder
.build_insert_value(result, found, 1, "insert_found")
.unwrap()
.into_struct_value()
.into()
}
/// Returns { value: \empty, found: False }, representing that no element was found in a call
/// to List.find when the layout of the element is also unknown.
pub fn list_find_trivial_not_found<'a, 'ctx, 'env>(
env: &Env<'a, 'ctx, 'env>,
) -> BasicValueEnum<'ctx> {
let empty_type = env.context.custom_width_int_type(0);
let result = env
.context
.struct_type(&[empty_type.into(), env.context.bool_type().into()], false)
.const_zero();
env.builder
.build_insert_value(
result,
env.context.bool_type().const_zero(),
1,
"insert_found",
)
.unwrap()
.into_struct_value()
.into()
}
pub fn decrementing_elem_loop<'ctx, LoopFn>( pub fn decrementing_elem_loop<'ctx, LoopFn>(
builder: &Builder<'ctx>, builder: &Builder<'ctx>,
ctx: &'ctx Context, ctx: &'ctx Context,

View File

@ -47,6 +47,7 @@ pub enum LowLevel {
ListDropAt, ListDropAt,
ListSwap, ListSwap,
ListAny, ListAny,
ListFindUnsafe,
DictSize, DictSize,
DictEmpty, DictEmpty,
DictInsert, DictInsert,
@ -225,6 +226,7 @@ macro_rules! higher_order {
| ListKeepErrs | ListKeepErrs
| ListSortWith | ListSortWith
| ListAny | ListAny
| ListFindUnsafe
| DictWalk | DictWalk
}; };
} }
@ -259,6 +261,7 @@ impl LowLevel {
ListKeepErrs => 1, ListKeepErrs => 1,
ListSortWith => 1, ListSortWith => 1,
ListAny => 1, ListAny => 1,
ListFindUnsafe => 1,
DictWalk => 2, DictWalk => 2,
} }
} }

View File

@ -1066,6 +1066,8 @@ define_builtins! {
43 LIST_JOIN_MAP_CONCAT: "#joinMapConcat" 43 LIST_JOIN_MAP_CONCAT: "#joinMapConcat"
44 LIST_ANY: "any" 44 LIST_ANY: "any"
45 LIST_TAKE_FIRST: "takeFirst" 45 LIST_TAKE_FIRST: "takeFirst"
46 LIST_FIND: "find"
47 LIST_FIND_RESULT: "#find_result" // symbol used in the definition of List.find
} }
5 RESULT: "Result" => { 5 RESULT: "Result" => {
0 RESULT_RESULT: "Result" imported // the Result.Result type alias 0 RESULT_RESULT: "Result" imported // the Result.Result type alias

View File

@ -1093,6 +1093,42 @@ fn call_spec(
add_loop(builder, block, state_type, init_state, loop_body) add_loop(builder, block, state_type, init_state, loop_body)
} }
ListFindUnsafe { xs } => {
let list = env.symbols[xs];
// Mark the list as being used by the "find" predicate function.
// It may be the case that all elements in the list are used by the predicate.
// Since `bag_get` assumes items are picked non-deterministically, this is
// (probably?) enough to express that usage.
let bag = builder.add_get_tuple_field(block, list, LIST_BAG_INDEX)?;
let cell = builder.add_get_tuple_field(block, list, LIST_CELL_INDEX)?;
let element = builder.add_bag_get(block, bag)?;
let _bool = call_function!(builder, block, [element]);
// ListFindUnsafe returns { value: v, found: Bool=Int1 }
let output_layouts =
vec![arg_layouts[0].clone(), Layout::Builtin(Builtin::Int1)];
let output_layout = Layout::Struct(&output_layouts);
let output_type = layout_spec(builder, &output_layout)?;
// We may or may not use the element we got from the list in the output struct,
// depending on whether we found the element to satisfy the "find" predicate.
let found_branch = builder.add_block();
let output_with_element =
builder.add_unknown_with(found_branch, &[element], output_type)?;
let not_found_branch = builder.add_block();
let output_without_element =
builder.add_unknown_with(not_found_branch, &[], output_type)?;
builder.add_choice(
block,
&[
BlockExpr(found_branch, output_with_element),
BlockExpr(not_found_branch, output_without_element),
],
)
}
} }
} }
} }

View File

@ -618,7 +618,8 @@ impl<'a> BorrowInfState<'a> {
| ListKeepIf { xs } | ListKeepIf { xs }
| ListKeepOks { xs } | ListKeepOks { xs }
| ListKeepErrs { xs } | ListKeepErrs { xs }
| ListAny { xs } => { | ListAny { xs }
| ListFindUnsafe { xs } => {
// own the list if the function wants to own the element // own the list if the function wants to own the element
if !function_ps[0].borrow { if !function_ps[0].borrow {
self.own_var(*xs); self.own_var(*xs);
@ -959,6 +960,7 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
arena.alloc_slice_copy(&[owned, owned, function, closure_data]) arena.alloc_slice_copy(&[owned, owned, function, closure_data])
} }
ListSortWith => arena.alloc_slice_copy(&[owned, function, closure_data]), ListSortWith => arena.alloc_slice_copy(&[owned, function, closure_data]),
ListFindUnsafe => arena.alloc_slice_copy(&[owned, function, closure_data]),
// TODO when we have lists with capacity (if ever) // TODO when we have lists with capacity (if ever)
// List.append should own its first argument // List.append should own its first argument

View File

@ -531,7 +531,8 @@ impl<'a> Context<'a> {
| ListKeepIf { xs } | ListKeepIf { xs }
| ListKeepOks { xs } | ListKeepOks { xs }
| ListKeepErrs { xs } | ListKeepErrs { xs }
| ListAny { xs } => { | ListAny { xs }
| ListFindUnsafe { xs } => {
let borrows = [function_ps[0].borrow, FUNCTION, CLOSURE_DATA]; let borrows = [function_ps[0].borrow, FUNCTION, CLOSURE_DATA];
let b = self.add_dec_after_lowlevel(arguments, &borrows, b, b_live_vars); let b = self.add_dec_after_lowlevel(arguments, &borrows, b, b_live_vars);

View File

@ -4164,6 +4164,11 @@ pub fn with_hole<'a>(
match_on_closure_argument!(ListMap4, [xs, ys, zs, ws]) match_on_closure_argument!(ListMap4, [xs, ys, zs, ws])
} }
ListFindUnsafe => {
debug_assert_eq!(arg_symbols.len(), 2);
let xs = arg_symbols[0];
match_on_closure_argument!(ListFindUnsafe, [xs])
}
_ => { _ => {
let call = self::Call { let call = self::Call {
call_type: CallType::LowLevel { call_type: CallType::LowLevel {

View File

@ -50,6 +50,9 @@ pub enum HigherOrder {
ListAny { ListAny {
xs: Symbol, xs: Symbol,
}, },
ListFindUnsafe {
xs: Symbol,
},
DictWalk { DictWalk {
xs: Symbol, xs: Symbol,
state: Symbol, state: Symbol,
@ -71,6 +74,7 @@ impl HigherOrder {
HigherOrder::ListKeepOks { .. } => 1, HigherOrder::ListKeepOks { .. } => 1,
HigherOrder::ListKeepErrs { .. } => 1, HigherOrder::ListKeepErrs { .. } => 1,
HigherOrder::ListSortWith { .. } => 2, HigherOrder::ListSortWith { .. } => 2,
HigherOrder::ListFindUnsafe { .. } => 1,
HigherOrder::DictWalk { .. } => 2, HigherOrder::DictWalk { .. } => 2,
HigherOrder::ListAny { .. } => 1, HigherOrder::ListAny { .. } => 1,
} }

View File

@ -2281,7 +2281,7 @@ fn list_join_map() {
RocStr::from_slice("cyrus".as_bytes()), RocStr::from_slice("cyrus".as_bytes()),
]), ]),
RocList<RocStr> RocList<RocStr>
); )
} }
#[test] #[test]
@ -2294,5 +2294,64 @@ fn list_join_map_empty() {
), ),
RocList::from_slice(&[]), RocList::from_slice(&[]),
RocList<RocStr> RocList<RocStr>
)
}
#[test]
fn list_find() {
assert_evals_to!(
indoc!(
r#"
when List.find ["a", "bc", "def"] (\s -> Str.countGraphemes s > 1) is
Ok v -> v
Err _ -> "not found"
"#
),
RocStr::from_slice(b"bc"),
RocStr
);
}
#[test]
fn list_find_not_found() {
assert_evals_to!(
indoc!(
r#"
when List.find ["a", "bc", "def"] (\s -> Str.countGraphemes s > 5) is
Ok v -> v
Err _ -> "not found"
"#
),
RocStr::from_slice(b"not found"),
RocStr
);
}
#[test]
fn list_find_empty_typed_list() {
assert_evals_to!(
indoc!(
r#"
when List.find [] (\s -> Str.countGraphemes s > 5) is
Ok v -> v
Err _ -> "not found"
"#
),
RocStr::from_slice(b"not found"),
RocStr
);
}
#[test]
#[ignore = "Fails because monomorphization can't be done if we don't have a concrete element type!"]
fn list_find_empty_layout() {
assert_evals_to!(
indoc!(
r#"
List.find [] (\_ -> True)
"#
),
0,
i64
); );
} }