Obligation checking of floating point for Eq succeeds only with Dec

This commit is contained in:
Ayaz Hafiz 2022-12-01 11:31:41 -06:00
parent e8492f279e
commit ed7d4f8f63
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -488,19 +488,10 @@ struct NotDerivable {
struct Descend(bool);
enum FPDerivable {
/// Whether the floating point type is derivable is based on its ground or unbound type.
Descend,
/// The FP type is never derivable.
No(NotDerivableContext),
}
trait DerivableVisitor {
const ABILITY: Symbol;
const ABILITY_SLICE: SubsSlice<Symbol>;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable;
#[inline(always)]
fn is_derivable_builtin_opaque(_symbol: Symbol) -> bool {
false
@ -602,6 +593,18 @@ trait DerivableVisitor {
})
}
#[inline(always)]
fn visit_floating_point_content(
var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)]
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Err(NotDerivable {
@ -736,12 +739,10 @@ trait DerivableVisitor {
stack.push(real_var);
}
Alias(Symbol::NUM_FLOATINGPOINT, _alias_variables, real_var, AliasKind::Opaque) => {
match Self::IS_FLOATING_POINT_DERIVABLE {
FPDerivable::Descend => {
// Decay to a ground
stack.push(real_var)
}
FPDerivable::No(context) => return Err(NotDerivable { var, context }),
let descend = Self::visit_floating_point_content(var, subs, real_var)?;
if descend.0 {
// Decay to a ground
stack.push(real_var)
}
}
Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => {
@ -788,8 +789,6 @@ impl DerivableVisitor for DeriveEncoding {
const ABILITY: Symbol = Symbol::ENCODE_ENCODING;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_ENCODING;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable = FPDerivable::Descend;
#[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_number_alias(symbol)
@ -862,6 +861,15 @@ impl DerivableVisitor for DeriveEncoding {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(false))
}
}
struct DeriveDecoding;
@ -869,8 +877,6 @@ impl DerivableVisitor for DeriveDecoding {
const ABILITY: Symbol = Symbol::DECODE_DECODING;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_DECODING;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable = FPDerivable::Descend;
#[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_number_alias(symbol)
@ -954,6 +960,15 @@ impl DerivableVisitor for DeriveDecoding {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(false))
}
}
struct DeriveHash;
@ -961,8 +976,6 @@ impl DerivableVisitor for DeriveHash {
const ABILITY: Symbol = Symbol::HASH_HASH_ABILITY;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_HASH;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable = FPDerivable::Descend;
#[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_number_alias(symbol)
@ -1046,6 +1059,15 @@ impl DerivableVisitor for DeriveHash {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(())
}
#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(false))
}
}
struct DeriveEq;
@ -1053,9 +1075,6 @@ impl DerivableVisitor for DeriveEq {
const ABILITY: Symbol = Symbol::BOOL_EQ;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_EQ;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable =
FPDerivable::No(NotDerivableContext::Eq(NotDerivableEq::FloatingPoint));
#[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_int_alias(symbol) || is_builtin_dec_alias(symbol)
@ -1144,6 +1163,32 @@ impl DerivableVisitor for DeriveEq {
}
}
fn visit_floating_point_content(
var: Variable,
subs: &mut Subs,
content_var: Variable,
) -> Result<Descend, NotDerivable> {
use roc_unify::unify::{unify, Mode};
// Of the floating-point types,
// only Dec implements Eq.
let mut env = Env::new(subs);
let unified = unify(
&mut env,
content_var,
Variable::DECIMAL,
Mode::EQ,
Polarity::Pos,
);
match unified {
roc_unify::unify::Unified::Success { .. } => Ok(Descend(false)),
roc_unify::unify::Unified::Failure(..) => Err(NotDerivable {
var,
context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint),
}),
}
}
#[inline(always)]
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
// Ranged numbers are allowed, because they are always possibly ints - floats can not have