Fix try-catch scoping bug in daml-script (#16190)

* _tryCatch sig

* CatchPayload now only continues via 'continue' field

* fmt

* Add test cases

* update docs for CatchPayload

* fix type of CatchPayload.continue

* Add x type parameter to CatchPayload, drop toLedgerValue
This commit is contained in:
Moisés Ackerman 2023-02-06 14:01:40 +01:00 committed by GitHub
parent fc8eadf164
commit 9582832c2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 146 additions and 19 deletions

View File

@ -2,6 +2,7 @@
-- SPDX-License-Identifier: Apache-2.0
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE InstanceSigs #-}
module Daml.Script
( Script
@ -148,18 +149,28 @@ data ScriptF a
#ifndef DAML_EXCEPTIONS
deriving Functor
#else
| Catch (CatchPayload a)
| Catch (CatchPayload LedgerValue a)
| Throw ThrowPayload
deriving Functor
-- The slightly odd nesting here is required to preserve scoping.
-- (try x catch f) >>= f should not pull `f` in the try block
data CatchPayload a = CatchPayload
-- `(try x catch f) >>= g` roughly translates to
-- CatchPayload with
-- act = \() -> x
-- handle = f
-- continue = g
data CatchPayload x a = CatchPayload
with
act : () -> Free ScriptF (a, ())
handle : AnyException -> Optional a
-- TODO(MA): simplify to `act : () -> Free ScriptF x`
act : () -> Free ScriptF (Free ScriptF x, ())
handle : AnyException -> Optional (Free ScriptF x)
continue : Free ScriptF x -> a
deriving Functor
-- This wraps a `CatchPayload` whose inner action returns an `x` into an
-- existential - we forget what `x` was.
castCatchPayload : CatchPayload x a -> CatchPayload LedgerValue a
castCatchPayload = error "foobar" -- gets replaced by the identity-function in script/Runner.scala
data ThrowPayload = ThrowPayload
with
exc: AnyException
@ -168,9 +179,14 @@ instance ActionThrow Script where
throw e = lift (Free (Throw (ThrowPayload (toAnyException e))))
instance ActionCatch Script where
_tryCatch act handle = lift $ Free $ Catch $ CatchPayload with
act = \() -> fmap (first pure) $ runScript (act ()) ()
handle = \e -> fmap (\s -> fmap fst $ runScript s ()) (handle e)
_tryCatch : forall t. (() -> Script t) -> (AnyException -> Optional (Script t)) -> Script t
_tryCatch act handle = lift $ Free $ Catch $ castCatchPayload payload
where
payload : CatchPayload t (Free ScriptF t)
payload = CatchPayload with
act = \() -> fmap (first pure) $ runScript (act ()) ()
handle = \e -> fmap (\s -> fmap fst $ runScript s ()) (handle e)
continue = identity
#endif
data QueryACS a = QueryACS

View File

@ -393,23 +393,27 @@ private[lf] class Runner(
timeMode: ScriptTimeMode,
) extends StrictLogging {
// We overwrite the definition of fromLedgerValue with an identity function.
// We overwrite the definition of 'fromLedgerValue' with an identity function.
// This is a type error but Speedy doesnt care about the types and the only thing we do
// with the result is convert it to ledger values/record so this is safe.
// We do the same substitution for 'castCatchPayload' to circumvent Daml's
// lack of existential types.
private val extendedCompiledPackages = {
val fromLedgerValue: PartialFunction[SDefinitionRef, SDefinition] = {
val damlScriptDefs: PartialFunction[SDefinitionRef, SDefinition] = {
case LfDefRef(id) if id == script.scriptIds.damlScript("fromLedgerValue") =>
SDefinition(SEMakeClo(Array(), 1, SELocA(0)))
case LfDefRef(id) if id == script.scriptIds.damlScript("castCatchPayload") =>
SDefinition(SEMakeClo(Array(), 1, SELocA(0)))
}
new CompiledPackages(Runner.compilerConfig) {
override def getDefinition(dref: SDefinitionRef): Option[SDefinition] =
fromLedgerValue.andThen(Some(_)).applyOrElse(dref, compiledPackages.getDefinition)
damlScriptDefs.andThen(Some(_)).applyOrElse(dref, compiledPackages.getDefinition)
// FIXME: avoid override of non abstract method
override def pkgInterface: PackageInterface = compiledPackages.pkgInterface
override def packageIds: collection.Set[PackageId] = compiledPackages.packageIds
// FIXME: avoid override of non abstract method
override def definitions: PartialFunction[SDefinitionRef, SDefinition] =
fromLedgerValue.orElse(compiledPackages.definitions)
damlScriptDefs.orElse(compiledPackages.definitions)
}
}
@ -477,9 +481,9 @@ private[lf] class Runner(
)
.flatMap { scriptF =>
scriptF match {
case ScriptF.Catch(act, handle) =>
case ScriptF.Catch(act, handle, continue) =>
run(SEAppAtomic(SEValue(act), Array(SEValue(SUnit)))).transformWith {
case Success(v) => Future.successful(SEValue(v))
case Success(v) => Future.successful(SEApp(SEValue(continue), Array(v)))
case Failure(
exce @ Runner.InterpretationError(
SError.SErrorDamlException(IE.UnhandledException(typ, value))
@ -499,7 +503,8 @@ private[lf] class Runner(
.flatMap {
case SOptional(None) =>
Future.failed(exce)
case SOptional(Some(free)) => Future.successful(SEValue(free))
case SOptional(Some(free)) =>
Future.successful(SEApp(SEValue(continue), Array(free)))
case e =>
Future.failed(
new ConverterException(s"Expected SOptional but got $e")

View File

@ -42,7 +42,7 @@ object ScriptF {
cause,
)
final case class Catch(act: SValue, handle: SValue) extends ScriptF
final case class Catch(act: SValue, handle: SValue, continue: SValue) extends ScriptF
final case class Throw(exc: SAny) extends ScriptF
sealed trait Cmd extends ScriptF {
@ -946,8 +946,8 @@ object ScriptF {
private def parseCatch(v: SValue): Either[String, Catch] = {
v match {
case SRecord(_, _, ArrayList(act, handle)) =>
Right(Catch(act, handle))
case SRecord(_, _, ArrayList(act, handle, continue)) =>
Right(Catch(act, handle, continue))
case _ => Left(s"Expected Catch payload but got $v")
}

View File

@ -70,3 +70,40 @@ test = script do
catch
Test2 _ -> pure 42
exc === 42
-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132)
try_catch_then_error : Script ()
try_catch_then_error = do
{- 1 -} wasThrown <-
{- 2 -} try do pure False
{- 3 -} catch (_ : AnyException) -> pure True
{- 4 -} if wasThrown
{- 5 -} then pure ()
{- 6 -} else error "expected exception"
-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132)
try_catch_then_fail : Script ()
try_catch_then_fail = do
{- 1 -} wasThrown <-
{- 2 -} try do pure False
{- 3 -} catch (_ : AnyException) -> pure True
{- 4 -} if wasThrown
{- 5 -} then pure ()
{- 6 -} else fail "expected exception"
-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132)
try_catch_then_abort : Script ()
try_catch_then_abort = do
{- 1 -} wasThrown <-
{- 2 -} try do pure False
{- 3 -} catch (_ : AnyException) -> pure True
{- 4 -} if wasThrown
{- 5 -} then pure ()
{- 6 -} else abort "expected exception"
try_catch_recover : Script ()
try_catch_recover = do
x <-
try do error "uh-oh"
catch (e : AnyException) -> pure (message e <> "!")
x === "uh-oh!"

View File

@ -7,6 +7,7 @@ import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
import com.daml.lf.data.Ref._
import com.daml.lf.data.{FrontStack, FrontStackCons, Numeric}
import com.daml.lf.engine.script.{ScriptF, StackTrace}
import com.daml.lf.engine.script.Runner.InterpretationError
import com.daml.lf.speedy.SValue
import com.daml.lf.speedy.SValue._
import io.grpc.{Status, StatusRuntimeException}
@ -331,6 +332,74 @@ abstract class AbstractFuncIT
}
}
}
"Exceptions:try_catch_then_error" should {
"fail" in {
for {
clients <- participantClients()
exception <- recoverToExceptionIf[InterpretationError](
run(
clients,
QualifiedName.assertFromString("TestExceptions:try_catch_then_error"),
dar = devDar,
)
).map(_.toString)
} yield {
exception should include("Unhandled Daml exception")
exception should include("GeneralError")
exception should include("expected exception")
}
}
}
"Exceptions:try_catch_then_fail" should {
"fail" in {
for {
clients <- participantClients()
exception <- recoverToExceptionIf[InterpretationError](
run(
clients,
QualifiedName.assertFromString("TestExceptions:try_catch_then_fail"),
dar = devDar,
)
).map(_.toString)
} yield {
exception should include("Unhandled Daml exception")
exception should include("GeneralError")
exception should include("expected exception")
}
}
}
"Exceptions:try_catch_then_abort" should {
"fail" in {
for {
clients <- participantClients()
exception <- recoverToExceptionIf[InterpretationError](
run(
clients,
QualifiedName.assertFromString("TestExceptions:try_catch_then_abort"),
dar = devDar,
)
).map(_.toString)
} yield {
exception should include("Unhandled Daml exception")
exception should include("GeneralError")
exception should include("expected exception")
}
}
}
"Exceptions:try_catch_recover" should {
"succeed" in {
for {
clients <- participantClients()
v <- run(
clients,
QualifiedName.assertFromString("TestExceptions:try_catch_recover"),
dar = devDar,
)
} yield {
v shouldBe (SUnit)
}
}
}
"Interface:test_queryInterface" should {
"succeed" in {
for {