Merge pull request #5260 from roc-lang/subs-occur-cache

Do not revisit variables in an occurs check
This commit is contained in:
Ayaz 2023-04-08 07:12:03 -05:00 committed by GitHub
commit e5d13bfc22
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,7 +25,8 @@ roc_error_macros::assert_sizeof_all!(RecordFields, 2 * 8);
pub struct Mark(i32);
impl Mark {
pub const NONE: Mark = Mark(2);
pub const NONE: Mark = Mark(3);
pub const VISITED_IN_OCCURS_CHECK: Mark = Mark(2);
pub const OCCURS: Mark = Mark(1);
pub const GET_VAR_NAMES: Mark = Mark(0);
@ -1996,9 +1997,15 @@ impl Subs {
///
/// This ignores [Content::RecursionVar]s that occur recursively, because those are
/// already priced in and expected to occur.
pub fn occurs(&self, var: Variable) -> Result<(), (Variable, Vec<Variable>)> {
///
/// Although `subs` is taken as mutable reference, this function will return it in the same
/// state it was given.
pub fn occurs(&mut self, var: Variable) -> Result<(), (Variable, Vec<Variable>)> {
let mut scratchpad = take_occurs_scratchpad();
let result = occurs(self, &mut scratchpad, var);
for v in &scratchpad.all_visited {
self.set_mark_unchecked(*v, Mark::NONE);
}
put_occurs_scratchpad(scratchpad);
result
}
@ -3434,15 +3441,34 @@ impl TupleElems {
}
}
std::thread_local! {
static SCRATCHPAD_FOR_OCCURS: RefCell<Option<Vec<Variable>>> = RefCell::new(Some(Vec::with_capacity(1024)));
struct OccursScratchpad {
seen: Vec<Variable>,
all_visited: Vec<Variable>,
}
fn take_occurs_scratchpad() -> Vec<Variable> {
impl OccursScratchpad {
fn new_static() -> Self {
Self {
seen: Vec::with_capacity(1024),
all_visited: Vec::with_capacity(1024),
}
}
fn clear(&mut self) {
self.seen.clear();
self.all_visited.clear();
}
}
std::thread_local! {
static SCRATCHPAD_FOR_OCCURS: RefCell<Option<OccursScratchpad>> = RefCell::new(Some(OccursScratchpad::new_static()));
}
fn take_occurs_scratchpad() -> OccursScratchpad {
SCRATCHPAD_FOR_OCCURS.with(|f| f.take().unwrap())
}
fn put_occurs_scratchpad(mut scratchpad: Vec<Variable>) {
fn put_occurs_scratchpad(mut scratchpad: OccursScratchpad) {
SCRATCHPAD_FOR_OCCURS.with(|f| {
scratchpad.clear();
f.replace(Some(scratchpad));
@ -3450,20 +3476,36 @@ fn put_occurs_scratchpad(mut scratchpad: Vec<Variable>) {
}
fn occurs(
subs: &Subs,
seen: &mut Vec<Variable>,
subs: &mut Subs,
ctx: &mut OccursScratchpad,
input_var: Variable,
) -> Result<(), (Variable, Vec<Variable>)> {
// NB(subs-invariant): it is pivotal that subs is not modified in any material way.
// As variables are visited, they are marked as observed so they are not revisited,
// but no other modification should take place.
use self::Content::*;
use self::FlatType::*;
let root_var = subs.get_root_key_without_compacting(input_var);
if seen.contains(&root_var) {
// SAFETY: due to XREF(subs-invariant), only the mark in a variable is modified, and all
// variable (and other content) identities are guaranteed to be preserved during an occurs
// check. As a result, we can freely take references of variables and UnionTags.
macro_rules! safe {
($t:ty, $expr:expr) => {
unsafe { std::mem::transmute::<_, &'static $t>($expr) }
};
}
if ctx.seen.contains(&root_var) {
Err((root_var, Vec::with_capacity(0)))
} else if subs.get_mark_unchecked(root_var) == Mark::VISITED_IN_OCCURS_CHECK {
Ok(())
} else {
seen.push(root_var);
let result = (|| match subs.get_content_without_compacting(root_var) {
ctx.seen.push(root_var);
ctx.all_visited.push(root_var);
let result = (|| match subs.get_content_unchecked(root_var) {
FlexVar(_)
| RigidVar(_)
| FlexAbleVar(_, _)
@ -3472,47 +3514,57 @@ fn occurs(
| Error => Ok(()),
Structure(flat_type) => match flat_type {
Apply(_, args) => {
short_circuit(subs, root_var, seen, subs.get_subs_slice(*args).iter())
}
Apply(_, args) => short_circuit(
subs,
root_var,
ctx,
safe!([Variable], subs.get_subs_slice(*args)).iter(),
),
Func(arg_vars, closure_var, ret_var) => {
let it = once(ret_var)
.chain(once(closure_var))
.chain(subs.get_subs_slice(*arg_vars).iter());
short_circuit(subs, root_var, seen, it)
let it = once(safe!(Variable, ret_var))
.chain(once(safe!(Variable, closure_var)))
.chain(safe!([Variable], subs.get_subs_slice(*arg_vars)).iter());
short_circuit(subs, root_var, ctx, it)
}
Record(vars_by_field, ext) => {
let slice = SubsSlice::new(vars_by_field.variables_start, vars_by_field.length);
let it = once(ext).chain(subs.get_subs_slice(slice).iter());
short_circuit(subs, root_var, seen, it)
let slice =
VariableSubsSlice::new(vars_by_field.variables_start, vars_by_field.length);
let it = once(safe!(Variable, ext))
.chain(safe!([Variable], subs.get_subs_slice(slice)).iter());
short_circuit(subs, root_var, ctx, it)
}
Tuple(vars_by_elem, ext) => {
let slice = SubsSlice::new(vars_by_elem.variables_start, vars_by_elem.length);
let it = once(ext).chain(subs.get_subs_slice(slice).iter());
short_circuit(subs, root_var, seen, it)
let slice =
VariableSubsSlice::new(vars_by_elem.variables_start, vars_by_elem.length);
let it = once(safe!(Variable, ext))
.chain(safe!([Variable], subs.get_subs_slice(slice)).iter());
short_circuit(subs, root_var, ctx, it)
}
TagUnion(tags, ext) => {
occurs_union(subs, root_var, seen, tags)?;
let ext_var = ext.var();
occurs_union(subs, root_var, ctx, safe!(UnionLabels<TagName>, tags))?;
short_circuit_help(subs, root_var, seen, ext.var())
short_circuit_help(subs, root_var, ctx, ext_var)
}
FunctionOrTagUnion(_, _, ext) => {
short_circuit(subs, root_var, seen, once(&ext.var()))
short_circuit(subs, root_var, ctx, once(&ext.var()))
}
RecursiveTagUnion(_, tags, ext) => {
occurs_union(subs, root_var, seen, tags)?;
let ext_var = ext.var();
occurs_union(subs, root_var, ctx, safe!(UnionLabels<TagName>, tags))?;
short_circuit_help(subs, root_var, seen, ext.var())
short_circuit_help(subs, root_var, ctx, ext_var)
}
EmptyRecord | EmptyTuple | EmptyTagUnion => Ok(()),
},
Alias(_, args, real_var, _) => {
let real_var = *real_var;
for var_index in args.into_iter() {
let var = subs[var_index];
if short_circuit_help(subs, root_var, seen, var).is_err() {
if short_circuit_help(subs, root_var, ctx, var).is_err() {
// Pay the cost and figure out what the actual recursion point is
return short_circuit_help(subs, root_var, seen, *real_var);
return short_circuit_help(subs, root_var, ctx, real_var);
}
}
@ -3527,27 +3579,33 @@ fn occurs(
// unspecialized lambda vars excluded because they are not explicitly part of the
// type (they only matter after being resolved).
occurs_union(subs, root_var, seen, solved)
occurs_union(subs, root_var, ctx, safe!(UnionLabels<Symbol>, solved))
}
RangedNumber(_range_vars) => Ok(()),
})();
seen.pop();
// Cache the variable's property of having no cycle, but only if it indeed has no cycle.
if result.is_ok() {
subs.set_mark_unchecked(root_var, Mark::VISITED_IN_OCCURS_CHECK);
}
ctx.seen.pop();
result
}
}
#[inline(always)]
fn occurs_union<L: Label>(
subs: &Subs,
subs: &mut Subs,
root_var: Variable,
seen: &mut Vec<Variable>,
ctx: &mut OccursScratchpad,
tags: &UnionLabels<L>,
) -> Result<(), (Variable, Vec<Variable>)> {
for slice_index in tags.variables() {
let slice = subs[slice_index];
for var_index in slice {
let var = subs[var_index];
short_circuit_help(subs, root_var, seen, var)?;
short_circuit_help(subs, root_var, ctx, var)?;
}
}
Ok(())
@ -3555,16 +3613,16 @@ fn occurs_union<L: Label>(
#[inline(always)]
fn short_circuit<'a, T>(
subs: &Subs,
subs: &mut Subs,
root_key: Variable,
seen: &mut Vec<Variable>,
ctx: &mut OccursScratchpad,
iter: T,
) -> Result<(), (Variable, Vec<Variable>)>
where
T: Iterator<Item = &'a Variable>,
{
for var in iter {
short_circuit_help(subs, root_key, seen, *var)?;
short_circuit_help(subs, root_key, ctx, *var)?;
}
Ok(())
@ -3572,12 +3630,12 @@ where
#[inline(always)]
fn short_circuit_help(
subs: &Subs,
subs: &mut Subs,
root_key: Variable,
seen: &mut Vec<Variable>,
ctx: &mut OccursScratchpad,
var: Variable,
) -> Result<(), (Variable, Vec<Variable>)> {
if let Err((v, mut vec)) = occurs(subs, seen, var) {
if let Err((v, mut vec)) = occurs(subs, ctx, var) {
vec.push(root_key);
return Err((v, vec));
}