From 9582832c2cb8465b2190103cc6945cd5d70dee39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mois=C3=A9s=20Ackerman?= <6054733+akrmn@users.noreply.github.com> Date: Mon, 6 Feb 2023 14:01:40 +0100 Subject: [PATCH] Fix try-catch scoping bug in `daml-script` (#16190) * _tryCatch sig * CatchPayload now only continues via 'continue' field * fmt * Add test cases * update docs for CatchPayload * fix type of CatchPayload.continue * Add x type parameter to CatchPayload, drop toLedgerValue --- daml-script/daml/Daml/Script.daml | 34 ++++++--- .../daml/lf/engine/script/Runner.scala | 19 +++-- .../daml/lf/engine/script/ScriptF.scala | 6 +- daml-script/test/daml/TestExceptions.daml | 37 ++++++++++ .../engine/script/test/AbstractFuncIT.scala | 69 +++++++++++++++++++ 5 files changed, 146 insertions(+), 19 deletions(-) diff --git a/daml-script/daml/Daml/Script.daml b/daml-script/daml/Daml/Script.daml index 157008b9c2..dc5904c43d 100644 --- a/daml-script/daml/Daml/Script.daml +++ b/daml-script/daml/Daml/Script.daml @@ -2,6 +2,7 @@ -- SPDX-License-Identifier: Apache-2.0 {-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE CPP #-} +{-# LANGUAGE InstanceSigs #-} module Daml.Script ( Script @@ -148,18 +149,28 @@ data ScriptF a #ifndef DAML_EXCEPTIONS deriving Functor #else - | Catch (CatchPayload a) + | Catch (CatchPayload LedgerValue a) | Throw ThrowPayload deriving Functor --- The slightly odd nesting here is required to preserve scoping. --- (try x catch f) >>= f should not pull `f` in the try block -data CatchPayload a = CatchPayload +-- `(try x catch f) >>= g` roughly translates to +-- CatchPayload with +-- act = \() -> x +-- handle = f +-- continue = g +data CatchPayload x a = CatchPayload with - act : () -> Free ScriptF (a, ()) - handle : AnyException -> Optional a + -- TODO(MA): simplify to `act : () -> Free ScriptF x` + act : () -> Free ScriptF (Free ScriptF x, ()) + handle : AnyException -> Optional (Free ScriptF x) + continue : Free ScriptF x -> a deriving Functor +-- This wraps a `CatchPayload` whose inner action returns an `x` into an +-- existential - we forget what `x` was. +castCatchPayload : CatchPayload x a -> CatchPayload LedgerValue a +castCatchPayload = error "foobar" -- gets replaced by the identity-function in script/Runner.scala + data ThrowPayload = ThrowPayload with exc: AnyException @@ -168,9 +179,14 @@ instance ActionThrow Script where throw e = lift (Free (Throw (ThrowPayload (toAnyException e)))) instance ActionCatch Script where - _tryCatch act handle = lift $ Free $ Catch $ CatchPayload with - act = \() -> fmap (first pure) $ runScript (act ()) () - handle = \e -> fmap (\s -> fmap fst $ runScript s ()) (handle e) + _tryCatch : forall t. (() -> Script t) -> (AnyException -> Optional (Script t)) -> Script t + _tryCatch act handle = lift $ Free $ Catch $ castCatchPayload payload + where + payload : CatchPayload t (Free ScriptF t) + payload = CatchPayload with + act = \() -> fmap (first pure) $ runScript (act ()) () + handle = \e -> fmap (\s -> fmap fst $ runScript s ()) (handle e) + continue = identity #endif data QueryACS a = QueryACS diff --git a/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/Runner.scala b/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/Runner.scala index 5ae02c28c8..a1d3579ed0 100644 --- a/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/Runner.scala +++ b/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/Runner.scala @@ -393,23 +393,27 @@ private[lf] class Runner( timeMode: ScriptTimeMode, ) extends StrictLogging { - // We overwrite the definition of fromLedgerValue with an identity function. + // We overwrite the definition of 'fromLedgerValue' with an identity function. // This is a type error but Speedy doesn’t care about the types and the only thing we do // with the result is convert it to ledger values/record so this is safe. + // We do the same substitution for 'castCatchPayload' to circumvent Daml's + // lack of existential types. private val extendedCompiledPackages = { - val fromLedgerValue: PartialFunction[SDefinitionRef, SDefinition] = { + val damlScriptDefs: PartialFunction[SDefinitionRef, SDefinition] = { case LfDefRef(id) if id == script.scriptIds.damlScript("fromLedgerValue") => SDefinition(SEMakeClo(Array(), 1, SELocA(0))) + case LfDefRef(id) if id == script.scriptIds.damlScript("castCatchPayload") => + SDefinition(SEMakeClo(Array(), 1, SELocA(0))) } new CompiledPackages(Runner.compilerConfig) { override def getDefinition(dref: SDefinitionRef): Option[SDefinition] = - fromLedgerValue.andThen(Some(_)).applyOrElse(dref, compiledPackages.getDefinition) + damlScriptDefs.andThen(Some(_)).applyOrElse(dref, compiledPackages.getDefinition) // FIXME: avoid override of non abstract method override def pkgInterface: PackageInterface = compiledPackages.pkgInterface override def packageIds: collection.Set[PackageId] = compiledPackages.packageIds // FIXME: avoid override of non abstract method override def definitions: PartialFunction[SDefinitionRef, SDefinition] = - fromLedgerValue.orElse(compiledPackages.definitions) + damlScriptDefs.orElse(compiledPackages.definitions) } } @@ -477,9 +481,9 @@ private[lf] class Runner( ) .flatMap { scriptF => scriptF match { - case ScriptF.Catch(act, handle) => + case ScriptF.Catch(act, handle, continue) => run(SEAppAtomic(SEValue(act), Array(SEValue(SUnit)))).transformWith { - case Success(v) => Future.successful(SEValue(v)) + case Success(v) => Future.successful(SEApp(SEValue(continue), Array(v))) case Failure( exce @ Runner.InterpretationError( SError.SErrorDamlException(IE.UnhandledException(typ, value)) @@ -499,7 +503,8 @@ private[lf] class Runner( .flatMap { case SOptional(None) => Future.failed(exce) - case SOptional(Some(free)) => Future.successful(SEValue(free)) + case SOptional(Some(free)) => + Future.successful(SEApp(SEValue(continue), Array(free))) case e => Future.failed( new ConverterException(s"Expected SOptional but got $e") diff --git a/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/ScriptF.scala b/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/ScriptF.scala index 68929f9480..12364e77e1 100644 --- a/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/ScriptF.scala +++ b/daml-script/runner/src/main/scala/com/digitalasset/daml/lf/engine/script/ScriptF.scala @@ -42,7 +42,7 @@ object ScriptF { cause, ) - final case class Catch(act: SValue, handle: SValue) extends ScriptF + final case class Catch(act: SValue, handle: SValue, continue: SValue) extends ScriptF final case class Throw(exc: SAny) extends ScriptF sealed trait Cmd extends ScriptF { @@ -946,8 +946,8 @@ object ScriptF { private def parseCatch(v: SValue): Either[String, Catch] = { v match { - case SRecord(_, _, ArrayList(act, handle)) => - Right(Catch(act, handle)) + case SRecord(_, _, ArrayList(act, handle, continue)) => + Right(Catch(act, handle, continue)) case _ => Left(s"Expected Catch payload but got $v") } diff --git a/daml-script/test/daml/TestExceptions.daml b/daml-script/test/daml/TestExceptions.daml index 85f6ded2f5..3903e73416 100644 --- a/daml-script/test/daml/TestExceptions.daml +++ b/daml-script/test/daml/TestExceptions.daml @@ -70,3 +70,40 @@ test = script do catch Test2 _ -> pure 42 exc === 42 + +-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132) +try_catch_then_error : Script () +try_catch_then_error = do + {- 1 -} wasThrown <- + {- 2 -} try do pure False + {- 3 -} catch (_ : AnyException) -> pure True + {- 4 -} if wasThrown + {- 5 -} then pure () + {- 6 -} else error "expected exception" + +-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132) +try_catch_then_fail : Script () +try_catch_then_fail = do + {- 1 -} wasThrown <- + {- 2 -} try do pure False + {- 3 -} catch (_ : AnyException) -> pure True + {- 4 -} if wasThrown + {- 5 -} then pure () + {- 6 -} else fail "expected exception" + +-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132) +try_catch_then_abort : Script () +try_catch_then_abort = do + {- 1 -} wasThrown <- + {- 2 -} try do pure False + {- 3 -} catch (_ : AnyException) -> pure True + {- 4 -} if wasThrown + {- 5 -} then pure () + {- 6 -} else abort "expected exception" + +try_catch_recover : Script () +try_catch_recover = do + x <- + try do error "uh-oh" + catch (e : AnyException) -> pure (message e <> "!") + x === "uh-oh!" diff --git a/daml-script/test/src/test-utils/com/daml/lf/engine/script/test/AbstractFuncIT.scala b/daml-script/test/src/test-utils/com/daml/lf/engine/script/test/AbstractFuncIT.scala index 61db003d32..4289e8134c 100644 --- a/daml-script/test/src/test-utils/com/daml/lf/engine/script/test/AbstractFuncIT.scala +++ b/daml-script/test/src/test-utils/com/daml/lf/engine/script/test/AbstractFuncIT.scala @@ -7,6 +7,7 @@ import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll import com.daml.lf.data.Ref._ import com.daml.lf.data.{FrontStack, FrontStackCons, Numeric} import com.daml.lf.engine.script.{ScriptF, StackTrace} +import com.daml.lf.engine.script.Runner.InterpretationError import com.daml.lf.speedy.SValue import com.daml.lf.speedy.SValue._ import io.grpc.{Status, StatusRuntimeException} @@ -331,6 +332,74 @@ abstract class AbstractFuncIT } } } + "Exceptions:try_catch_then_error" should { + "fail" in { + for { + clients <- participantClients() + exception <- recoverToExceptionIf[InterpretationError]( + run( + clients, + QualifiedName.assertFromString("TestExceptions:try_catch_then_error"), + dar = devDar, + ) + ).map(_.toString) + } yield { + exception should include("Unhandled Daml exception") + exception should include("GeneralError") + exception should include("expected exception") + } + } + } + "Exceptions:try_catch_then_fail" should { + "fail" in { + for { + clients <- participantClients() + exception <- recoverToExceptionIf[InterpretationError]( + run( + clients, + QualifiedName.assertFromString("TestExceptions:try_catch_then_fail"), + dar = devDar, + ) + ).map(_.toString) + } yield { + exception should include("Unhandled Daml exception") + exception should include("GeneralError") + exception should include("expected exception") + } + } + } + "Exceptions:try_catch_then_abort" should { + "fail" in { + for { + clients <- participantClients() + exception <- recoverToExceptionIf[InterpretationError]( + run( + clients, + QualifiedName.assertFromString("TestExceptions:try_catch_then_abort"), + dar = devDar, + ) + ).map(_.toString) + } yield { + exception should include("Unhandled Daml exception") + exception should include("GeneralError") + exception should include("expected exception") + } + } + } + "Exceptions:try_catch_recover" should { + "succeed" in { + for { + clients <- participantClients() + v <- run( + clients, + QualifiedName.assertFromString("TestExceptions:try_catch_recover"), + dar = devDar, + ) + } yield { + v shouldBe (SUnit) + } + } + } "Interface:test_queryInterface" should { "succeed" in { for {