Bugfix handle more specialization instances

This commit is contained in:
Ayaz Hafiz 2022-05-05 11:12:50 -04:00
parent de924de266
commit 19e8b37402
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -770,6 +770,26 @@ impl<'a> Procs<'a> {
needed_symbol_specializations: BumpMap::new_in(arena),
}
}
/// Expects and removes a single specialization symbol for the given requested symbol.
fn remove_single_symbol_specialization(&mut self, symbol: Symbol) -> Option<Symbol> {
let mut specialized_symbols = self
.needed_symbol_specializations
.drain_filter(|(sym, _), _| sym == &symbol);
let specialization_symbol = specialized_symbols
.next()
.map(|(_, (_, specialized_symbol))| specialized_symbol);
debug_assert_eq!(
specialized_symbols.count(),
0,
"Symbol {:?} has multiple specializations",
symbol
);
specialization_symbol
}
}
#[derive(Clone, Debug, PartialEq)]
@ -2468,11 +2488,11 @@ fn specialize_external<'a>(
// An argument from the closure list may have taken on a specialized symbol
// name during the evaluation of the def body. If this is the case, load the
// specialized name rather than the original captured name!
let get_specialized_name = |symbol, layout| {
let mut get_specialized_name = |symbol, layout| {
procs
.needed_symbol_specializations
.get(&(symbol, layout))
.map(|(_, specialized)| *specialized)
.remove(&(symbol, layout))
.map(|(_, specialized)| specialized)
.unwrap_or(symbol)
};
@ -3304,7 +3324,7 @@ pub fn with_hole<'a>(
} else {
// this may be a destructure pattern
let (mono_pattern, assignments) =
match from_can_pattern(env, layout_cache, &def.loc_pattern.value) {
match from_can_pattern(env, procs, layout_cache, &def.loc_pattern.value) {
Ok(v) => v,
Err(_runtime_error) => {
// todo
@ -5492,6 +5512,7 @@ pub fn from_can<'a>(
}
LetNonRec(def, cont, outer_annotation) => {
if let roc_can::pattern::Pattern::Identifier(symbol) = &def.loc_pattern.value {
// dbg!(symbol, &def.loc_expr.value);
match def.loc_expr.value {
roc_can::expr::Expr::Closure(closure_data) => {
register_capturing_closure(env, procs, layout_cache, *symbol, closure_data);
@ -5706,7 +5727,7 @@ pub fn from_can<'a>(
// this may be a destructure pattern
let (mono_pattern, assignments) =
match from_can_pattern(env, layout_cache, &def.loc_pattern.value) {
match from_can_pattern(env, procs, layout_cache, &def.loc_pattern.value) {
Ok(v) => v,
Err(_) => todo!(),
};
@ -5737,8 +5758,22 @@ pub fn from_can<'a>(
// layer on any default record fields
for (symbol, variable, expr) in assignments {
let specialization_symbol = procs
.remove_single_symbol_specialization(symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(symbol);
let hole = env.arena.alloc(stmt);
stmt = with_hole(env, expr, variable, procs, layout_cache, symbol, hole);
stmt = with_hole(
env,
expr,
variable,
procs,
layout_cache,
specialization_symbol,
hole,
);
}
if let roc_can::expr::Expr::Var(outer_symbol) = def.loc_expr.value {
@ -5772,6 +5807,7 @@ pub fn from_can<'a>(
fn to_opt_branches<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
branches: std::vec::Vec<roc_can::expr::WhenBranch>,
exhaustive_mark: ExhaustiveMark,
layout_cache: &mut LayoutCache<'a>,
@ -5798,7 +5834,7 @@ fn to_opt_branches<'a>(
}
for loc_pattern in when_branch.patterns {
match from_can_pattern(env, layout_cache, &loc_pattern.value) {
match from_can_pattern(env, procs, layout_cache, &loc_pattern.value) {
Ok((mono_pattern, assignments)) => {
loc_branches.push((
Loc::at(loc_pattern.region, mono_pattern.clone()),
@ -5876,7 +5912,7 @@ fn from_can_when<'a>(
// We can't know what to return!
return Stmt::RuntimeError("Hit a 0-branch when expression");
}
let opt_branches = to_opt_branches(env, branches, exhaustive_mark, layout_cache);
let opt_branches = to_opt_branches(env, procs, branches, exhaustive_mark, layout_cache);
let cond_layout =
return_on_layout_error!(env, layout_cache.from_var(env.arena, cond_var, env.subs));
@ -6341,7 +6377,15 @@ fn store_pattern_help<'a>(
match can_pat {
Identifier(symbol) => {
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol);
// An identifier in a pattern can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(env.arena, &mut stmt, specialization_symbol, outer_symbol);
}
Underscore => {
// do nothing
@ -6402,7 +6446,18 @@ fn store_pattern_help<'a>(
for destruct in destructs {
match &destruct.typ {
DestructType::Required(symbol) => {
substitute_in_exprs(env.arena, &mut stmt, *symbol, outer_symbol);
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
substitute_in_exprs(
env.arena,
&mut stmt,
specialization_symbol,
outer_symbol,
);
}
DestructType::Guard(guard_pattern) => {
return store_pattern_help(
@ -6480,10 +6535,11 @@ fn store_tag_pattern<'a>(
match argument {
Identifier(symbol) => {
// TODO: use procs.remove_single_symbol_specialization
let symbol = procs
.needed_symbol_specializations
.get(&(*symbol, arg_layout))
.map(|(_, sym)| *sym)
.remove(&(*symbol, arg_layout))
.map(|(_, sym)| sym)
.unwrap_or(*symbol);
// store immediately in the given symbol
@ -6562,8 +6618,19 @@ fn store_newtype_pattern<'a>(
match argument {
Identifier(symbol) => {
// store immediately in the given symbol
stmt = Stmt::Let(*symbol, load, arg_layout, env.arena.alloc(stmt));
// store immediately in the given symbol, removing it specialization if it had any
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
arg_layout,
env.arena.alloc(stmt),
);
is_productive = true;
}
Underscore => {
@ -6625,11 +6692,35 @@ fn store_record_destruct<'a>(
match &destruct.typ {
DestructType::Required(symbol) => {
stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt));
// A destructure can define at most one specialization!
// Remove any requested specializations for this name now, since this is the definition site.
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
destruct.layout,
env.arena.alloc(stmt),
);
}
DestructType::Guard(guard_pattern) => match &guard_pattern {
Identifier(symbol) => {
stmt = Stmt::Let(*symbol, load, destruct.layout, env.arena.alloc(stmt));
let specialization_symbol = procs
.remove_single_symbol_specialization(*symbol)
// Can happen when the symbol was never used under this body, and hence has no
// requested specialization.
.unwrap_or(*symbol);
stmt = Stmt::Let(
specialization_symbol,
load,
destruct.layout,
env.arena.alloc(stmt),
);
}
Underscore => {
// important that this is special-cased to do nothing: mono record patterns will extract all the
@ -6816,48 +6907,53 @@ where
return build_rest(env, procs, layout_cache);
}
// Otherwise we're dealing with an alias to something that doesn't need to be specialized, or
// whose usages will already be specialized in the rest of the program.
if procs.is_imported_module_thunk(right) {
let result = build_rest(env, procs, layout_cache);
if procs.partial_procs.contains_key(right) {
// This is an alias to a function defined in this module.
// Attach the alias, then build the rest of the module, so that we reference and specialize
// the correct proc.
procs.partial_procs.insert_alias(left, right);
return build_rest(env, procs, layout_cache);
}
// Otherwise we're dealing with an alias whose usages will tell us what specializations we
// need. So let's figure those out first.
let result = build_rest(env, procs, layout_cache);
// The specializations we wanted of the symbol on the LHS of this alias.
let needed_specializations_of_left = procs
.needed_symbol_specializations
.drain_filter(|(s, _), _| s == &left)
.collect::<std::vec::Vec<_>>();
if procs.is_imported_module_thunk(right) {
// if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer.
add_needed_external(procs, env, variable, right);
let mut result = result;
for (_, (variable, left)) in needed_specializations_of_left.into_iter() {
add_needed_external(procs, env, variable, right);
let res_layout = layout_cache.from_var(env.arena, variable, env.subs);
let layout = return_on_layout_error!(env, res_layout);
let res_layout = layout_cache.from_var(env.arena, variable, env.subs);
let layout = return_on_layout_error!(env, res_layout);
force_thunk(env, right, layout, left, env.arena.alloc(result))
result = force_thunk(env, right, layout, left, env.arena.alloc(result));
}
result
} else if env.is_imported_symbol(right) {
let result = build_rest(env, procs, layout_cache);
// if this is an imported symbol, then we must make sure it is
// specialized, and wrap the original in a function pointer.
add_needed_external(procs, env, variable, right);
// then we must construct its closure; since imported symbols have no closure, we use the empty struct
let_empty_struct(left, env.arena.alloc(result))
} else if procs.partial_procs.contains_key(right) {
// This is an alias to a function defined in this module.
// Attach the alias, then build the rest of the module, so that we reference and specialize
// the correct proc.
procs.partial_procs.insert_alias(left, right);
build_rest(env, procs, layout_cache)
} else {
// This should be a fully specialized value. Replace the alias with the original symbol.
let mut result = build_rest(env, procs, layout_cache);
// We need to lift all specializations of "left" to be specializations of "right".
let to_update = procs
.needed_symbol_specializations
.drain_filter(|(s, _), _| s == &left)
.collect::<std::vec::Vec<_>>();
let mut scratchpad_update_specializations = std::vec::Vec::new();
let left_had_specialization_symbols = !to_update.is_empty();
let left_had_specialization_symbols = !needed_specializations_of_left.is_empty();
for ((_, layout), (specialized_var, specialized_sym)) in to_update.into_iter() {
for ((_, layout), (specialized_var, specialized_sym)) in
needed_specializations_of_left.into_iter()
{
let old_specialized_sym = procs
.needed_symbol_specializations
.insert((right, layout), (specialized_var, specialized_sym));
@ -6867,6 +6963,7 @@ where
}
}
let mut result = result;
if left_had_specialization_symbols {
// If the symbol is specialized, only the specializations need to be updated.
for (old_specialized_sym, specialized_sym) in
@ -7894,6 +7991,7 @@ pub struct WhenBranch<'a> {
#[allow(clippy::type_complexity)]
fn from_can_pattern<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
can_pattern: &roc_can::pattern::Pattern,
) -> Result<
@ -7904,13 +8002,14 @@ fn from_can_pattern<'a>(
RuntimeError,
> {
let mut assignments = Vec::new_in(env.arena);
let pattern = from_can_pattern_help(env, layout_cache, can_pattern, &mut assignments)?;
let pattern = from_can_pattern_help(env, procs, layout_cache, can_pattern, &mut assignments)?;
Ok((pattern, assignments))
}
fn from_can_pattern_help<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
can_pattern: &roc_can::pattern::Pattern,
assignments: &mut Vec<'a, (Symbol, Variable, roc_can::expr::Expr)>,
@ -8105,7 +8204,13 @@ fn from_can_pattern_help<'a>(
let mut mono_args = Vec::with_capacity_in(arguments.len(), env.arena);
for ((_, loc_pat), layout) in arguments.iter().zip(field_layouts.iter()) {
mono_args.push((
from_can_pattern_help(env, layout_cache, &loc_pat.value, assignments)?,
from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
)?,
*layout,
));
}
@ -8183,6 +8288,7 @@ fn from_can_pattern_help<'a>(
mono_args.push((
from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
@ -8228,6 +8334,7 @@ fn from_can_pattern_help<'a>(
mono_args.push((
from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
@ -8271,6 +8378,7 @@ fn from_can_pattern_help<'a>(
mono_args.push((
from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
@ -8344,6 +8452,7 @@ fn from_can_pattern_help<'a>(
mono_args.push((
from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
@ -8400,6 +8509,7 @@ fn from_can_pattern_help<'a>(
mono_args.push((
from_can_pattern_help(
env,
procs,
layout_cache,
&loc_pat.value,
assignments,
@ -8430,8 +8540,13 @@ fn from_can_pattern_help<'a>(
let arg_layout = layout_cache
.from_var(env.arena, *arg_var, env.subs)
.unwrap();
let mono_arg_pattern =
from_can_pattern_help(env, layout_cache, &loc_arg_pattern.value, assignments)?;
let mono_arg_pattern = from_can_pattern_help(
env,
procs,
layout_cache,
&loc_arg_pattern.value,
assignments,
)?;
Ok(Pattern::OpaqueUnwrap {
opaque: *opaque,
argument: Box::new((mono_arg_pattern, arg_layout)),
@ -8474,6 +8589,7 @@ fn from_can_pattern_help<'a>(
// this field is destructured by the pattern
mono_destructs.push(from_can_record_destruct(
env,
procs,
layout_cache,
&destruct.value,
field_layout,
@ -8565,6 +8681,7 @@ fn from_can_pattern_help<'a>(
fn from_can_record_destruct<'a>(
env: &mut Env<'a, '_>,
procs: &mut Procs<'a>,
layout_cache: &mut LayoutCache<'a>,
can_rd: &roc_can::pattern::RecordDestruct,
field_layout: Layout<'a>,
@ -8581,7 +8698,7 @@ fn from_can_record_destruct<'a>(
DestructType::Required(can_rd.symbol)
}
roc_can::pattern::DestructType::Guard(_, loc_pattern) => DestructType::Guard(
from_can_pattern_help(env, layout_cache, &loc_pattern.value, assignments)?,
from_can_pattern_help(env, procs, layout_cache, &loc_pattern.value, assignments)?,
),
},
})