Implement saturated add/subtract

This commit is contained in:
ayazhafiz 2022-01-10 22:37:08 -05:00
parent 4ea91b54eb
commit 2c41c43aea
10 changed files with 177 additions and 66 deletions

View File

@ -501,7 +501,16 @@ add : Num a, Num a -> Num a
##
## This is the same as [Num.add] except if the operation overflows, instead of
## panicking or returning ∞ or -∞, it will return `Err Overflow`.
addCheckOverflow : Num a, Num a -> Result (Num a) [ Overflow ]*
addChecked : Num a, Num a -> Result (Num a) [ Overflow ]*
## Add two numbers, clamping on the maximum representable number rather than
## overflowing.
##
## This is the same as [Num.add] except for the saturating behavior if the
## addition is to overflow.
## For example, if `x : U8` is 200 and `y : U8` is 100, `addSaturated x y` will
## yield 255, the maximum value of a `U8`.
addSaturated : Num a, Num a -> Num a
## Subtract two numbers of the same type.
##
@ -528,7 +537,16 @@ sub : Num a, Num a -> Num a
##
## This is the same as [Num.sub] except if the operation overflows, instead of
## panicking or returning ∞ or -∞, it will return `Err Overflow`.
subCheckOverflow : Num a, Num a -> Result (Num a) [ Overflow ]*
subChecked : Num a, Num a -> Result (Num a) [ Overflow ]*
## Subtract two numbers, clamping on the minimum representable number rather
## than overflowing.
##
## This is the same as [Num.sub] except for the saturating behavior if the
## subtraction is to overflow.
## For example, if `x : U8` is 10 and `y : U8` is 20, `subSaturated x y` will
## yield 0, the minimum value of a `U8`.
subSaturated : Num a, Num a -> Num a
## Multiply two numbers of the same type.
##

View File

@ -141,6 +141,13 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
Box::new(int_type(flex(TVAR1))),
);
// addSaturated : Num a, Num a -> Num a
add_top_level_function_type!(
Symbol::NUM_ADD_SATURATED,
vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))],
Box::new(int_type(flex(TVAR1))),
);
// sub or (-) : Num a, Num a -> Num a
add_top_level_function_type!(
Symbol::NUM_SUB,
@ -162,6 +169,13 @@ pub fn types() -> MutMap<Symbol, (SolvedType, Region)> {
Box::new(result_type(num_type(flex(TVAR1)), overflow())),
);
// subSaturated : Num a, Num a -> Num a
add_top_level_function_type!(
Symbol::NUM_SUB_SATURATED,
vec![int_type(flex(TVAR1)), int_type(flex(TVAR1))],
Box::new(int_type(flex(TVAR1))),
);
// mul or (*) : Num a, Num a -> Num a
add_top_level_function_type!(
Symbol::NUM_MUL,

View File

@ -156,9 +156,11 @@ pub fn builtin_defs_map(symbol: Symbol, var_store: &mut VarStore) -> Option<Def>
NUM_ADD => num_add,
NUM_ADD_CHECKED => num_add_checked,
NUM_ADD_WRAP => num_add_wrap,
NUM_ADD_SATURATED => num_add_saturated,
NUM_SUB => num_sub,
NUM_SUB_WRAP => num_sub_wrap,
NUM_SUB_CHECKED => num_sub_checked,
NUM_SUB_SATURATED => num_sub_saturated,
NUM_MUL => num_mul,
NUM_MUL_WRAP => num_mul_wrap,
NUM_MUL_CHECKED => num_mul_checked,
@ -641,6 +643,11 @@ fn num_add_checked(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_overflow_checked(symbol, var_store, LowLevel::NumAddChecked)
}
/// Num.addSaturated : Int a, Int a -> Int a
fn num_add_saturated(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_binop(symbol, var_store, LowLevel::NumAddSaturated)
}
/// Num.sub : Num a, Num a -> Num a
fn num_sub(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_binop(symbol, var_store, LowLevel::NumSub)
@ -656,6 +663,11 @@ fn num_sub_checked(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_overflow_checked(symbol, var_store, LowLevel::NumSubChecked)
}
/// Num.subSaturated : Int a, Int a -> Int a
fn num_sub_saturated(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_binop(symbol, var_store, LowLevel::NumSubSaturated)
}
/// Num.mul : Num a, Num a -> Num a
fn num_mul(symbol: Symbol, var_store: &mut VarStore) -> Def {
num_binop(symbol, var_store, LowLevel::NumMul)

View File

@ -493,6 +493,12 @@ fn add_int_intrinsic<'ctx, F>(
};
}
check!(IntWidth::U8, ctx.i8_type());
check!(IntWidth::U16, ctx.i16_type());
check!(IntWidth::U32, ctx.i32_type());
check!(IntWidth::U64, ctx.i64_type());
check!(IntWidth::U128, ctx.i128_type());
check!(IntWidth::I8, ctx.i8_type());
check!(IntWidth::I16, ctx.i16_type());
check!(IntWidth::I32, ctx.i32_type());
@ -579,6 +585,14 @@ fn add_intrinsics<'ctx>(ctx: &'ctx Context, module: &Module<'ctx>) {
ctx.struct_type(&fields, false)
.fn_type(&[t.into(), t.into()], false)
});
add_int_intrinsic(ctx, module, &LLVM_ADD_SATURATED, |t| {
t.fn_type(&[t.into(), t.into()], false)
});
add_int_intrinsic(ctx, module, &LLVM_SUB_SATURATED, |t| {
t.fn_type(&[t.into(), t.into()], false)
});
}
const LLVM_POW: IntrinsicName = float_intrinsic!("llvm.pow");
@ -609,6 +623,9 @@ const LLVM_SUB_WITH_OVERFLOW: IntrinsicName =
const LLVM_MUL_WITH_OVERFLOW: IntrinsicName =
int_intrinsic!("llvm.smul.with.overflow", "llvm.umul.with.overflow");
const LLVM_ADD_SATURATED: IntrinsicName = int_intrinsic!("llvm.sadd.sat", "llvm.uadd.sat");
const LLVM_SUB_SATURATED: IntrinsicName = int_intrinsic!("llvm.ssub.sat", "llvm.usub.sat");
fn add_intrinsic<'ctx>(
module: &Module<'ctx>,
intrinsic_name: &str,
@ -5809,8 +5826,9 @@ fn run_low_level<'a, 'ctx, 'env>(
}
NumAdd | NumSub | NumMul | NumLt | NumLte | NumGt | NumGte | NumRemUnchecked
| NumIsMultipleOf | NumAddWrap | NumAddChecked | NumDivUnchecked | NumDivCeilUnchecked
| NumPow | NumPowInt | NumSubWrap | NumSubChecked | NumMulWrap | NumMulChecked => {
| NumIsMultipleOf | NumAddWrap | NumAddChecked | NumAddSaturated | NumDivUnchecked
| NumDivCeilUnchecked | NumPow | NumPowInt | NumSubWrap | NumSubChecked
| NumSubSaturated | NumMulWrap | NumMulChecked => {
debug_assert_eq!(args.len(), 2);
let (lhs_arg, lhs_layout) = load_symbol_and_layout(scope, &args[0]);
@ -6381,6 +6399,9 @@ fn build_int_binop<'a, 'ctx, 'env>(
&LLVM_ADD_WITH_OVERFLOW[int_width],
&[lhs.into(), rhs.into()],
),
NumAddSaturated => {
env.call_intrinsic(&LLVM_ADD_SATURATED[int_width], &[lhs.into(), rhs.into()])
}
NumSub => {
let result = env
.call_intrinsic(
@ -6396,6 +6417,9 @@ fn build_int_binop<'a, 'ctx, 'env>(
&LLVM_SUB_WITH_OVERFLOW[int_width],
&[lhs.into(), rhs.into()],
),
NumSubSaturated => {
env.call_intrinsic(&LLVM_SUB_SATURATED[int_width], &[lhs.into(), rhs.into()])
}
NumMul => {
let result = env
.call_intrinsic(

View File

@ -136,6 +136,7 @@ pub fn dispatch_low_level<'a>(
},
NumToStr => return NotImplemented,
NumAddChecked => return NotImplemented,
NumAddSaturated => return NotImplemented,
NumSub => match ret_layout {
WasmLayout::Primitive(value_type, _) => match value_type {
I32 => code_builder.i32_sub(),
@ -168,6 +169,7 @@ pub fn dispatch_low_level<'a>(
},
},
NumSubChecked => return NotImplemented,
NumSubSaturated => return NotImplemented,
NumMul => match ret_layout {
WasmLayout::Primitive(value_type, _) => match value_type {
I32 => code_builder.i32_mul(),

View File

@ -69,9 +69,11 @@ pub enum LowLevel {
NumAdd,
NumAddWrap,
NumAddChecked,
NumAddSaturated,
NumSub,
NumSubWrap,
NumSubChecked,
NumSubSaturated,
NumMul,
NumMulWrap,
NumMulChecked,
@ -269,9 +271,11 @@ impl LowLevelWrapperType {
Symbol::NUM_ADD => CanBeReplacedBy(NumAdd),
Symbol::NUM_ADD_WRAP => CanBeReplacedBy(NumAddWrap),
Symbol::NUM_ADD_CHECKED => WrapperIsRequired,
Symbol::NUM_ADD_SATURATED => CanBeReplacedBy(NumAddSaturated),
Symbol::NUM_SUB => CanBeReplacedBy(NumSub),
Symbol::NUM_SUB_WRAP => CanBeReplacedBy(NumSubWrap),
Symbol::NUM_SUB_CHECKED => WrapperIsRequired,
Symbol::NUM_SUB_SATURATED => CanBeReplacedBy(NumSubSaturated),
Symbol::NUM_MUL => CanBeReplacedBy(NumMul),
Symbol::NUM_MUL_WRAP => CanBeReplacedBy(NumMulWrap),
Symbol::NUM_MUL_CHECKED => WrapperIsRequired,

View File

@ -940,59 +940,61 @@ define_builtins! {
52 NUM_FLOOR: "floor"
53 NUM_ADD_WRAP: "addWrap"
54 NUM_ADD_CHECKED: "addChecked"
55 NUM_ATAN: "atan"
56 NUM_ACOS: "acos"
57 NUM_ASIN: "asin"
58 NUM_AT_SIGNED128: "@Signed128"
59 NUM_SIGNED128: "Signed128" imported
60 NUM_AT_SIGNED64: "@Signed64"
61 NUM_SIGNED64: "Signed64" imported
62 NUM_AT_SIGNED32: "@Signed32"
63 NUM_SIGNED32: "Signed32" imported
64 NUM_AT_SIGNED16: "@Signed16"
65 NUM_SIGNED16: "Signed16" imported
66 NUM_AT_SIGNED8: "@Signed8"
67 NUM_SIGNED8: "Signed8" imported
68 NUM_AT_UNSIGNED128: "@Unsigned128"
69 NUM_UNSIGNED128: "Unsigned128" imported
70 NUM_AT_UNSIGNED64: "@Unsigned64"
71 NUM_UNSIGNED64: "Unsigned64" imported
72 NUM_AT_UNSIGNED32: "@Unsigned32"
73 NUM_UNSIGNED32: "Unsigned32" imported
74 NUM_AT_UNSIGNED16: "@Unsigned16"
75 NUM_UNSIGNED16: "Unsigned16" imported
76 NUM_AT_UNSIGNED8: "@Unsigned8"
77 NUM_UNSIGNED8: "Unsigned8" imported
78 NUM_AT_BINARY64: "@Binary64"
79 NUM_BINARY64: "Binary64" imported
80 NUM_AT_BINARY32: "@Binary32"
81 NUM_BINARY32: "Binary32" imported
82 NUM_BITWISE_AND: "bitwiseAnd"
83 NUM_BITWISE_XOR: "bitwiseXor"
84 NUM_BITWISE_OR: "bitwiseOr"
85 NUM_SHIFT_LEFT: "shiftLeftBy"
86 NUM_SHIFT_RIGHT: "shiftRightBy"
87 NUM_SHIFT_RIGHT_ZERO_FILL: "shiftRightZfBy"
88 NUM_SUB_WRAP: "subWrap"
89 NUM_SUB_CHECKED: "subChecked"
90 NUM_MUL_WRAP: "mulWrap"
91 NUM_MUL_CHECKED: "mulChecked"
92 NUM_INT: "Int" imported
93 NUM_FLOAT: "Float" imported
94 NUM_AT_NATURAL: "@Natural"
95 NUM_NATURAL: "Natural" imported
96 NUM_NAT: "Nat" imported
97 NUM_INT_CAST: "intCast"
98 NUM_MAX_I128: "maxI128"
99 NUM_IS_MULTIPLE_OF: "isMultipleOf"
100 NUM_AT_DECIMAL: "@Decimal"
101 NUM_DECIMAL: "Decimal" imported
102 NUM_DEC: "Dec" imported // the Num.Dectype alias
103 NUM_BYTES_TO_U16: "bytesToU16"
104 NUM_BYTES_TO_U32: "bytesToU32"
105 NUM_CAST_TO_NAT: "#castToNat"
106 NUM_DIV_CEIL: "divCeil"
107 NUM_TO_STR: "toStr"
55 NUM_ADD_SATURATED: "addSaturated"
56 NUM_ATAN: "atan"
57 NUM_ACOS: "acos"
58 NUM_ASIN: "asin"
59 NUM_AT_SIGNED128: "@Signed128"
60 NUM_SIGNED128: "Signed128" imported
61 NUM_AT_SIGNED64: "@Signed64"
62 NUM_SIGNED64: "Signed64" imported
63 NUM_AT_SIGNED32: "@Signed32"
64 NUM_SIGNED32: "Signed32" imported
65 NUM_AT_SIGNED16: "@Signed16"
66 NUM_SIGNED16: "Signed16" imported
67 NUM_AT_SIGNED8: "@Signed8"
68 NUM_SIGNED8: "Signed8" imported
69 NUM_AT_UNSIGNED128: "@Unsigned128"
70 NUM_UNSIGNED128: "Unsigned128" imported
71 NUM_AT_UNSIGNED64: "@Unsigned64"
72 NUM_UNSIGNED64: "Unsigned64" imported
73 NUM_AT_UNSIGNED32: "@Unsigned32"
74 NUM_UNSIGNED32: "Unsigned32" imported
75 NUM_AT_UNSIGNED16: "@Unsigned16"
76 NUM_UNSIGNED16: "Unsigned16" imported
77 NUM_AT_UNSIGNED8: "@Unsigned8"
78 NUM_UNSIGNED8: "Unsigned8" imported
79 NUM_AT_BINARY64: "@Binary64"
80 NUM_BINARY64: "Binary64" imported
81 NUM_AT_BINARY32: "@Binary32"
82 NUM_BINARY32: "Binary32" imported
83 NUM_BITWISE_AND: "bitwiseAnd"
84 NUM_BITWISE_XOR: "bitwiseXor"
85 NUM_BITWISE_OR: "bitwiseOr"
86 NUM_SHIFT_LEFT: "shiftLeftBy"
87 NUM_SHIFT_RIGHT: "shiftRightBy"
88 NUM_SHIFT_RIGHT_ZERO_FILL: "shiftRightZfBy"
89 NUM_SUB_WRAP: "subWrap"
90 NUM_SUB_CHECKED: "subChecked"
91 NUM_SUB_SATURATED: "subSaturated"
92 NUM_MUL_WRAP: "mulWrap"
93 NUM_MUL_CHECKED: "mulChecked"
94 NUM_INT: "Int" imported
95 NUM_FLOAT: "Float" imported
96 NUM_AT_NATURAL: "@Natural"
97 NUM_NATURAL: "Natural" imported
98 NUM_NAT: "Nat" imported
99 NUM_INT_CAST: "intCast"
100 NUM_MAX_I128: "maxI128"
101 NUM_IS_MULTIPLE_OF: "isMultipleOf"
102 NUM_AT_DECIMAL: "@Decimal"
103 NUM_DECIMAL: "Decimal" imported
104 NUM_DEC: "Dec" imported // the Num.Dectype alias
105 NUM_BYTES_TO_U16: "bytesToU16"
106 NUM_BYTES_TO_U32: "bytesToU32"
107 NUM_CAST_TO_NAT: "#castToNat"
108 NUM_DIV_CEIL: "divCeil"
109 NUM_TO_STR: "toStr"
}
2 BOOL: "Bool" => {
0 BOOL_BOOL: "Bool" imported // the Bool.Bool type alias

View File

@ -972,11 +972,13 @@ pub fn lowlevel_borrow_signature(arena: &Bump, op: LowLevel) -> &[bool] {
Eq | NotEq => arena.alloc_slice_copy(&[borrowed, borrowed]),
And | Or | NumAdd | NumAddWrap | NumAddChecked | NumSub | NumSubWrap | NumSubChecked
| NumMul | NumMulWrap | NumMulChecked | NumGt | NumGte | NumLt | NumLte | NumCompare
| NumDivUnchecked | NumDivCeilUnchecked | NumRemUnchecked | NumIsMultipleOf | NumPow
| NumPowInt | NumBitwiseAnd | NumBitwiseXor | NumBitwiseOr | NumShiftLeftBy
| NumShiftRightBy | NumShiftRightZfBy => arena.alloc_slice_copy(&[irrelevant, irrelevant]),
And | Or | NumAdd | NumAddWrap | NumAddChecked | NumAddSaturated | NumSub | NumSubWrap
| NumSubChecked | NumSubSaturated | NumMul | NumMulWrap | NumMulChecked | NumGt
| NumGte | NumLt | NumLte | NumCompare | NumDivUnchecked | NumDivCeilUnchecked
| NumRemUnchecked | NumIsMultipleOf | NumPow | NumPowInt | NumBitwiseAnd
| NumBitwiseXor | NumBitwiseOr | NumShiftLeftBy | NumShiftRightBy | NumShiftRightZfBy => {
arena.alloc_slice_copy(&[irrelevant, irrelevant])
}
NumToStr | NumAbs | NumNeg | NumSin | NumCos | NumSqrtUnchecked | NumLogUnchecked
| NumRound | NumCeiling | NumFloor | NumToFloat | Not | NumIsFinite | NumAtan | NumAcos

View File

@ -2155,3 +2155,39 @@ fn u8_mul_greater_than_i8() {
u8
)
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
fn add_saturated() {
assert_evals_to!(
indoc!(
r#"
x : U8
x = 200
y : U8
y = 200
Num.addSaturated x y
"#
),
255,
u8
)
}
#[test]
#[cfg(any(feature = "gen-llvm"))]
fn sub_saturated() {
assert_evals_to!(
indoc!(
r#"
x : U8
x = 10
y : U8
y = 20
Num.subSaturated x y
"#
),
0,
u8
)
}

View File

@ -30,10 +30,7 @@ quicksortHelp = \list, order, low, high ->
when partition low high list order is
Pair partitionIndex partitioned ->
partitioned
|> \lst ->
# TODO: this will be nicer if we have Num.subSaturated
high1 = if partitionIndex == 0 then 0 else partitionIndex - 1
quicksortHelp lst order low high1
|> quicksortHelp order low (Num.subSaturated partitionIndex 1)
|> quicksortHelp order (partitionIndex + 1) high
else
list