diff --git a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SBuiltin.scala b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SBuiltin.scala index e64463a866a..e683366f87e 100644 --- a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SBuiltin.scala +++ b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SBuiltin.scala @@ -1395,11 +1395,28 @@ private[lf] object SBuiltin { checkToken(token) args.get(0) match { case SOptional(opt) => - opt match { - case None => - unwindToHandler(machine, payload) //re-throw - case Some(handler) => - machine.enterApplication(handler, Array(SEValue(token))) + machine.withOnLedger("SBTryHandler") { onLedger => + opt match { + case None => + onLedger.ptx = onLedger.ptx.abortTry + unwindToHandler(machine, payload) //re-throw + case Some(handler) => + payload match { + case SAnyException(typ, _, sv) => + // TODO https://github.com/digital-asset/daml/issues/8020 + // Must convert speedy value to LF value, but currently this crashes on SBuiltinException + // so make a hack workaround: + val value = + sv match { + case _: SBuiltinException => V.ValueInt64(999) + case _ => sv.toValue + } + onLedger.ptx = onLedger.ptx.rollbackTry(typ, value) + case _ => + crash(s"SBTryHandler, expected payload to be SAnyException: $payload") + } + machine.enterApplication(handler, Array(SEValue(token))) + } } case v => crash(s"invalid argument to SBTryHandler (expected SOptional): $v") diff --git a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SExpr.scala b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SExpr.scala index 4eb5236cf29..37f52475b29 100644 --- a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SExpr.scala +++ b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/SExpr.scala @@ -378,6 +378,9 @@ object SExpr { def execute(machine: Machine): Unit = { machine.pushKont(KTryCatchHandler(machine, handler)) machine.ctrl = body + machine.withOnLedger("SETryCatch") { onLedger => + onLedger.ptx = onLedger.ptx.beginTry + } } } diff --git a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Speedy.scala b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Speedy.scala index 873882d2db5..147a33438f6 100644 --- a/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Speedy.scala +++ b/daml-lf/interpreter/src/main/scala/com/digitalasset/daml/lf/speedy/Speedy.scala @@ -169,7 +169,7 @@ private[lf] object Speedy { private[lf] def withOnLedger[T](op: String)(f: OnLedger => T): T = ledgerMode match { - case onLedger @ OnLedger(_, _, _, _, _, _, _, _) => f(onLedger) + case onLedger: OnLedger => f(onLedger) case OffLedger => throw SRequiresOnLedger(op) } @@ -816,6 +816,26 @@ private[lf] object Speedy { ) } + @throws[PackageNotFound] + @throws[CompilationError] + // Construct a machine for running an update expression (testing -- avoiding scenarios) + def fromUpdateExpr( + compiledPackages: CompiledPackages, + transactionSeed: crypto.Hash, + updateE: Expr, + committer: Party, + ): Machine = { + val updateSE: SExpr = compiledPackages.compiler.unsafeCompile(updateE) + Machine( + compiledPackages = compiledPackages, + submissionTime = Time.Timestamp.MinValue, + initialSeeding = InitialSeeding.TransactionSeed(transactionSeed), + expr = SEApp(updateSE, Array(SEValue.Token)), + globalCids = Set.empty, + committers = Set(committer), + ) + } + @throws[PackageNotFound] @throws[CompilationError] // Construct a machine for running scenario. @@ -1298,6 +1318,9 @@ private[lf] object Speedy { def execute(v: SValue) = { restore() + machine.withOnLedger("KTryCatchHandler") { onLedger => + onLedger.ptx = onLedger.ptx.endTry + } machine.returnValue = v } } diff --git a/daml-lf/interpreter/src/test/scala/com/digitalasset/daml/lf/speedy/RollbackTest.scala b/daml-lf/interpreter/src/test/scala/com/digitalasset/daml/lf/speedy/RollbackTest.scala new file mode 100644 index 00000000000..fd89cc19244 --- /dev/null +++ b/daml-lf/interpreter/src/test/scala/com/digitalasset/daml/lf/speedy/RollbackTest.scala @@ -0,0 +1,187 @@ +// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.daml.lf +package speedy + +import com.daml.lf.PureCompiledPackages +import com.daml.lf.data.ImmArray +import com.daml.lf.data.Ref.Party +import com.daml.lf.language.Ast._ +import com.daml.lf.speedy.Compiler.FullStackTrace +import com.daml.lf.speedy.PartialTransaction._ +import com.daml.lf.speedy.SResult._ +import com.daml.lf.testing.parser.Implicits._ +import com.daml.lf.transaction.Node +import com.daml.lf.transaction.SubmittedTransaction +import com.daml.lf.validation.Validation +import com.daml.lf.value.Value +import com.daml.lf.value.Value._ + +import org.scalatest.matchers.should.Matchers +import org.scalatest.prop.TableDrivenPropertyChecks +import org.scalatest.wordspec.AnyWordSpec + +class ExceptionTest extends AnyWordSpec with Matchers with TableDrivenPropertyChecks { + + private def typeAndCompile(pkg: Package): PureCompiledPackages = { + val rawPkgs = Map(defaultParserParameters.defaultPackageId -> pkg) + Validation.checkPackage(rawPkgs, defaultParserParameters.defaultPackageId, pkg) + data.assertRight( + PureCompiledPackages(rawPkgs, Compiler.Config.Default.copy(stacktracing = FullStackTrace)) + ) + } + + private def runUpdateExprGetTx( + pkgs1: PureCompiledPackages + )(e: Expr, party: Party): SubmittedTransaction = { + def transactionSeed: crypto.Hash = crypto.Hash.hashPrivateKey("RollbackTest.scala") + val machine = Speedy.Machine.fromUpdateExpr(pkgs1, transactionSeed, e, party) + val res = machine.run() + res match { + case _: SResultFinalValue => + machine.withOnLedger("RollbackTest") { onLedger => + onLedger.ptx.finish match { + case IncompleteTransaction(_) => + sys.error("unexpected IncompleteTransaction") + case CompleteTransaction(tx) => + tx + } + } + case _ => + sys.error(s"unexpected res: $res") + } + } + + val pkgs: PureCompiledPackages = typeAndCompile(p""" + module M { + + record @serializable MyException = { message: Text } ; + + exception MyException = { + message \(e: M:MyException) -> M:MyException {message} e + }; + + record @serializable T1 = { party: Party, info: Int64 } ; + template (record : T1) = { + precondition True, + signatories Cons @Party [(M:T1 {party} record)] (Nil @Party), + observers Nil @Party, + agreement "Agreement", + choices { + } + }; + + val create0 : Party -> Update Unit = \(party: Party) -> + upure @Unit (); + + val create1 : Party -> Update Unit = \(party: Party) -> + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 } + in upure @Unit (); + + val create2 : Party -> Update Unit = \(party: Party) -> + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 }; + x2: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 200 } + in upure @Unit (); + + val create3 : Party -> Update Unit = \(party: Party) -> + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 }; + x2: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 200 }; + x3: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 300 } + in upure @Unit (); + + val create3nested : Party -> Update Unit = \(party: Party) -> + ubind + u1: Unit <- + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 }; + x2: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 200 } + in upure @Unit (); + x3: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 300 } + in upure @Unit (); + + val create3catchNoThrow : Party -> Update Unit = \(party: Party) -> + ubind + u1: Unit <- + try @Unit + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 }; + x2: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 200 } + in upure @Unit () + catch e -> Some @(Update Unit) (upure @Unit ()) + ; + x3: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 300 } + in upure @Unit (); + + val create3throwAndCatch : Party -> Update Unit = \(party: Party) -> + ubind + u1: Unit <- + try @Unit + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 }; + x2: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 200 } + in throw @(Update Unit) @M:MyException (M:MyException {message = "oops"}) + catch e -> Some @(Update Unit) (upure @Unit ()) + ; + x3: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 300 } + in upure @Unit (); + + val create3throwAndOuterCatch : Party -> Update Unit = \(party: Party) -> + ubind + u1: Unit <- + try @Unit + try @Unit + ubind + x1: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 100 }; + x2: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 200 } + in throw @(Update Unit) @M:MyException (M:MyException {message = "oops"}) + catch e -> None @(Update Unit) + catch e -> Some @(Update Unit) (upure @Unit ()) + ; + x3: ContractId M:T1 <- create @M:T1 M:T1 { party = party, info = 300 } + in upure @Unit (); + + } + """) + + val testCases = Table[String, List[Long]]( + ("expression", "expected-number-of-contracts"), + ("create0", Nil), + ("create1", List(100)), + ("create2", List(100, 200)), + ("create3", List(100, 200, 300)), + ("create3nested", List(100, 200, 300)), + ("create3catchNoThrow", List(100, 200, 300)), + ("create3throwAndCatch", List(300)), + ("create3throwAndOuterCatch", List(300)), + ) + + forEvery(testCases) { (exp: String, expected: List[Long]) => + s"""$exp, contracts expected: $expected """ in { + val party = Party.assertFromString("Alice") + val lit: PrimLit = PLParty(party) + val arg: Expr = EPrimLit(lit) + val example: Expr = EApp(e"M:$exp", arg) + val tx: SubmittedTransaction = runUpdateExprGetTx(pkgs)(example, party) + val ids: Seq[Long] = contractValuesInOrder(tx) + ids shouldBe expected + } + } + + private def contractValuesInOrder(tx: SubmittedTransaction): Seq[Long] = { + tx.fold(Vector.empty[Long]) { + case (acc, (_, create: Node.NodeCreate[Value.ContractId])) => + create.arg match { + case ValueRecord(_, ImmArray(_, (Some("info"), ValueInt64(n)))) => + acc :+ n + case _ => + sys.error(s"unexpected create.arg: ${create.arg}") + } + case (acc, _) => acc + } + } + +}