Switch obligation checking to use a visitor

My hope is this will make obligation checking for other abilities easier
to add when we do so
This commit is contained in:
Ayaz Hafiz 2022-08-01 12:52:32 -05:00
parent fa14146054
commit a7bc8cf4f2
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -5,6 +5,7 @@ use roc_error_macros::internal_error;
use roc_module::symbol::Symbol;
use roc_region::all::{Loc, Region};
use roc_solve_problem::{TypeError, UnderivableReason, Unfulfilled};
use roc_types::num::NumericRange;
use roc_types::subs::{instantiate_rigids, Content, FlatType, GetSubsSlice, Rank, Subs, Variable};
use roc_types::types::{AliasKind, Category, MemberImpl, PatternCategory};
use roc_unify::unify::{Env, MustImplementConstraints};
@ -253,7 +254,12 @@ impl ObligationCache {
// independent queries.
let opt_can_derive_builtin = match ability {
Symbol::ENCODE_ENCODING => Some(self.can_derive_encoding(subs, abilities_store, var)),
Symbol::ENCODE_ENCODING => Some(DeriveEncoding::is_derivable(
self,
abilities_store,
subs,
var,
)),
_ => None,
};
@ -262,7 +268,7 @@ impl ObligationCache {
// can derive!
None
}
Some(Err(failure_var)) => Some(if failure_var == var {
Some(Err(DerivableError::NotDerivable(failure_var))) => Some(if failure_var == var {
UnderivableReason::SurfaceNotDerivable
} else {
let (error_type, _skeletons) = subs.var_to_error_type(failure_var);
@ -391,16 +397,128 @@ impl ObligationCache {
let check_has_fake = self.derive_cache.insert(derive_key, root_result);
debug_assert_eq!(check_has_fake, Some(fake_fulfilled));
}
}
// If we have a lot of these, consider using a visitor.
// It will be very similar for most types (can't derive functions, can't derive unbound type
// variables, can only derive opaques if they have an impl, etc).
fn can_derive_encoding(
&mut self,
subs: &mut Subs,
#[inline(always)]
#[rustfmt::skip]
fn is_builtin_number_alias(symbol: Symbol) -> bool {
matches!(symbol,
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
)
}
enum DerivableError {
NotDerivable(Variable),
}
struct Descend(bool);
trait DerivableVisitor {
const ABILITY: Symbol;
#[inline(always)]
fn visit_flex(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_rigid(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_flex_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
if ability != Self::ABILITY {
Err(DerivableError::NotDerivable(var))
} else {
Ok(())
}
}
#[inline(always)]
fn visit_rigid_able(var: Variable, ability: Symbol) -> Result<(), DerivableError> {
if ability != Self::ABILITY {
Err(DerivableError::NotDerivable(var))
} else {
Ok(())
}
}
#[inline(always)]
fn visit_recursion(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_apply(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_func(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_record(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_recursive_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_function_or_tag_union(var: Variable) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_empty_record(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_empty_tag_union(var: Variable) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_alias(var: Variable, _symbol: Symbol) -> Result<Descend, DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
Err(DerivableError::NotDerivable(var))
}
#[inline(always)]
fn is_derivable(
obligation_cache: &mut ObligationCache,
abilities_store: &AbilitiesStore,
subs: &Subs,
var: Variable,
) -> Result<(), Variable> {
) -> Result<(), DerivableError> {
let mut stack = vec![var];
let mut seen_recursion_vars = vec![];
@ -418,102 +536,93 @@ impl ObligationCache {
let content = subs.get_content_without_compacting(var);
use Content::*;
use DerivableError::*;
use FlatType::*;
match content {
FlexVar(_) | RigidVar(_) => return Err(var),
FlexAbleVar(_, ability) | RigidAbleVar(_, ability) => {
if *ability != Symbol::ENCODE_ENCODING {
return Err(var);
}
// Any concrete type this variables is instantiated with will also gain a "does
// implement" check so this is okay.
}
match *content {
FlexVar(_) => Self::visit_flex(var)?,
RigidVar(_) => Self::visit_rigid(var)?,
FlexAbleVar(_, ability) => Self::visit_flex_able(var, ability)?,
RigidAbleVar(_, ability) => Self::visit_rigid_able(var, ability)?,
RecursionVar {
structure,
opt_name: _,
} => {
seen_recursion_vars.push(var);
stack.push(*structure);
let descend = Self::visit_recursion(var)?;
if descend.0 {
seen_recursion_vars.push(var);
stack.push(structure);
}
}
Structure(flat_type) => match flat_type {
Apply(
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
vars,
) => push_var_slice!(*vars),
Apply(..) => return Err(var),
Func(..) => {
return Err(var);
}
Record(fields, var) => {
push_var_slice!(fields.variables());
stack.push(*var);
}
TagUnion(tags, ext_var) => {
for i in tags.variables() {
push_var_slice!(subs[i]);
Apply(symbol, vars) => {
let descend = Self::visit_apply(var, symbol)?;
if descend.0 {
push_var_slice!(vars)
}
stack.push(*ext_var);
}
FunctionOrTagUnion(_, _, var) => stack.push(*var),
RecursiveTagUnion(rec_var, tags, ext_var) => {
seen_recursion_vars.push(*rec_var);
for i in tags.variables() {
push_var_slice!(subs[i]);
Func(args, _clos, ret) => {
let descend = Self::visit_func(var)?;
if descend.0 {
push_var_slice!(args);
stack.push(ret);
}
stack.push(*ext_var);
}
EmptyRecord | EmptyTagUnion => {
// yes
Record(fields, ext) => {
let descend = Self::visit_record(var)?;
if descend.0 {
push_var_slice!(fields.variables());
stack.push(ext);
}
}
Erroneous(_) => return Err(var),
TagUnion(tags, ext) => {
let descend = Self::visit_tag_union(var)?;
if descend.0 {
for i in tags.variables() {
push_var_slice!(subs[i]);
}
stack.push(ext);
}
}
FunctionOrTagUnion(_tag_name, _fn_name, ext) => {
let descend = Self::visit_function_or_tag_union(var)?;
if descend.0 {
stack.push(ext);
}
}
RecursiveTagUnion(rec, tags, ext) => {
let descend = Self::visit_recursive_tag_union(var)?;
if descend.0 {
seen_recursion_vars.push(rec);
for i in tags.variables() {
push_var_slice!(subs[i]);
}
stack.push(ext);
}
}
EmptyRecord => Self::visit_empty_record(var)?,
EmptyTagUnion => Self::visit_empty_tag_union(var)?,
Erroneous(_) => return Err(NotDerivable(var)),
},
#[rustfmt::skip]
Alias(
Symbol::NUM_U8 | Symbol::NUM_UNSIGNED8
| Symbol::NUM_U16 | Symbol::NUM_UNSIGNED16
| Symbol::NUM_U32 | Symbol::NUM_UNSIGNED32
| Symbol::NUM_U64 | Symbol::NUM_UNSIGNED64
| Symbol::NUM_U128 | Symbol::NUM_UNSIGNED128
| Symbol::NUM_I8 | Symbol::NUM_SIGNED8
| Symbol::NUM_I16 | Symbol::NUM_SIGNED16
| Symbol::NUM_I32 | Symbol::NUM_SIGNED32
| Symbol::NUM_I64 | Symbol::NUM_SIGNED64
| Symbol::NUM_I128 | Symbol::NUM_SIGNED128
| Symbol::NUM_NAT | Symbol::NUM_NATURAL
| Symbol::NUM_F32 | Symbol::NUM_BINARY32
| Symbol::NUM_F64 | Symbol::NUM_BINARY64
| Symbol::NUM_DEC | Symbol::NUM_DECIMAL,
_,
_,
_,
) => {
// yes
}
Alias(
Symbol::NUM_NUM | Symbol::NUM_INTEGER | Symbol::NUM_FLOATINGPOINT,
_,
real_var,
_,
) => stack.push(*real_var),
Alias(name, _, _, AliasKind::Opaque) => {
let opaque = *name;
if self
.check_opaque_and_read(abilities_store, opaque, Symbol::ENCODE_ENCODING)
Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => {
if obligation_cache
.check_opaque_and_read(abilities_store, opaque, Self::ABILITY)
.is_err()
{
return Err(var);
return Err(NotDerivable(var));
}
}
Alias(_, arguments, real_type_var, _) => {
push_var_slice!(arguments.all_variables());
stack.push(*real_type_var);
Alias(symbol, _alias_variables, real_var, AliasKind::Structural) => {
let descend = Self::visit_alias(var, symbol)?;
if descend.0 {
stack.push(real_var);
}
}
RangedNumber(..) => {
// yes, all numbers can
}
LambdaSet(..) => return Err(var),
RangedNumber(range) => Self::visit_ranged_number(var, range)?,
LambdaSet(..) => return Err(NotDerivable(var)),
Error => {
return Err(var);
return Err(NotDerivable(var));
}
}
}
@ -522,6 +631,66 @@ impl ObligationCache {
}
}
struct DeriveEncoding;
impl DerivableVisitor for DeriveEncoding {
const ABILITY: Symbol = Symbol::ENCODE_ENCODING;
#[inline(always)]
fn visit_recursion(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_apply(var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
if matches!(
symbol,
Symbol::LIST_LIST | Symbol::SET_SET | Symbol::DICT_DICT | Symbol::STR_STR,
) {
Ok(Descend(true))
} else {
Err(DerivableError::NotDerivable(var))
}
}
fn visit_record(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_recursive_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
fn visit_function_or_tag_union(_var: Variable) -> Result<Descend, DerivableError> {
Ok(Descend(true))
}
#[inline(always)]
fn visit_empty_record(_var: Variable) -> Result<(), DerivableError> {
Ok(())
}
#[inline(always)]
fn visit_empty_tag_union(_var: Variable) -> Result<(), DerivableError> {
Ok(())
}
fn visit_alias(_var: Variable, symbol: Symbol) -> Result<Descend, DerivableError> {
if is_builtin_number_alias(symbol) {
Ok(Descend(false))
} else {
Ok(Descend(true))
}
}
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), DerivableError> {
Ok(())
}
}
/// Determines what type implements an ability member of a specialized signature, given the
/// [MustImplementAbility] constraints of the signature.
pub fn type_implementing_specialization(