diff --git a/ast/src/solve_type.rs b/ast/src/solve_type.rs index 60ce284cd0..9c87d37801 100644 --- a/ast/src/solve_type.rs +++ b/ast/src/solve_type.rs @@ -231,6 +231,7 @@ fn solve<'a>( Success { vars, must_implement_ability: _, + lambda_sets_to_specialize: _, // TODO ignored } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); @@ -328,6 +329,7 @@ fn solve<'a>( Success { vars, must_implement_ability: _, + lambda_sets_to_specialize: _, // TODO ignored } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); @@ -403,6 +405,7 @@ fn solve<'a>( Success { vars, must_implement_ability: _, + lambda_sets_to_specialize: _, // TODO ignored } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); @@ -715,6 +718,7 @@ fn solve<'a>( Success { vars, must_implement_ability: _, + lambda_sets_to_specialize: _, // TODO ignored } => { // TODO(abilities) record deferred ability checks introduce(subs, rank, pools, &vars); diff --git a/compiler/can/src/abilities.rs b/compiler/can/src/abilities.rs index 980222beba..732a46d41b 100644 --- a/compiler/can/src/abilities.rs +++ b/compiler/can/src/abilities.rs @@ -50,10 +50,29 @@ impl AbilityMemberData { pub type SolvedSpecializations = VecMap<(Symbol, Symbol), MemberSpecialization>; /// A particular specialization of an ability member. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, Clone)] pub struct MemberSpecialization { pub symbol: Symbol, - pub region: Region, + + /// Solved lambda sets for an ability member specialization. For example, if we have + /// + /// Default has default : {} -[[] + a:default:1]-> a | a has Default + /// + /// A := {} + /// default = \{} -[[closA]]-> @A {} + /// + /// and this [MemberSpecialization] is for `A`, then there is a mapping of + /// `1` to the variable representing `[[closA]]`. + specialization_lambda_sets: VecMap, +} + +impl MemberSpecialization { + pub fn new(symbol: Symbol, specialization_lambda_sets: VecMap) -> Self { + Self { + symbol, + specialization_lambda_sets, + } + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -199,13 +218,13 @@ impl AbilitiesStore { /// "ability member" has a "specialization" for type "type". pub fn iter_specializations( &self, - ) -> impl Iterator + '_ { - self.declared_specializations.iter().map(|(k, v)| (*k, *v)) + ) -> impl Iterator + '_ { + self.declared_specializations.iter().map(|(k, v)| (*k, v)) } /// Retrieves the specialization of `member` for `typ`, if it exists. - pub fn get_specialization(&self, member: Symbol, typ: Symbol) -> Option { - self.declared_specializations.get(&(member, typ)).copied() + pub fn get_specialization(&self, member: Symbol, typ: Symbol) -> Option<&MemberSpecialization> { + self.declared_specializations.get(&(member, typ)) } pub fn members_of_ability(&self, ability: Symbol) -> Option<&[Symbol]> { @@ -298,9 +317,9 @@ impl AbilitiesStore { declared_specializations .iter() .filter(|((member, _), _)| members.contains(member)) - .for_each(|(&(member, typ), &specialization)| { + .for_each(|(&(member, typ), specialization)| { new.register_specializing_symbol(specialization.symbol, member); - new.register_specialization_for_type(member, typ, specialization); + new.register_specialization_for_type(member, typ, specialization.clone()); }); } @@ -338,9 +357,10 @@ impl AbilitiesStore { for ((member, typ), specialization) in declared_specializations.into_iter() { let old_specialization = self .declared_specializations - .insert((member, typ), specialization); + .insert((member, typ), specialization.clone()); debug_assert!( - old_specialization.is_none() || old_specialization.unwrap() == specialization + old_specialization.is_none() + || old_specialization.unwrap().symbol == specialization.symbol ); } @@ -360,4 +380,15 @@ impl AbilitiesStore { _ => internal_error!("{:?} is not imported!", member), } } + + pub fn get_specializaton_lambda_set( + &self, + opaque: Symbol, + ability_member: Symbol, + region: u8, + ) -> Option { + self.get_specialization(ability_member, opaque) + .and_then(|spec| spec.specialization_lambda_sets.get(®ion)) + .copied() + } } diff --git a/compiler/load_internal/src/file.rs b/compiler/load_internal/src/file.rs index 53b538434f..1c2d52b066 100644 --- a/compiler/load_internal/src/file.rs +++ b/compiler/load_internal/src/file.rs @@ -3528,7 +3528,11 @@ impl<'a> BuildTask<'a> { .. } = exposed; for ((member, typ), specialization) in solved_specializations.iter() { - abilities_store.register_specialization_for_type(*member, *typ, *specialization); + abilities_store.register_specialization_for_type( + *member, + *typ, + specialization.clone(), + ); } } @@ -3685,6 +3689,7 @@ fn run_solve_solve( // module. member.module_id() == module_id || typ.module_id() == module_id }) + .map(|(key, specialization)| (key, specialization.clone())) .collect(); let is_specialization_symbol = diff --git a/compiler/solve/src/ability.rs b/compiler/solve/src/ability.rs index 7754876ffb..a4b1349dbc 100644 --- a/compiler/solve/src/ability.rs +++ b/compiler/solve/src/ability.rs @@ -625,6 +625,7 @@ pub fn resolve_ability_specialization( .member_def(ability_member) .expect("Not an ability member symbol"); + // Figure out the ability we're resolving in a temporary subs snapshot. let snapshot = subs.snapshot(); let signature_var = member_def @@ -632,8 +633,8 @@ pub fn resolve_ability_specialization( .unwrap_or_else(|| internal_error!("Signature var not resolved for {:?}", ability_member)); instantiate_rigids(subs, signature_var); - let (_, must_implement_ability) = unify(subs, specialization_var, signature_var, Mode::EQ) - .expect_success( + let (_vars, must_implement_ability, _lambda_sets_to_specialize) = + unify(subs, specialization_var, signature_var, Mode::EQ).expect_success( "If resolving a specialization, the specialization must be known to typecheck.", ); diff --git a/compiler/solve/src/solve.rs b/compiler/solve/src/solve.rs index 003b292a7c..692aa5c13b 100644 --- a/compiler/solve/src/solve.rs +++ b/compiler/solve/src/solve.rs @@ -9,6 +9,7 @@ use roc_can::constraint::{Constraints, Cycle, LetConstraint, OpportunisticResolv use roc_can::expected::{Expected, PExpected}; use roc_can::expr::PendingDerives; use roc_collections::all::MutMap; +use roc_collections::{VecMap, VecSet}; use roc_debug_flags::dbg_do; #[cfg(debug_assertions)] use roc_debug_flags::ROC_VERIFY_RIGID_LET_GENERALIZED; @@ -19,9 +20,9 @@ use roc_problem::can::CycleEntry; use roc_region::all::{Loc, Region}; use roc_types::solved_types::Solved; use roc_types::subs::{ - self, AliasVariables, Content, Descriptor, FlatType, GetSubsSlice, Mark, OptVariable, Rank, - RecordFields, Subs, SubsIndex, SubsSlice, UnionLabels, UnionLambdas, UnionTags, Variable, - VariableSubsSlice, + self, AliasVariables, Content, Descriptor, FlatType, GetSubsSlice, LambdaSet, Mark, + OptVariable, Rank, RecordFields, Subs, SubsIndex, SubsSlice, UlsOfVar, UnionLabels, + UnionLambdas, UnionTags, Variable, VariableSubsSlice, }; use roc_types::types::Type::{self, *}; use roc_types::types::{ @@ -527,6 +528,10 @@ fn run_in_place( let pending_derives = PendingDerivesTable::new(subs, aliases, pending_derives); let mut deferred_obligations = DeferredObligations::new(pending_derives); + // Because we don't know what ability specializations are available until the entire module is + // solved, we must wait to solve unspecialized lambda sets then. + let mut deferred_uls_to_resolve = UlsOfVar::default(); + let state = solve( &arena, constraints, @@ -539,6 +544,7 @@ fn run_in_place( constraint, abilities_store, &mut deferred_obligations, + &mut deferred_uls_to_resolve, ); // Now that the module has been solved, we can run through and check all @@ -547,6 +553,8 @@ fn run_in_place( let (obligation_problems, _derived) = deferred_obligations.check_all(subs, abilities_store); problems.extend(obligation_problems); + compact_deferred_lambda_sets(subs, &mut pools, abilities_store, deferred_uls_to_resolve); + state.env } @@ -599,6 +607,7 @@ fn solve( constraint: &Constraint, abilities_store: &mut AbilitiesStore, deferred_obligations: &mut DeferredObligations, + deferred_uls_to_resolve: &mut UlsOfVar, ) -> State { let initial = Work::Constraint { env: &Env::default(), @@ -659,6 +668,7 @@ fn solve( abilities_store, problems, deferred_obligations, + deferred_uls_to_resolve, *symbol, *loc_var, ); @@ -764,6 +774,7 @@ fn solve( abilities_store, problems, deferred_obligations, + deferred_uls_to_resolve, *symbol, *loc_var, ); @@ -815,6 +826,7 @@ fn solve( Success { vars, must_implement_ability, + lambda_sets_to_specialize, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -823,6 +835,7 @@ fn solve( AbilityImplError::BadExpr(*region, category.clone(), actual), ); } + deferred_uls_to_resolve.union(lambda_sets_to_specialize); state } @@ -867,9 +880,12 @@ fn solve( vars, // ERROR NOT REPORTED must_implement_ability: _, + lambda_sets_to_specialize, } => { introduce(subs, rank, pools, &vars); + deferred_uls_to_resolve.union(lambda_sets_to_specialize); + state } Failure(vars, _actual_type, _expected_type, _bad_impls) => { @@ -922,6 +938,7 @@ fn solve( Success { vars, must_implement_ability, + lambda_sets_to_specialize, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -934,6 +951,7 @@ fn solve( ), ); } + deferred_uls_to_resolve.union(lambda_sets_to_specialize); state } @@ -999,6 +1017,7 @@ fn solve( Success { vars, must_implement_ability, + lambda_sets_to_specialize, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -1007,6 +1026,7 @@ fn solve( AbilityImplError::BadPattern(*region, category.clone(), actual), ); } + deferred_uls_to_resolve.union(lambda_sets_to_specialize); state } @@ -1161,6 +1181,7 @@ fn solve( Success { vars, must_implement_ability, + lambda_sets_to_specialize, } => { introduce(subs, rank, pools, &vars); if !must_implement_ability.is_empty() { @@ -1173,6 +1194,7 @@ fn solve( ), ); } + deferred_uls_to_resolve.union(lambda_sets_to_specialize); state } @@ -1265,6 +1287,7 @@ fn solve( Success { vars, must_implement_ability, + lambda_sets_to_specialize, } => { subs.commit_snapshot(snapshot); @@ -1273,6 +1296,8 @@ fn solve( internal_error!("Didn't expect ability vars to land here"); } + deferred_uls_to_resolve.union(lambda_sets_to_specialize); + // Case 1: unify error types, but don't check exhaustiveness. // Case 2: run exhaustiveness to check for redundant branches. should_check_exhaustiveness = !already_have_error; @@ -1493,16 +1518,17 @@ fn check_ability_specialization( abilities_store: &mut AbilitiesStore, problems: &mut Vec, deferred_obligations: &mut DeferredObligations, + deferred_uls_to_resolve: &mut UlsOfVar, symbol: Symbol, symbol_loc_var: Loc, ) { // If the symbol specializes an ability member, we need to make sure that the // inferred type for the specialization actually aligns with the expected // implementation. - if let Some((root_symbol, root_data)) = abilities_store.root_name_and_def(symbol) { - let root_signature_var = root_data - .signature_var() - .unwrap_or_else(|| internal_error!("Signature var not resolved for {:?}", root_symbol)); + if let Some((ability_member, root_data)) = abilities_store.root_name_and_def(symbol) { + let root_signature_var = root_data.signature_var().unwrap_or_else(|| { + internal_error!("Signature var not resolved for {:?}", ability_member) + }); let parent_ability = root_data.parent_ability; // Check if they unify - if they don't, then the claimed specialization isn't really one, @@ -1521,6 +1547,7 @@ fn check_ability_specialization( Success { vars, must_implement_ability, + lambda_sets_to_specialize, } => { let specialization_type = type_implementing_specialization(&must_implement_ability, parent_ability); @@ -1532,13 +1559,20 @@ fn check_ability_specialization( subs.commit_snapshot(snapshot); introduce(subs, rank, pools, &vars); + let (other_lambda_sets_to_specialize, specialization_lambda_sets) = + find_specializaton_lambda_sets( + subs, + opaque, + ability_member, + lambda_sets_to_specialize, + ); + deferred_uls_to_resolve.union(other_lambda_sets_to_specialize); + let specialization_region = symbol_loc_var.region; - let specialization = MemberSpecialization { - symbol, - region: specialization_region, - }; + let specialization = + MemberSpecialization::new(symbol, specialization_lambda_sets); abilities_store.register_specialization_for_type( - root_symbol, + ability_member, opaque, specialization, ); @@ -1568,7 +1602,7 @@ fn check_ability_specialization( region: symbol_loc_var.region, typ, ability: parent_ability, - member: root_symbol, + member: ability_member, }; problems.push(problem); @@ -1586,13 +1620,13 @@ fn check_ability_specialization( let (actual_type, _problems) = subs.var_to_error_type(symbol_loc_var.value); let reason = Reason::GeneralizedAbilityMemberSpecialization { - member_name: root_symbol, + member_name: ability_member, def_region: root_data.region, }; let problem = TypeError::BadExpr( symbol_loc_var.region, - Category::AbilityMemberSpecialization(root_symbol), + Category::AbilityMemberSpecialization(ability_member), actual_type, Expected::ForReason(reason, expected_type, symbol_loc_var.region), ); @@ -1607,14 +1641,14 @@ fn check_ability_specialization( introduce(subs, rank, pools, &vars); let reason = Reason::InvalidAbilityMemberSpecialization { - member_name: root_symbol, + member_name: ability_member, def_region: root_data.region, unimplemented_abilities, }; let problem = TypeError::BadExpr( symbol_loc_var.region, - Category::AbilityMemberSpecialization(root_symbol), + Category::AbilityMemberSpecialization(ability_member), actual_type, Expected::ForReason(reason, expected_type, symbol_loc_var.region), ); @@ -1631,6 +1665,206 @@ fn check_ability_specialization( } } +/// Finds the lambda sets in an ability member specialization. +/// +/// Suppose we have +/// +/// Default has default : {} -[[] + a:default:1]-> a | a has Default +/// +/// A := {} +/// default = \{} -[[closA]]-> @A {} +/// +/// Now after solving the `default` specialization we have unified it with the ability signature, +/// yielding +/// +/// {} -[[closA] + A:default:1]-> A +/// +/// But really, what we want is to only keep around the original lambda sets, and associate +/// `A:default:1` to resolve to the lambda set `[[closA]]`. There might be other unspecialized lambda +/// sets in the lambda sets for this implementation, which we need to account for as well; that is, +/// it may really be `[[closA] + v123:otherAbilityMember:4 + ...]`. +#[inline(always)] +fn find_specializaton_lambda_sets( + subs: &mut Subs, + opaque: Symbol, + ability_member: Symbol, + uls: UlsOfVar, +) -> (UlsOfVar, VecMap) { + // unspecialized lambda sets that don't belong to our specialization, and should be resolved + // later. + let mut leftover_uls = UlsOfVar::default(); + let mut specialization_lambda_sets: VecMap = VecMap::default(); + + for (spec_var, lambda_sets) in uls.drain() { + if !matches!(subs.get_content_without_compacting(spec_var), Content::Alias(name, _, _, AliasKind::Opaque) if *name == opaque) + { + // These lambda sets aren't resolved to the current specialization, they need to be + // solved at a later time. + leftover_uls.extend(spec_var, lambda_sets); + continue; + } + + for lambda_set in lambda_sets { + let &LambdaSet { + solved, + recursion_var, + unspecialized, + } = match subs.get_content_without_compacting(lambda_set) { + Content::LambdaSet(lambda_set) => lambda_set, + _ => internal_error!("Not a lambda set"), + }; + + // Figure out the unspecailized lambda set that corresponds to our specialization + // (`A:default:1` in the example), and those that need to stay part of the lambda set. + let mut split_index_and_region = None; + let uls_slice = subs.get_subs_slice(unspecialized).to_owned(); + for (i, &Uls(var, _sym, region)) in uls_slice.iter().enumerate() { + if var == spec_var { + debug_assert!(split_index_and_region.is_none()); + debug_assert!(_sym == ability_member, "unspecialized lambda set var is the same as the specialization, but points to a different ability member"); + split_index_and_region = Some((i, region)); + } + } + + let (split_index, specialized_lset_region) = + split_index_and_region.expect("no unspecialization lambda set found"); + let (uls_before, uls_after) = + (&uls_slice[0..split_index], &uls_slice[split_index + 1..]); + let new_unspecialized = SubsSlice::extend_new( + &mut subs.unspecialized_lambda_sets, + uls_before.into_iter().chain(uls_after.into_iter()).copied(), + ); + + subs.set_content( + lambda_set, + Content::LambdaSet(LambdaSet { + solved, + recursion_var, + unspecialized: new_unspecialized, + }), + ); + + let old_specialized = + specialization_lambda_sets.insert(specialized_lset_region, lambda_set); + debug_assert!( + old_specialized.is_none(), + "Specialization of lambda set already exists" + ); + } + } + + (leftover_uls, specialization_lambda_sets) +} + +fn compact_deferred_lambda_sets( + subs: &mut Subs, + pools: &mut Pools, + abilities_store: &AbilitiesStore, + uls_of_var: UlsOfVar, +) { + let mut seen = VecSet::default(); + for (_, lambda_sets) in uls_of_var.drain() { + for lset in lambda_sets { + let root_lset = subs.get_root_key_without_compacting(lset); + if seen.contains(&root_lset) { + continue; + } + + compact_lambda_set(subs, pools, abilities_store, root_lset); + seen.insert(root_lset); + } + } +} + +fn compact_lambda_set( + subs: &mut Subs, + pools: &mut Pools, + abilities_store: &AbilitiesStore, + lambda_set: Variable, +) { + let LambdaSet { + solved, + recursion_var, + unspecialized, + } = subs.get_lambda_set(lambda_set); + + if unspecialized.is_empty() { + return; + } + + let mut new_unspecialized = vec![]; + let mut specialized_to_unify_with = Vec::with_capacity(1); + for uls_index in unspecialized.into_iter() { + let uls @ Uls(var, member, region) = subs[uls_index]; + + use Content::*; + let opaque = match subs.get_content_without_compacting(var) { + FlexAbleVar(_, _) => { + /* not specialized yet */ + new_unspecialized.push(uls); + continue; + } + Structure(_) | Alias(_, _, _, AliasKind::Structural) => { + // TODO: figure out a convention for references to structural types in the + // unspecialized lambda set. This may very well happen, for example + // + // Default has default : {} -> a | a has Default + // + // {a, b} = default {} + // # ^^^^^^^ {} -[{a: t1, b: t2}:default:1] + continue; + } + Alias(opaque, _, _, AliasKind::Opaque) => opaque, + Error => { + /* skip */ + continue; + } + RigidVar(..) + | RigidAbleVar(..) + | FlexVar(..) + | RecursionVar { .. } + | LambdaSet(..) + | RangedNumber(_, _) => { + internal_error!("unexpected") + } + }; + + let specialized_lambda_set = abilities_store + .get_specializaton_lambda_set(*opaque, member, region) + .expect("lambda set not resolved, or opaque doesn't specialize"); + + compact_lambda_set(subs, pools, abilities_store, specialized_lambda_set); + + specialized_to_unify_with.push(specialized_lambda_set); + } + + let new_unspecialized_slice = + SubsSlice::extend_new(&mut subs.unspecialized_lambda_sets, new_unspecialized); + let partial_compacted_lambda_set = Content::LambdaSet(LambdaSet { + solved, + recursion_var, + unspecialized: new_unspecialized_slice, + }); + subs.set_content(lambda_set, partial_compacted_lambda_set); + + for other_specialized in specialized_to_unify_with.into_iter() { + let (vars, must_implement_ability, lambda_sets_to_specialize) = + unify(subs, lambda_set, other_specialized, Mode::EQ) + .expect_success("lambda sets don't unify"); + + introduce(subs, subs.get_rank(lambda_set), pools, &vars); + + debug_assert!( + must_implement_ability.is_empty(), + "didn't expect abilities instantiated in this position" + ); + debug_assert!( + lambda_sets_to_specialize.is_empty(), + "didn't expect more lambda sets in this position" + ); + } +} + #[derive(Debug)] enum LocalDefVarsVec { Stack(arrayvec::ArrayVec), diff --git a/compiler/types/src/subs.rs b/compiler/types/src/subs.rs index ab012f130a..df83cf3313 100644 --- a/compiler/types/src/subs.rs +++ b/compiler/types/src/subs.rs @@ -310,21 +310,49 @@ impl Subs { /// Mapping of variables to [Content::LambdaSet]s containing unspecialized lambda sets depending on /// that variable. -#[derive(Clone, Default)] +#[derive(Clone, Default, Debug)] pub struct UlsOfVar(VecMap>); impl UlsOfVar { pub fn add(&mut self, var: Variable, dependent_lambda_set: Variable) -> bool { + // TODO: should we be checking root key here? let set = self.0.get_or_insert(var, Default::default); set.insert(dependent_lambda_set) } + pub fn extend( + &mut self, + var: Variable, + dependent_lambda_sets: impl IntoIterator, + ) { + // TODO: should we be checking root key here? + let set = self.0.get_or_insert(var, Default::default); + set.extend(dependent_lambda_sets); + } + + pub fn union(&mut self, other: Self) { + for (key, lset) in other.drain() { + self.extend(key, lset); + } + } + pub fn remove_dependents( &mut self, var: Variable, ) -> Option> { + // TODO: should we be checking root key here? self.0.remove(&var).map(|(_, v)| v) } + + pub fn drain(self) -> impl Iterator)> { + self.0 + .into_iter() + .map(|(v, set): (Variable, VecSet)| (v, set.into_iter())) + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } } #[derive(Clone)] @@ -1956,6 +1984,13 @@ impl Subs { pub fn vars_since_snapshot(&mut self, snapshot: &Snapshot) -> core::ops::Range { self.utable.vars_since_snapshot(snapshot) } + + pub fn get_lambda_set(&self, lambda_set: Variable) -> LambdaSet { + match self.get_content_without_compacting(lambda_set) { + Content::LambdaSet(lambda_set) => *lambda_set, + _ => internal_error!("not a lambda set"), + } + } } #[inline(always)] diff --git a/compiler/unify/src/unify.rs b/compiler/unify/src/unify.rs index 96addbca58..5306f3fcac 100644 --- a/compiler/unify/src/unify.rs +++ b/compiler/unify/src/unify.rs @@ -9,8 +9,8 @@ use roc_types::num::NumericRange; use roc_types::subs::Content::{self, *}; use roc_types::subs::{ AliasVariables, Descriptor, ErrorTypeContext, FlatType, GetSubsSlice, LambdaSet, Mark, - OptVariable, RecordFields, Subs, SubsIndex, SubsSlice, UnionLabels, UnionLambdas, UnionTags, - Variable, VariableSubsSlice, + OptVariable, RecordFields, Subs, SubsIndex, SubsSlice, UlsOfVar, UnionLabels, UnionLambdas, + UnionTags, Variable, VariableSubsSlice, }; use roc_types::types::{AliasKind, DoesNotImplementAbility, ErrorType, Mismatch, RecordField}; @@ -143,18 +143,23 @@ pub enum Unified { Success { vars: Pool, must_implement_ability: MustImplementConstraints, + lambda_sets_to_specialize: UlsOfVar, }, Failure(Pool, ErrorType, ErrorType, DoesNotImplementAbility), BadType(Pool, roc_types::types::Problem), } impl Unified { - pub fn expect_success(self, err_msg: &'static str) -> (Pool, MustImplementConstraints) { + pub fn expect_success( + self, + err_msg: &'static str, + ) -> (Pool, MustImplementConstraints, UlsOfVar) { match self { Unified::Success { vars, must_implement_ability, - } => (vars, must_implement_ability), + lambda_sets_to_specialize, + } => (vars, must_implement_ability, lambda_sets_to_specialize), _ => internal_error!("{}", err_msg), } } @@ -212,6 +217,9 @@ pub struct Outcome { /// We defer these checks until the end of a solving phase. /// NOTE: this vector is almost always empty! must_implement_ability: MustImplementConstraints, + /// We defer resolution of these lambda sets to the caller of [unify]. + /// See also [merge_flex_able_with_concrete]. + lambda_sets_to_specialize: UlsOfVar, } impl Outcome { @@ -219,6 +227,8 @@ impl Outcome { self.mismatches.extend(other.mismatches); self.must_implement_ability .extend(other.must_implement_ability); + self.lambda_sets_to_specialize + .union(other.lambda_sets_to_specialize); } } @@ -228,12 +238,14 @@ pub fn unify(subs: &mut Subs, var1: Variable, var2: Variable, mode: Mode) -> Uni let Outcome { mismatches, must_implement_ability, + lambda_sets_to_specialize, } = unify_pool(subs, &mut vars, var1, var2, mode); if mismatches.is_empty() { Unified::Success { vars, must_implement_ability, + lambda_sets_to_specialize, } } else { let error_context = if mismatches.contains(&Mismatch::TypeNotInRange) { @@ -450,6 +462,7 @@ fn check_valid_range(subs: &mut Subs, var: Variable, range: NumericRange) -> Out let outcome = Outcome { mismatches: vec![Mismatch::TypeNotInRange], must_implement_ability: Default::default(), + lambda_sets_to_specialize: Default::default(), }; return outcome; @@ -595,12 +608,14 @@ fn unify_opaque( } FlexAbleVar(_, ability) if args.is_empty() => { // Opaque type wins - let mut outcome = merge(subs, ctx, Alias(symbol, args, real_var, kind)); - outcome.must_implement_ability.push(MustImplementAbility { - typ: Obligated::Opaque(symbol), - ability: *ability, - }); - outcome + merge_flex_able_with_concrete( + subs, + ctx, + ctx.second, + *ability, + Alias(symbol, args, real_var, kind), + Obligated::Opaque(symbol), + ) } Alias(_, _, other_real_var, AliasKind::Structural) => { unify_pool(subs, pool, ctx.first, *other_real_var, ctx.mode) @@ -673,13 +688,15 @@ fn unify_structure( outcome } FlexAbleVar(_, ability) => { - let mut outcome = merge(subs, ctx, Structure(*flat_type)); - let must_implement_ability = MustImplementAbility { - typ: Obligated::Adhoc(ctx.first), - ability: *ability, - }; - outcome.must_implement_ability.push(must_implement_ability); - outcome + // Structure wins + merge_flex_able_with_concrete( + subs, + ctx, + ctx.second, + *ability, + Structure(*flat_type), + Obligated::Adhoc(ctx.first), + ) } // _name has an underscore because it's unused in --release builds RigidVar(_name) => { @@ -2028,6 +2045,7 @@ fn unify_flex( merge(subs, ctx, FlexAbleVar(opt_name, *ability)) } + // TODO: not accounting for ability bound here! RigidVar(_) | RigidAbleVar(_, _) | RecursionVar { .. } @@ -2093,12 +2111,14 @@ fn unify_flex_able( Alias(name, args, _real_var, AliasKind::Opaque) => { if args.is_empty() { // Opaque type wins - let mut outcome = merge(subs, ctx, *other); - outcome.must_implement_ability.push(MustImplementAbility { - typ: Obligated::Opaque(*name), + merge_flex_able_with_concrete( + subs, + ctx, + ctx.first, ability, - }); - outcome + *other, + Obligated::Opaque(*name), + ) } else { mismatch!("FlexAble vs Opaque with type vars") } @@ -2106,18 +2126,51 @@ fn unify_flex_able( Structure(_) | Alias(_, _, _, AliasKind::Structural) | RangedNumber(..) => { // Structural type wins. - let mut outcome = merge(subs, ctx, *other); - outcome.must_implement_ability.push(MustImplementAbility { - typ: Obligated::Adhoc(ctx.second), + merge_flex_able_with_concrete( + subs, + ctx, + ctx.first, ability, - }); - outcome + *other, + Obligated::Adhoc(ctx.second), + ) } Error => merge(subs, ctx, Error), } } +fn merge_flex_able_with_concrete( + subs: &mut Subs, + ctx: &Context, + flex_able_var: Variable, + ability: Symbol, + concrete_content: Content, + concrete_obligation: Obligated, +) -> Outcome { + let mut outcome = merge(subs, ctx, concrete_content); + let must_implement_ability = MustImplementAbility { + typ: concrete_obligation, + ability, + }; + outcome.must_implement_ability.push(must_implement_ability); + + // Figure which, if any, lambda sets should be specialized thanks to the flex able var + // being instantiated. Now as much as I would love to do that here, we don't, because we might + // be in the middle of solving a module and not resolved all available ability implementations + // yet! Instead we chuck it up in the [Outcome] and let our caller do the resolution. + // + // If we ever organize ability implementations so that they are well-known before any other + // unification is done, they can be solved in-band here! + if let Some(uls_of_concrete) = subs.uls_of_var.remove_dependents(flex_able_var) { + outcome + .lambda_sets_to_specialize + .extend(flex_able_var, uls_of_concrete); + } + + outcome +} + #[inline(always)] fn unify_recursion( subs: &mut Subs,