LF: Engine support for BigNumeric (#9220)

This is part of #8719

CHANGELOG_BEGIN
CHANGELOG_END
This commit is contained in:
Remy 2021-03-25 11:52:30 +01:00 committed by GitHub
parent b6f7b78990
commit 6f84b35fc0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 238 additions and 15 deletions

View File

@ -192,7 +192,7 @@ abstract class NumericModule {
*/
final def toString(x: BigDecimal): String = {
val s = x.toPlainString
if (x.scale == 0) s + "." else s
if (x.scale <= 0) s + "." else s
}
private val validScaledFormat =

View File

@ -14,4 +14,26 @@ module BuiltinMod {
val greater: forall (a: *). a -> a -> Bool =
GREATER;
val scaleBigNumeric: BigNumeric -> Int64 =
SCALE_BIGNUMERIC;
val precisionBigNumeric: BigNumeric -> Int64 =
PRECISION_BIGNUMERIC;
val addBigNumeric: BigNumeric -> BigNumeric -> BigNumeric =
ADD_BIGNUMERIC;
val subBigNumeric: BigNumeric -> BigNumeric -> BigNumeric =
SUB_BIGNUMERIC;
val mulBigNumeric: BigNumeric -> BigNumeric -> BigNumeric =
MUL_BIGNUMERIC;
val divBigNumeric: Int64 -> RoundingMode -> BigNumeric -> BigNumeric -> BigNumeric =
DIV_BIGNUMERIC;
val shiftBigNumeric: Int64 -> BigNumeric -> BigNumeric =
SHIFT_BIGNUMERIC;
val toNumericBigNumeric: forall (n: nat). BigNumeric -> Numeric n =
TO_NUMERIC_BIGNUMERIC;
val toBigNumericNumeric: forall (n: nat). Numeric n -> BigNumeric =
TO_BIGNUMERIC_NUMERIC;
val roundingMode: RoundingMode =
ROUNDING_UP;
}

View File

@ -576,11 +576,15 @@ private[lf] final class Compiler(
case BGenMapValues => SBGenMapValues
case BGenMapSize => SBGenMapSize
case BAddBigNumeric | BDivBigNumeric | BMulBigNumeric | BPrecisionBigNumeric |
BScaleBigNumeric | BShiftBigNumeric | BSubBigNumeric | BToBigNumericNumeric |
BToNumericBigNumeric =>
// TODO https://github.com/digital-asset/daml/issues/8719
sys.error(s"builtin $bf not supported")
case BScaleBigNumeric => SBScaleBigNumeric
case BPrecisionBigNumeric => SBPrecisionBigNumeric
case BAddBigNumeric => SBAddBigNumeric
case BSubBigNumeric => SBSubBigNumeric
case BDivBigNumeric => SBDivBigNumeric
case BMulBigNumeric => SBMulBigNumeric
case BShiftBigNumeric => SBShiftBigNumeric
case BToBigNumericNumeric => SBToBigNumericNumeric
case BToNumericBigNumeric => SBToNumericBigNumeric
// Unstable Text Primitives
case BTextToUpper => SBTextToUpper
@ -630,10 +634,7 @@ private[lf] final class Compiler(
case PLTimestamp(ts) => STimestamp(ts)
case PLParty(p) => SParty(p)
case PLDate(d) => SDate(d)
case PLRoundingMode(_) =>
// TODO https://github.com/digital-asset/daml/issues/8719
sys.error("RoundingMode not supported")
case PLRoundingMode(roundingMode) => SInt64(roundingMode.ordinal.toLong)
})
// ERecUpd(_, f2, ERecUpd(_, f1, e0, e1), e2) => (e0, [f1, f2], [e1, e2])

View File

@ -315,6 +315,7 @@ private[lf] object SBuiltin {
case SParty(p) => p
case SUnit => s"<unit>"
case SDate(date) => date.toString
case SBigNumeric(x) => Numeric.toString(x)
case SContractId(_) | SNumeric(_) => crash("litToText: literal not supported")
})
}
@ -852,6 +853,90 @@ private[lf] object SBuiltin {
}
}
final object SBScaleBigNumeric extends SBuiltinPure(1) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SInt64 = {
val x = args.get(0).asInstanceOf[SBigNumeric].value
SInt64(x.scale().toLong)
}
}
final object SBPrecisionBigNumeric extends SBuiltinPure(1) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SInt64 = {
val x = args.get(0).asInstanceOf[SBigNumeric].value
SInt64(x.precision().toLong)
}
}
final object SBAddBigNumeric extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SBigNumeric = {
val x = args.get(0).asInstanceOf[SBigNumeric].value
val y = args.get(1).asInstanceOf[SBigNumeric].value
rightOrArithmeticError("overflow/underflow", SBigNumeric.fromBigDecimal(x add y))
}
}
final object SBSubBigNumeric extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SBigNumeric = {
val x = args.get(0).asInstanceOf[SBigNumeric].value
val y = args.get(1).asInstanceOf[SBigNumeric].value
rightOrArithmeticError("overflow/underflow", SBigNumeric.fromBigDecimal(x subtract y))
}
}
final object SBMulBigNumeric extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SBigNumeric = {
val x = args.get(0).asInstanceOf[SBigNumeric].value
val y = args.get(1).asInstanceOf[SBigNumeric].value
rightOrArithmeticError("overflow/underflow", SBigNumeric.fromBigDecimal(x multiply y))
}
}
final object SBDivBigNumeric extends SBuiltinPure(4) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SBigNumeric = {
val scale = rightOrCrash(SBigNumeric.checkScale(args.get(0).asInstanceOf[SInt64].value))
val roundingMode =
java.math.RoundingMode.valueOf(args.get(1).asInstanceOf[SInt64].value.toInt)
val x = args.get(2).asInstanceOf[SBigNumeric].value
val y = args.get(3).asInstanceOf[SBigNumeric].value
val result =
try {
x.divide(y, scale, roundingMode)
} catch {
case e: ArithmeticException =>
throw DamlEArithmeticError(e.getMessage)
}
rightOrArithmeticError("overflow/underflow", SBigNumeric.fromBigDecimal(result))
}
}
final object SBShiftBigNumeric extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SBigNumeric = {
val shifting = args.get(0).asInstanceOf[SInt64].value
if (shifting.abs > SBigNumeric.MaxPrecision) throw DamlEArithmeticError("overflow/underflow")
val x = args.get(1).asInstanceOf[SBigNumeric].value
rightOrArithmeticError(
"overflow/underflow",
SBigNumeric.fromBigDecimal(x.scaleByPowerOfTen(-shifting.toInt)),
)
}
}
final object SBToBigNumericNumeric extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SBigNumeric = {
val x = args.get(1).asInstanceOf[SNumeric].value
// should not fail
rightOrArithmeticError("overflow/underflow", SBigNumeric.fromBigDecimal(x))
}
}
final object SBToNumericBigNumeric extends SBuiltinPure(2) {
override private[speedy] def executePure(args: util.ArrayList[SValue]): SNumeric = {
val scale = args.get(0).asInstanceOf[STNat].n
val x = args.get(1).asInstanceOf[SNumeric].value
rightOrArithmeticError("overflow/underflow", Numeric.fromBigDecimal(scale, x).map(SNumeric))
}
}
/** $checkPrecondition
* :: arg (template argument)
* -> Bool (false if ensure failed)
@ -1838,7 +1923,7 @@ private[lf] object SBuiltin {
case _ => crash(s"Invalid cached contract: $v")
}
private[this] def rightOrArithmeticError[A](message: String, mb: Either[String, A]): A =
private[this] def rightOrArithmeticError[A](message: => String, mb: Either[String, A]): A =
mb.fold(_ => throw DamlEArithmeticError(s"$message"), identity)
private[this] def rightOrCrash[A](either: Either[String, A]) =

View File

@ -1,7 +1,8 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.lf.speedy
package com.daml.lf
package speedy
import java.util
@ -14,6 +15,7 @@ import com.daml.lf.value.{Value => V}
import scala.jdk.CollectionConverters._
import scala.collection.compat._
import scala.collection.immutable.TreeMap
import scala.util.hashing.MurmurHash3
/** Speedy values. These are the value types recognized by the
* machine. In addition to the usual types present in the LF value,
@ -66,6 +68,8 @@ sealed trait SValue {
throw SErrorCrash("SValue.toValue: unexpected SStruct")
case SAny(_, _) =>
throw SErrorCrash("SValue.toValue: unexpected SAny")
case SBigNumeric(_) =>
throw SErrorCrash("SValue.toValue: unexpected SBigNumeric")
case SAnyException(_, _, _) =>
throw SErrorCrash("SValue.toValue: unexpected SAnyException")
case STypeRep(_) =>
@ -210,7 +214,49 @@ object SValue {
// with SValue and we can remove one layer of indirection.
sealed trait SPrimLit extends SValue with Equals
final case class SInt64(value: Long) extends SPrimLit
// TODO https://github.com/digital-asset/daml/issues/8719
// try to factorize SNumeric and SBigNumeric
// note it seems that scale is relevant in SNumeric but lost in SBigNumeric
final case class SNumeric(value: Numeric) extends SPrimLit
final class SBigNumeric private (val value: java.math.BigDecimal) extends SPrimLit {
override def canEqual(that: Any): Boolean = that match {
case _: SBigNumeric => true
case _ => false
}
override def equals(obj: Any): Boolean = obj match {
case that: SBigNumeric => this.value == that.value
case _ => false
}
override def hashCode(): Int = MurmurHash3.mix(getClass.hashCode(), value.hashCode())
override def toString: String = s"SBigNumeric($value)"
}
object SBigNumeric {
// TODO https://github.com/digital-asset/daml/issues/8719
// Decide what are the actual bound for BigDecimal
val MaxScale = 1 << 10
val MaxPrecision = MaxScale << 2
def unapply(value: SBigNumeric): Some[java.math.BigDecimal] =
Some(value.value)
def fromBigDecimal(x: java.math.BigDecimal): Either[String, SBigNumeric] = {
val norm = x.stripTrailingZeros()
Either.cond(
test = norm.scale <= MaxScale && norm.precision + norm.scale < MaxScale,
right = new SBigNumeric(norm),
left = "non valid BigNumeric",
)
}
def assertFromBigDecimal(x: java.math.BigDecimal): SBigNumeric =
data.assertRight(fromBigDecimal(x))
def checkScale(s: Long): Either[String, Int] =
Either.cond(test = s.abs <= MaxScale, right = s.toInt, left = "invalide scale")
}
final case class SText(value: String) extends SPrimLit
final case class STimestamp(value: Time.Timestamp) extends SPrimLit
final case class SParty(value: Party) extends SPrimLit

View File

@ -1127,8 +1127,8 @@ private[lf] object Speedy {
case SValue.SContractId(_) | SValue.SDate(_) | SValue.SNumeric(_) | SValue.SInt64(_) |
SValue.SParty(_) | SValue.SText(_) | SValue.STimestamp(_) | SValue.SStruct(_, _) |
SValue.SGenMap(_, _) | SValue.SRecord(_, _, _) | SValue.SAny(_, _) | SValue.STypeRep(_) |
SValue.STNat(_) | _: SValue.SPAP | SValue.SToken | SValue.SBuiltinException(_, _) |
SValue.SAnyException(_, _, _) =>
SValue.STNat(_) | SValue.SBigNumeric(_) | _: SValue.SPAP | SValue.SToken |
SValue.SBuiltinException(_, _) | SValue.SAnyException(_, _, _) =>
crash("Match on non-matchable value")
}

View File

@ -47,6 +47,8 @@ object Ordering extends scala.math.Ordering[SValue] {
diff = x compareTo y
case (SNumeric(x), SNumeric(y)) =>
diff = x compareTo y
case (SBigNumeric(x), SBigNumeric(y)) =>
diff = x compareTo y
case (SText(x), SText(y)) =>
diff = Utf8.Ordering.compare(x, y)
case (SDate(x), SDate(y)) =>

View File

@ -205,6 +205,8 @@ class SBuiltinTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChe
}
"Decimal operations" - {
// TODO https://github.com/digital-asset/daml/issues/8719
// Add extensive test for BigNumeric builtins
val maxDecimal = Decimal.MaxValue
@ -364,7 +366,7 @@ class SBuiltinTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChe
}
}
"Decimal binary operations compute proper results" in {
"Numeric binary operations compute proper results" in {
def round(x: BigDecimal) = n(10, x.setScale(10, BigDecimal.RoundingMode.HALF_EVEN))
@ -399,6 +401,71 @@ class SBuiltinTest extends AnyFreeSpec with Matchers with TableDrivenPropertyChe
}
}
"BigNumeric binary operations compute proper results" in {
import java.math.BigDecimal
import scala.math.{BigDecimal => BigDec}
import SBigNumeric.assertFromBigDecimal
val testCases = Table[String, (BigDecimal, BigDecimal) => Option[SValue]](
("builtin", "reference"),
("ADD_BIGNUMERIC", (a, b) => Some(assertFromBigDecimal(a add b))),
("SUB_BIGNUMERIC", (a, b) => Some(assertFromBigDecimal(a subtract b))),
("MUL_BIGNUMERIC ", (a, b) => Some(assertFromBigDecimal(a multiply b))),
(
"DIV_BIGNUMERIC 10 ROUNDING_HALF_EVEN",
{
case (a, b) if b.signum != 0 =>
Some(assertFromBigDecimal(a.divide(b, 10, java.math.RoundingMode.HALF_EVEN)))
case _ => None
},
),
("LESS_EQ @BigNumeric", (a, b) => Some(SBool(BigDec(a) <= BigDec(b)))),
("GREATER_EQ @BigNumeric", (a, b) => Some(SBool(BigDec(a) >= BigDec(b)))),
("LESS @BigNumeric", (a, b) => Some(SBool(BigDec(a) < BigDec(b)))),
("GREATER @BigNumeric", (a, b) => Some(SBool(BigDec(a) > BigDec(b)))),
("EQUAL @BigNumeric", (a, b) => Some(SBool(BigDec(a) == BigDec(b)))),
)
forEvery(testCases) { (builtin, ref) =>
forEvery(decimals) { a =>
forEvery(decimals) { b =>
val actualResult = eval(
e"$builtin (TO_BIGNUMERIC_NUMERIC @10 ${s(10, a)}) (TO_BIGNUMERIC_NUMERIC @10 ${s(10, b)})"
)
val expectedResult = ref(n(10, a), n(10, b))
if (actualResult.toOption != expectedResult)
actualResult shouldBe expectedResult
}
}
}
}
"SHIFT_BIGNUMERIC" - {
import java.math.BigDecimal
import SBigNumeric.assertFromBigDecimal
"returns proper result" in {
val testCases = Table[Int, Int, String, String](
("input scale", "output scale", "input", "output"),
(0, 1, s(0, Numeric.maxValue(0)), s(1, Numeric.maxValue(1))),
(0, 37, tenPowerOf(1, 0), tenPowerOf(-36, 37)),
(20, 10, tenPowerOf(15, 20), tenPowerOf(5, 30)),
(20, -10, tenPowerOf(15, 20), tenPowerOf(25, 10)),
(10, 10, tenPowerOf(-5, 10), tenPowerOf(-15, 20)),
(20, -10, tenPowerOf(-5, 20), tenPowerOf(5, 10)),
(10, 10, tenPowerOf(10, 10), tenPowerOf(0, 20)),
(20, -10, tenPowerOf(10, 20), tenPowerOf(20, 10)),
)
forEvery(testCases) { (inputScale, shifting, input, output) =>
eval(
e"SHIFT_BIGNUMERIC $shifting (TO_BIGNUMERIC_NUMERIC @$inputScale $input)"
) shouldBe Right(assertFromBigDecimal(new BigDecimal(output)))
}
}
}
"TO_TEXT_NUMERIC" - {
"returns proper results" in {
forEvery(decimals) { a =>