implement Num.compare

This commit is contained in:
Folkert 2020-09-08 19:40:18 +02:00
parent 4c995b12a6
commit 1b42831973
8 changed files with 225 additions and 2 deletions

View File

@ -265,6 +265,15 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
),
);
// compare : Num a, Num a -> [ LT, EQ, GT ]
add_type(
Symbol::NUM_COMPARE,
SolvedType::Func(
vec![num_type(flex(TVAR1)), num_type(flex(TVAR1))],
Box::new(ordering_type()),
),
);
// toFloat : Num a -> Float
add_type(
Symbol::NUM_TO_FLOAT,
@ -722,6 +731,19 @@ fn bool_type() -> SolvedType {
SolvedType::Apply(Symbol::BOOL_BOOL, Vec::new())
}
#[inline(always)]
fn ordering_type() -> SolvedType {
// [ LT, EQ, GT ]
SolvedType::TagUnion(
vec![
(TagName::Global("GT".into()), vec![]),
(TagName::Global("EQ".into()), vec![]),
(TagName::Global("LT".into()), vec![]),
],
Box::new(SolvedType::EmptyTagUnion),
)
}
#[inline(always)]
fn str_type() -> SolvedType {
SolvedType::Apply(Symbol::STR_STR, Vec::new())

View File

@ -317,6 +317,12 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
// isGte or (>=) : Num a, Num a -> Bool
add_num_comparison(Symbol::NUM_GTE);
// compare : Num a, Num a -> [ LT, EQ, GT ]
add_type(Symbol::NUM_COMPARE, {
let_tvars! { u, v, w, num };
unique_function(vec![num_type(u, num), num_type(v, num)], ordering_type(w))
});
// toFloat : Num a -> Float
add_type(Symbol::NUM_TO_FLOAT, {
let_tvars! { star1, star2, a };
@ -1205,3 +1211,22 @@ fn map_type(u: VarId, key: VarId, value: VarId) -> SolvedType {
],
)
}
#[inline(always)]
fn ordering_type(u: VarId) -> SolvedType {
// [ LT, EQ, GT ]
SolvedType::Apply(
Symbol::ATTR_ATTR,
vec![
flex(u),
SolvedType::TagUnion(
vec![
(TagName::Global("GT".into()), vec![]),
(TagName::Global("EQ".into()), vec![]),
(TagName::Global("LT".into()), vec![]),
],
Box::new(SolvedType::EmptyTagUnion),
),
],
)
}

View File

@ -73,6 +73,7 @@ pub fn builtin_defs(var_store: &mut VarStore) -> MutMap<Symbol, Def> {
Symbol::NUM_GTE => num_gte,
Symbol::NUM_LT => num_lt,
Symbol::NUM_LTE => num_lte,
Symbol::NUM_COMPARE => num_compare,
Symbol::NUM_SIN => num_sin,
Symbol::NUM_COS => num_cos,
Symbol::NUM_TAN => num_tan,
@ -262,6 +263,11 @@ fn num_lte(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_bool_binop(symbol, var_store, LowLevel::NumLte)
}
/// Num.compare : Num a, Num a -> [ LT, EQ, GT ]
fn num_compare(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_bool_binop(symbol, var_store, LowLevel::NumCompare)
}
/// Num.sin : Float -> Float
fn num_sin(symbol: Symbol, var_store: &mut VarStore) -> Def {
let float_var = var_store.fresh();

View File

@ -185,7 +185,7 @@ pub fn construct_optimization_passes<'a>(
}
OptLevel::Optimize => {
// this threshold seems to do what we want
pmb.set_inliner_with_threshold(2);
pmb.set_inliner_with_threshold(275);
// TODO figure out which of these actually help
@ -1650,6 +1650,88 @@ fn run_low_level<'a, 'ctx, 'env>(
}
}
}
NumCompare => {
use inkwell::FloatPredicate;
use inkwell::IntPredicate;
debug_assert_eq!(args.len(), 2);
let (lhs_arg, lhs_layout) = load_symbol_and_layout(env, scope, &args[0]);
let (rhs_arg, rhs_layout) = load_symbol_and_layout(env, scope, &args[1]);
match (lhs_layout, rhs_layout) {
(Layout::Builtin(lhs_builtin), Layout::Builtin(rhs_builtin))
if lhs_builtin == rhs_builtin =>
{
use roc_mono::layout::Builtin::*;
let tag_eq = env.context.i8_type().const_int(0 as u64, false);
let tag_gt = env.context.i8_type().const_int(1 as u64, false);
let tag_lt = env.context.i8_type().const_int(2 as u64, false);
match lhs_builtin {
Int128 | Int64 | Int32 | Int16 | Int8 => {
let are_equal = env.builder.build_int_compare(
IntPredicate::EQ,
lhs_arg.into_int_value(),
rhs_arg.into_int_value(),
"int_eq",
);
let is_less_than = env.builder.build_int_compare(
IntPredicate::SLT,
lhs_arg.into_int_value(),
rhs_arg.into_int_value(),
"int_compare",
);
let step1 =
env.builder
.build_select(is_less_than, tag_lt, tag_gt, "lt_or_gt");
env.builder.build_select(
are_equal,
tag_eq,
step1.into_int_value(),
"lt_or_gt",
)
}
Float128 | Float64 | Float32 | Float16 => {
let are_equal = env.builder.build_float_compare(
FloatPredicate::OEQ,
lhs_arg.into_float_value(),
rhs_arg.into_float_value(),
"float_eq",
);
let is_less_than = env.builder.build_float_compare(
FloatPredicate::OLT,
lhs_arg.into_float_value(),
rhs_arg.into_float_value(),
"float_compare",
);
let step1 =
env.builder
.build_select(is_less_than, tag_lt, tag_gt, "lt_or_gt");
env.builder.build_select(
are_equal,
tag_eq,
step1.into_int_value(),
"lt_or_gt",
)
}
_ => {
unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid builtin layout: ({:?})", op, lhs_layout);
}
}
}
_ => {
unreachable!("Compiler bug: tried to run numeric operation {:?} on invalid layouts. The 2 layouts were: ({:?}) and ({:?})", op, lhs_layout, rhs_layout);
}
}
}
NumAdd | NumSub | NumMul | NumLt | NumLte | NumGt | NumGte | NumRemUnchecked
| NumDivUnchecked => {
debug_assert_eq!(args.len(), 2);

View File

@ -581,4 +581,88 @@ mod gen_num {
fn float_to_float() {
assert_evals_to!("Num.toFloat 0.5", 0.5, f64);
}
#[test]
fn int_compare() {
assert_evals_to!(
indoc!(
r#"
when Num.compare 0 1 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
0,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 1 1 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
1,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 1 0 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
2,
i64
);
}
#[test]
fn float_compare() {
assert_evals_to!(
indoc!(
r#"
when Num.compare 0 3.14 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
0,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 3.14 3.14 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
1,
i64
);
assert_evals_to!(
indoc!(
r#"
when Num.compare 3.14 0 is
LT -> 0
EQ -> 1
GT -> 2
"#
),
2,
i64
);
}
}

View File

@ -25,6 +25,7 @@ pub enum LowLevel {
NumGte,
NumLt,
NumLte,
NumCompare,
NumDivUnchecked,
NumRemUnchecked,
NumAbs,

View File

@ -639,6 +639,7 @@ define_builtins! {
34 NUM_MOD_FLOAT: "modFloat"
35 NUM_SQRT: "sqrt"
36 NUM_ROUND: "round"
37 NUM_COMPARE: "compare"
}
2 BOOL: "Bool" => {
0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias

View File

@ -522,7 +522,9 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
ListWalkRight => arena.alloc_slice_copy(&[borrowed, irrelevant, owned]),
Eq | NotEq | And | Or | NumAdd | NumSub | NumMul | NumGt | NumGte | NumLt | NumLte
| NumDivUnchecked | NumRemUnchecked => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
| NumCompare | NumDivUnchecked | NumRemUnchecked => {
arena.alloc_slice_copy(&[irrelevant, irrelevant])
}
NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumRound | NumToFloat | Not => {
arena.alloc_slice_copy(&[irrelevant])