diff --git a/compiler/can/src/annotation.rs b/compiler/can/src/annotation.rs index 09395d3e3f..72f2f89f8f 100644 --- a/compiler/can/src/annotation.rs +++ b/compiler/can/src/annotation.rs @@ -32,6 +32,20 @@ impl<'a> NamedOrAbleVariable<'a> { NamedOrAbleVariable::Able(av) => av.first_seen, } } + + pub fn name(&self) -> &Lowercase { + match self { + NamedOrAbleVariable::Named(nv) => &nv.name, + NamedOrAbleVariable::Able(av) => &av.name, + } + } + + pub fn variable(&self) -> Variable { + match self { + NamedOrAbleVariable::Named(nv) => nv.variable, + NamedOrAbleVariable::Able(av) => av.variable, + } + } } /// A named type variable, not bound to an ability. @@ -148,19 +162,13 @@ impl IntroducedVariables { .map(|(_, var)| var) } + pub fn iter_named(&self) -> impl Iterator { + (self.named.iter().map(NamedOrAbleVariable::Named)) + .chain(self.able.iter().map(NamedOrAbleVariable::Able)) + } + pub fn named_var_by_name(&self, name: &Lowercase) -> Option { - if let Some(nav) = self - .named - .iter() - .find(|nv| &nv.name == name) - .map(NamedOrAbleVariable::Named) - { - return Some(nav); - } - self.able - .iter() - .find(|av| &av.name == name) - .map(NamedOrAbleVariable::Able) + self.iter_named().find(|v| v.name() == name) } pub fn collect_able(&self) -> Vec { diff --git a/compiler/constrain/src/expr.rs b/compiler/constrain/src/expr.rs index c9b41c1da8..1b820bf762 100644 --- a/compiler/constrain/src/expr.rs +++ b/compiler/constrain/src/expr.rs @@ -1684,18 +1684,18 @@ fn instantiate_rigids( let mut new_rigid_variables: Vec = Vec::new(); let mut rigid_substitution: MutMap = MutMap::default(); - for named in introduced_vars.named.iter() { + for named in introduced_vars.iter_named() { use std::collections::hash_map::Entry::*; - match ftv.entry(named.name.clone()) { + match ftv.entry(named.name().clone()) { Occupied(occupied) => { let existing_rigid = occupied.get(); - rigid_substitution.insert(named.variable, *existing_rigid); + rigid_substitution.insert(named.variable(), *existing_rigid); } Vacant(vacant) => { // It's possible to use this rigid in nested defs - vacant.insert(named.variable); - new_rigid_variables.push(named.variable); + vacant.insert(named.variable()); + new_rigid_variables.push(named.variable()); } } } diff --git a/compiler/mono/src/ir.rs b/compiler/mono/src/ir.rs index 5bb784d3c4..099f5f8dfc 100644 --- a/compiler/mono/src/ir.rs +++ b/compiler/mono/src/ir.rs @@ -4750,6 +4750,7 @@ fn get_specialization<'a>( symbol: Symbol, ) -> Option { use roc_solve::ability::type_implementing_member; + use roc_solve::solve::instantiate_rigids; use roc_unify::unify::unify; match env.abilities_store.member_def(symbol) { @@ -4759,6 +4760,7 @@ fn get_specialization<'a>( } Some(member) => { let snapshot = env.subs.snapshot(); + instantiate_rigids(env.subs, member.signature_var); let (_, must_implement_ability) = unify( env.subs, symbol_var, diff --git a/compiler/test_gen/src/gen_abilities.rs b/compiler/test_gen/src/gen_abilities.rs index 384ba85b36..0fa3f804ab 100644 --- a/compiler/test_gen/src/gen_abilities.rs +++ b/compiler/test_gen/src/gen_abilities.rs @@ -88,6 +88,31 @@ fn alias_member_specialization() { #[test] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn ability_constrained_in_non_member_usage() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ result ] to "./platform" + + Hash has + hash : a -> U64 | a has Hash + + mulHashes : a, a -> U64 | a has Hash + mulHashes = \x, y -> hash x * hash y + + Id := U64 + hash = \$Id n -> n + + result = mulHashes ($Id 5) ($Id 7) + "# + ), + 35, + u64 + ) +} + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] +fn ability_constrained_in_non_member_usage_inferred() { assert_evals_to!( indoc!( r#" @@ -112,6 +137,34 @@ fn ability_constrained_in_non_member_usage() { #[test] #[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] fn ability_constrained_in_non_member_multiple_specializations() { + assert_evals_to!( + indoc!( + r#" + app "test" provides [ result ] to "./platform" + + Hash has + hash : a -> U64 | a has Hash + + mulHashes : a, b -> U64 | a has Hash, b has Hash + mulHashes = \x, y -> hash x * hash y + + Id := U64 + hash = \$Id n -> n + + Three := {} + hash = \$Three _ -> 3 + + result = mulHashes ($Id 100) ($Three {}) + "# + ), + 300, + u64 + ) +} + +#[test] +#[cfg(any(feature = "gen-llvm", feature = "gen-wasm"))] +fn ability_constrained_in_non_member_multiple_specializations_inferred() { assert_evals_to!( indoc!( r#"