Obligation checking of floating point for Eq succeeds only with Dec

This commit is contained in:
Ayaz Hafiz 2022-12-01 11:31:41 -06:00
parent e8492f279e
commit ed7d4f8f63
No known key found for this signature in database
GPG Key ID: 0E2A37416A25EF58

View File

@ -488,19 +488,10 @@ struct NotDerivable {
struct Descend(bool); struct Descend(bool);
enum FPDerivable {
/// Whether the floating point type is derivable is based on its ground or unbound type.
Descend,
/// The FP type is never derivable.
No(NotDerivableContext),
}
trait DerivableVisitor { trait DerivableVisitor {
const ABILITY: Symbol; const ABILITY: Symbol;
const ABILITY_SLICE: SubsSlice<Symbol>; const ABILITY_SLICE: SubsSlice<Symbol>;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable;
#[inline(always)] #[inline(always)]
fn is_derivable_builtin_opaque(_symbol: Symbol) -> bool { fn is_derivable_builtin_opaque(_symbol: Symbol) -> bool {
false false
@ -602,6 +593,18 @@ trait DerivableVisitor {
}) })
} }
#[inline(always)]
fn visit_floating_point_content(
var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Err(NotDerivable {
var,
context: NotDerivableContext::NoContext,
})
}
#[inline(always)] #[inline(always)]
fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), NotDerivable> { fn visit_ranged_number(var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Err(NotDerivable { Err(NotDerivable {
@ -736,12 +739,10 @@ trait DerivableVisitor {
stack.push(real_var); stack.push(real_var);
} }
Alias(Symbol::NUM_FLOATINGPOINT, _alias_variables, real_var, AliasKind::Opaque) => { Alias(Symbol::NUM_FLOATINGPOINT, _alias_variables, real_var, AliasKind::Opaque) => {
match Self::IS_FLOATING_POINT_DERIVABLE { let descend = Self::visit_floating_point_content(var, subs, real_var)?;
FPDerivable::Descend => { if descend.0 {
// Decay to a ground // Decay to a ground
stack.push(real_var) stack.push(real_var)
}
FPDerivable::No(context) => return Err(NotDerivable { var, context }),
} }
} }
Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => { Alias(opaque, _alias_variables, _real_var, AliasKind::Opaque) => {
@ -788,8 +789,6 @@ impl DerivableVisitor for DeriveEncoding {
const ABILITY: Symbol = Symbol::ENCODE_ENCODING; const ABILITY: Symbol = Symbol::ENCODE_ENCODING;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_ENCODING; const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_ENCODING;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable = FPDerivable::Descend;
#[inline(always)] #[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool { fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_number_alias(symbol) is_builtin_number_alias(symbol)
@ -862,6 +861,15 @@ impl DerivableVisitor for DeriveEncoding {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> { fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(()) Ok(())
} }
#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(false))
}
} }
struct DeriveDecoding; struct DeriveDecoding;
@ -869,8 +877,6 @@ impl DerivableVisitor for DeriveDecoding {
const ABILITY: Symbol = Symbol::DECODE_DECODING; const ABILITY: Symbol = Symbol::DECODE_DECODING;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_DECODING; const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_DECODING;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable = FPDerivable::Descend;
#[inline(always)] #[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool { fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_number_alias(symbol) is_builtin_number_alias(symbol)
@ -954,6 +960,15 @@ impl DerivableVisitor for DeriveDecoding {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> { fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(()) Ok(())
} }
#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(false))
}
} }
struct DeriveHash; struct DeriveHash;
@ -961,8 +976,6 @@ impl DerivableVisitor for DeriveHash {
const ABILITY: Symbol = Symbol::HASH_HASH_ABILITY; const ABILITY: Symbol = Symbol::HASH_HASH_ABILITY;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_HASH; const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_HASH;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable = FPDerivable::Descend;
#[inline(always)] #[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool { fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_number_alias(symbol) is_builtin_number_alias(symbol)
@ -1046,6 +1059,15 @@ impl DerivableVisitor for DeriveHash {
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> { fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
Ok(()) Ok(())
} }
#[inline(always)]
fn visit_floating_point_content(
_var: Variable,
_subs: &mut Subs,
_content_var: Variable,
) -> Result<Descend, NotDerivable> {
Ok(Descend(false))
}
} }
struct DeriveEq; struct DeriveEq;
@ -1053,9 +1075,6 @@ impl DerivableVisitor for DeriveEq {
const ABILITY: Symbol = Symbol::BOOL_EQ; const ABILITY: Symbol = Symbol::BOOL_EQ;
const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_EQ; const ABILITY_SLICE: SubsSlice<Symbol> = Subs::AB_EQ;
const IS_FLOATING_POINT_DERIVABLE: FPDerivable =
FPDerivable::No(NotDerivableContext::Eq(NotDerivableEq::FloatingPoint));
#[inline(always)] #[inline(always)]
fn is_derivable_builtin_opaque(symbol: Symbol) -> bool { fn is_derivable_builtin_opaque(symbol: Symbol) -> bool {
is_builtin_int_alias(symbol) || is_builtin_dec_alias(symbol) is_builtin_int_alias(symbol) || is_builtin_dec_alias(symbol)
@ -1144,6 +1163,32 @@ impl DerivableVisitor for DeriveEq {
} }
} }
fn visit_floating_point_content(
var: Variable,
subs: &mut Subs,
content_var: Variable,
) -> Result<Descend, NotDerivable> {
use roc_unify::unify::{unify, Mode};
// Of the floating-point types,
// only Dec implements Eq.
let mut env = Env::new(subs);
let unified = unify(
&mut env,
content_var,
Variable::DECIMAL,
Mode::EQ,
Polarity::Pos,
);
match unified {
roc_unify::unify::Unified::Success { .. } => Ok(Descend(false)),
roc_unify::unify::Unified::Failure(..) => Err(NotDerivable {
var,
context: NotDerivableContext::Eq(NotDerivableEq::FloatingPoint),
}),
}
}
#[inline(always)] #[inline(always)]
fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> { fn visit_ranged_number(_var: Variable, _range: NumericRange) -> Result<(), NotDerivable> {
// Ranged numbers are allowed, because they are always possibly ints - floats can not have // Ranged numbers are allowed, because they are always possibly ints - floats can not have