mirror of
https://github.com/roc-lang/roc.git
synced 2024-11-10 10:02:38 +03:00
Switch obligation checking to use a visitor
My hope is this will make obligation checking for other abilities easier to add when we do so
This commit is contained in:
parent
fa14146054
commit
a7bc8cf4f2
@ -5,6 +5,7 @@ use roc_error_macros::internal_error;
|
||||
use roc_module::symbol::Symbol;
|
||||
use roc_region::all::{Loc, Region};
|
||||
use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled};
|
||||
use roc_types::num::NumericRange;
|
||||
use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable};
|
||||
use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory};
|
||||
use roc_unify::unify::{Env, MustImplementConstraints};
|
||||
@ -253,7 +254,12 @@ impl ObligationCache {
|
||||
// independent queries.
|
||||
|
||||
let opt_can_derive_builtin = match ability {
|
||||
Symbol::ENCODE_ENCODING => Some(self.can_derive_encoding(subs, abilities_store, var)),
|
||||
Symbol::ENCODE_ENCODING => Some(DeriveEncoding::is_derivable(
|
||||
self,
|
||||
abilities_store,
|
||||
subs,
|
||||
var,
|
||||
)),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
@ -262,7 +268,7 @@ impl ObligationCache {
|
||||
// can derive!
|
||||
None
|
||||
}
|
||||
Some(Err(failure_var)) => Some(if failure_var == var {
|
||||
Some(Err(DerivableError::NotDerivable(failure_var))) => Some(if failure_var == var {
|
||||
UnderivableReason::SurfaceNotDerivable
|
||||
} else {
|
||||
let (error_type, _skeletons) = subs.var_to_error_type(failure_var);
|
||||
@ -391,16 +397,128 @@ impl ObligationCache {
|
||||
let check_has_fake = self.derive_cache.insert(derive_key, root_result);
|
||||
debug_assert_eq!(check_has_fake, Some(fake_fulfilled));
|
||||
}
|
||||
}
|
||||
|
||||
// If we have a lot of these, consider using a visitor.
|
||||
// It will be very similar for most types (can't derive functions, can't derive unbound type
|
||||
// variables, can only derive opaques if they have an impl, etc).
|
||||
fn can_derive_encoding(
|
||||
&mut self,
|
||||
subs: &mut Subs,
|
||||
#[inline(always)]
|
||||
#[rustfmt::skip]
|
||||
fn is_builtin_number_alias(symbol: Symbol) -> bool {
|
||||
matches!(symbol,
|
||||
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
|
||||
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
|
||||
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
|
||||
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
|
||||
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
|
||||
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
|
||||
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
|
||||
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
|
||||
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
|
||||
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
|
||||
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
|
||||
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
|
||||
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
|
||||
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
|
||||
)
|
||||
}
|
||||
|
||||
enum DerivableError {
|
||||
NotDerivable(Variable),
|
||||
}
|
||||
|
||||
struct Descend(bool);
|
||||
|
||||
trait DerivableVisitor {
|
||||
const ABILITY: Symbol;
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_flex(var: Variable) -> Result<(), DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_rigid(var: Variable) -> Result<(), DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
|
||||
if ability != Self::ABILITY {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
|
||||
if ability != Self::ABILITY {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_recursion(var: Variable) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_apply(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_func(var: Variable) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_record(var: Variable) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_tag_union(var: Variable) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_recursive_tag_union(var: Variable) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_function_or_tag_union(var: Variable) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_empty_record(var: Variable) -> Result<(), DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_empty_tag_union(var: Variable) -> Result<(), DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_alias(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn is_derivable(
|
||||
obligation_cache: &mut ObligationCache,
|
||||
abilities_store: &AbilitiesStore,
|
||||
subs: &Subs,
|
||||
var: Variable,
|
||||
) -> Result<(), Variable> {
|
||||
) -> Result<(), DerivableError> {
|
||||
let mut stack = vec![var];
|
||||
let mut seen_recursion_vars = vec![];
|
||||
|
||||
@ -418,102 +536,93 @@ impl ObligationCache {
|
||||
let content = subs.get_content_without_compacting(var);
|
||||
|
||||
use Content::*;
|
||||
use DerivableError::*;
|
||||
use FlatType::*;
|
||||
match content {
|
||||
FlexVar(_) | RigidVar(_) => return Err(var),
|
||||
FlexAbleVar(_, ability) | RigidAbleVar(_, ability) => {
|
||||
if *ability != Symbol::ENCODE_ENCODING {
|
||||
return Err(var);
|
||||
}
|
||||
// Any concrete type this variables is instantiated with will also gain a "does
|
||||
// implement" check so this is okay.
|
||||
}
|
||||
match *content {
|
||||
FlexVar(_) => Self::visit_flex(var)?,
|
||||
RigidVar(_) => Self::visit_rigid(var)?,
|
||||
FlexAbleVar(_, ability) => Self::visit_flex_able(var, ability)?,
|
||||
RigidAbleVar(_, ability) => Self::visit_rigid_able(var, ability)?,
|
||||
RecursionVar {
|
||||
structure,
|
||||
opt_name: _,
|
||||
} => {
|
||||
seen_recursion_vars.push(var);
|
||||
stack.push(*structure);
|
||||
let descend = Self::visit_recursion(var)?;
|
||||
if descend.0 {
|
||||
seen_recursion_vars.push(var);
|
||||
stack.push(structure);
|
||||
}
|
||||
}
|
||||
Structure(flat_type) => match flat_type {
|
||||
Apply(
|
||||
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
|
||||
vars,
|
||||
) => push_var_slice!(*vars),
|
||||
Apply(..) => return Err(var),
|
||||
Func(..) => {
|
||||
return Err(var);
|
||||
}
|
||||
Record(fields, var) => {
|
||||
push_var_slice!(fields.variables());
|
||||
stack.push(*var);
|
||||
}
|
||||
TagUnion(tags, ext_var) => {
|
||||
for i in tags.variables() {
|
||||
push_var_slice!(subs[i]);
|
||||
Apply(symbol, vars) => {
|
||||
let descend = Self::visit_apply(var, symbol)?;
|
||||
if descend.0 {
|
||||
push_var_slice!(vars)
|
||||
}
|
||||
stack.push(*ext_var);
|
||||
}
|
||||
FunctionOrTagUnion(_, _, var) => stack.push(*var),
|
||||
RecursiveTagUnion(rec_var, tags, ext_var) => {
|
||||
seen_recursion_vars.push(*rec_var);
|
||||
for i in tags.variables() {
|
||||
push_var_slice!(subs[i]);
|
||||
Func(args, _clos, ret) => {
|
||||
let descend = Self::visit_func(var)?;
|
||||
if descend.0 {
|
||||
push_var_slice!(args);
|
||||
stack.push(ret);
|
||||
}
|
||||
stack.push(*ext_var);
|
||||
}
|
||||
EmptyRecord | EmptyTagUnion => {
|
||||
// yes
|
||||
Record(fields, ext) => {
|
||||
let descend = Self::visit_record(var)?;
|
||||
if descend.0 {
|
||||
push_var_slice!(fields.variables());
|
||||
stack.push(ext);
|
||||
}
|
||||
}
|
||||
Erroneous(_) => return Err(var),
|
||||
TagUnion(tags, ext) => {
|
||||
let descend = Self::visit_tag_union(var)?;
|
||||
if descend.0 {
|
||||
for i in tags.variables() {
|
||||
push_var_slice!(subs[i]);
|
||||
}
|
||||
stack.push(ext);
|
||||
}
|
||||
}
|
||||
FunctionOrTagUnion(_tag_name, _fn_name, ext) => {
|
||||
let descend = Self::visit_function_or_tag_union(var)?;
|
||||
if descend.0 {
|
||||
stack.push(ext);
|
||||
}
|
||||
}
|
||||
RecursiveTagUnion(rec, tags, ext) => {
|
||||
let descend = Self::visit_recursive_tag_union(var)?;
|
||||
if descend.0 {
|
||||
seen_recursion_vars.push(rec);
|
||||
for i in tags.variables() {
|
||||
push_var_slice!(subs[i]);
|
||||
}
|
||||
stack.push(ext);
|
||||
}
|
||||
}
|
||||
EmptyRecord => Self::visit_empty_record(var)?,
|
||||
EmptyTagUnion => Self::visit_empty_tag_union(var)?,
|
||||
|
||||
Erroneous(_) => return Err(NotDerivable(var)),
|
||||
},
|
||||
#[rustfmt::skip]
|
||||
Alias(
|
||||
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
|
||||
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
|
||||
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
|
||||
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
|
||||
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
|
||||
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
|
||||
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
|
||||
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
|
||||
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
|
||||
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
|
||||
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
|
||||
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
|
||||
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
|
||||
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) => {
|
||||
// yes
|
||||
}
|
||||
Alias(
|
||||
Symbol::NUM_NUM | Symbol::NUM_INTEGER | Symbol::NUM_FLOATINGPOINT,
|
||||
_,
|
||||
real_var,
|
||||
_,
|
||||
) => stack.push(*real_var),
|
||||
Alias(name, _, _, AliasKind::Opaque) => {
|
||||
let opaque = *name;
|
||||
if self
|
||||
.check_opaque_and_read(abilities_store, opaque, Symbol::ENCODE_ENCODING)
|
||||
Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => {
|
||||
if obligation_cache
|
||||
.check_opaque_and_read(abilities_store, opaque, Self::ABILITY)
|
||||
.is_err()
|
||||
{
|
||||
return Err(var);
|
||||
return Err(NotDerivable(var));
|
||||
}
|
||||
}
|
||||
Alias(_, arguments, real_type_var, _) => {
|
||||
push_var_slice!(arguments.all_variables());
|
||||
stack.push(*real_type_var);
|
||||
Alias(symbol, _alias_variables, real_var, AliasKind::Structural) => {
|
||||
let descend = Self::visit_alias(var, symbol)?;
|
||||
if descend.0 {
|
||||
stack.push(real_var);
|
||||
}
|
||||
}
|
||||
RangedNumber(..) => {
|
||||
// yes, all numbers can
|
||||
}
|
||||
LambdaSet(..) => return Err(var),
|
||||
RangedNumber(range) => Self::visit_ranged_number(var, range)?,
|
||||
|
||||
LambdaSet(..) => return Err(NotDerivable(var)),
|
||||
Error => {
|
||||
return Err(var);
|
||||
return Err(NotDerivable(var));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -522,6 +631,66 @@ impl ObligationCache {
|
||||
}
|
||||
}
|
||||
|
||||
struct DeriveEncoding;
|
||||
impl DerivableVisitor for DeriveEncoding {
|
||||
const ABILITY: Symbol = Symbol::ENCODE_ENCODING;
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||
if matches!(
|
||||
symbol,
|
||||
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
|
||||
) {
|
||||
Ok(Descend(true))
|
||||
} else {
|
||||
Err(DerivableError::NotDerivable(var))
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
|
||||
if is_builtin_number_alias(symbol) {
|
||||
Ok(Descend(false))
|
||||
} else {
|
||||
Ok(Descend(true))
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines what type implements an ability member of a specialized signature, given the
|
||||
/// [MustImplementAbility] constraints of the signature.
|
||||
pub fn type_implementing_specialization(
|
||||
|
Loading…
Reference in New Issue
Block a user