mirror of
https://github.com/digital-asset/daml.git
synced 2024-09-20 09:17:43 +03:00
Fix try-catch scoping bug in daml-script
(#16190)
* _tryCatch sig * CatchPayload now only continues via 'continue' field * fmt * Add test cases * update docs for CatchPayload * fix type of CatchPayload.continue * Add x type parameter to CatchPayload, drop toLedgerValue
This commit is contained in:
parent
fc8eadf164
commit
9582832c2c
@ -2,6 +2,7 @@
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
{-# LANGUAGE AllowAmbiguousTypes #-}
|
||||
{-# LANGUAGE CPP #-}
|
||||
{-# LANGUAGE InstanceSigs #-}
|
||||
|
||||
module Daml.Script
|
||||
( Script
|
||||
@ -148,18 +149,28 @@ data ScriptF a
|
||||
#ifndef DAML_EXCEPTIONS
|
||||
deriving Functor
|
||||
#else
|
||||
| Catch (CatchPayload a)
|
||||
| Catch (CatchPayload LedgerValue a)
|
||||
| Throw ThrowPayload
|
||||
deriving Functor
|
||||
|
||||
-- The slightly odd nesting here is required to preserve scoping.
|
||||
-- (try x catch f) >>= f should not pull `f` in the try block
|
||||
data CatchPayload a = CatchPayload
|
||||
-- `(try x catch f) >>= g` roughly translates to
|
||||
-- CatchPayload with
|
||||
-- act = \() -> x
|
||||
-- handle = f
|
||||
-- continue = g
|
||||
data CatchPayload x a = CatchPayload
|
||||
with
|
||||
act : () -> Free ScriptF (a, ())
|
||||
handle : AnyException -> Optional a
|
||||
-- TODO(MA): simplify to `act : () -> Free ScriptF x`
|
||||
act : () -> Free ScriptF (Free ScriptF x, ())
|
||||
handle : AnyException -> Optional (Free ScriptF x)
|
||||
continue : Free ScriptF x -> a
|
||||
deriving Functor
|
||||
|
||||
-- This wraps a `CatchPayload` whose inner action returns an `x` into an
|
||||
-- existential - we forget what `x` was.
|
||||
castCatchPayload : CatchPayload x a -> CatchPayload LedgerValue a
|
||||
castCatchPayload = error "foobar" -- gets replaced by the identity-function in script/Runner.scala
|
||||
|
||||
data ThrowPayload = ThrowPayload
|
||||
with
|
||||
exc: AnyException
|
||||
@ -168,9 +179,14 @@ instance ActionThrow Script where
|
||||
throw e = lift (Free (Throw (ThrowPayload (toAnyException e))))
|
||||
|
||||
instance ActionCatch Script where
|
||||
_tryCatch act handle = lift $ Free $ Catch $ CatchPayload with
|
||||
act = \() -> fmap (first pure) $ runScript (act ()) ()
|
||||
handle = \e -> fmap (\s -> fmap fst $ runScript s ()) (handle e)
|
||||
_tryCatch : forall t. (() -> Script t) -> (AnyException -> Optional (Script t)) -> Script t
|
||||
_tryCatch act handle = lift $ Free $ Catch $ castCatchPayload payload
|
||||
where
|
||||
payload : CatchPayload t (Free ScriptF t)
|
||||
payload = CatchPayload with
|
||||
act = \() -> fmap (first pure) $ runScript (act ()) ()
|
||||
handle = \e -> fmap (\s -> fmap fst $ runScript s ()) (handle e)
|
||||
continue = identity
|
||||
#endif
|
||||
|
||||
data QueryACS a = QueryACS
|
||||
|
@ -393,23 +393,27 @@ private[lf] class Runner(
|
||||
timeMode: ScriptTimeMode,
|
||||
) extends StrictLogging {
|
||||
|
||||
// We overwrite the definition of fromLedgerValue with an identity function.
|
||||
// We overwrite the definition of 'fromLedgerValue' with an identity function.
|
||||
// This is a type error but Speedy doesn’t care about the types and the only thing we do
|
||||
// with the result is convert it to ledger values/record so this is safe.
|
||||
// We do the same substitution for 'castCatchPayload' to circumvent Daml's
|
||||
// lack of existential types.
|
||||
private val extendedCompiledPackages = {
|
||||
val fromLedgerValue: PartialFunction[SDefinitionRef, SDefinition] = {
|
||||
val damlScriptDefs: PartialFunction[SDefinitionRef, SDefinition] = {
|
||||
case LfDefRef(id) if id == script.scriptIds.damlScript("fromLedgerValue") =>
|
||||
SDefinition(SEMakeClo(Array(), 1, SELocA(0)))
|
||||
case LfDefRef(id) if id == script.scriptIds.damlScript("castCatchPayload") =>
|
||||
SDefinition(SEMakeClo(Array(), 1, SELocA(0)))
|
||||
}
|
||||
new CompiledPackages(Runner.compilerConfig) {
|
||||
override def getDefinition(dref: SDefinitionRef): Option[SDefinition] =
|
||||
fromLedgerValue.andThen(Some(_)).applyOrElse(dref, compiledPackages.getDefinition)
|
||||
damlScriptDefs.andThen(Some(_)).applyOrElse(dref, compiledPackages.getDefinition)
|
||||
// FIXME: avoid override of non abstract method
|
||||
override def pkgInterface: PackageInterface = compiledPackages.pkgInterface
|
||||
override def packageIds: collection.Set[PackageId] = compiledPackages.packageIds
|
||||
// FIXME: avoid override of non abstract method
|
||||
override def definitions: PartialFunction[SDefinitionRef, SDefinition] =
|
||||
fromLedgerValue.orElse(compiledPackages.definitions)
|
||||
damlScriptDefs.orElse(compiledPackages.definitions)
|
||||
}
|
||||
}
|
||||
|
||||
@ -477,9 +481,9 @@ private[lf] class Runner(
|
||||
)
|
||||
.flatMap { scriptF =>
|
||||
scriptF match {
|
||||
case ScriptF.Catch(act, handle) =>
|
||||
case ScriptF.Catch(act, handle, continue) =>
|
||||
run(SEAppAtomic(SEValue(act), Array(SEValue(SUnit)))).transformWith {
|
||||
case Success(v) => Future.successful(SEValue(v))
|
||||
case Success(v) => Future.successful(SEApp(SEValue(continue), Array(v)))
|
||||
case Failure(
|
||||
exce @ Runner.InterpretationError(
|
||||
SError.SErrorDamlException(IE.UnhandledException(typ, value))
|
||||
@ -499,7 +503,8 @@ private[lf] class Runner(
|
||||
.flatMap {
|
||||
case SOptional(None) =>
|
||||
Future.failed(exce)
|
||||
case SOptional(Some(free)) => Future.successful(SEValue(free))
|
||||
case SOptional(Some(free)) =>
|
||||
Future.successful(SEApp(SEValue(continue), Array(free)))
|
||||
case e =>
|
||||
Future.failed(
|
||||
new ConverterException(s"Expected SOptional but got $e")
|
||||
|
@ -42,7 +42,7 @@ object ScriptF {
|
||||
cause,
|
||||
)
|
||||
|
||||
final case class Catch(act: SValue, handle: SValue) extends ScriptF
|
||||
final case class Catch(act: SValue, handle: SValue, continue: SValue) extends ScriptF
|
||||
final case class Throw(exc: SAny) extends ScriptF
|
||||
|
||||
sealed trait Cmd extends ScriptF {
|
||||
@ -946,8 +946,8 @@ object ScriptF {
|
||||
|
||||
private def parseCatch(v: SValue): Either[String, Catch] = {
|
||||
v match {
|
||||
case SRecord(_, _, ArrayList(act, handle)) =>
|
||||
Right(Catch(act, handle))
|
||||
case SRecord(_, _, ArrayList(act, handle, continue)) =>
|
||||
Right(Catch(act, handle, continue))
|
||||
case _ => Left(s"Expected Catch payload but got $v")
|
||||
}
|
||||
|
||||
|
@ -70,3 +70,40 @@ test = script do
|
||||
catch
|
||||
Test2 _ -> pure 42
|
||||
exc === 42
|
||||
|
||||
-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132)
|
||||
try_catch_then_error : Script ()
|
||||
try_catch_then_error = do
|
||||
{- 1 -} wasThrown <-
|
||||
{- 2 -} try do pure False
|
||||
{- 3 -} catch (_ : AnyException) -> pure True
|
||||
{- 4 -} if wasThrown
|
||||
{- 5 -} then pure ()
|
||||
{- 6 -} else error "expected exception"
|
||||
|
||||
-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132)
|
||||
try_catch_then_fail : Script ()
|
||||
try_catch_then_fail = do
|
||||
{- 1 -} wasThrown <-
|
||||
{- 2 -} try do pure False
|
||||
{- 3 -} catch (_ : AnyException) -> pure True
|
||||
{- 4 -} if wasThrown
|
||||
{- 5 -} then pure ()
|
||||
{- 6 -} else fail "expected exception"
|
||||
|
||||
-- tests that the error from {- 6 -} doesn't get caught in {- 3 -} (#16132)
|
||||
try_catch_then_abort : Script ()
|
||||
try_catch_then_abort = do
|
||||
{- 1 -} wasThrown <-
|
||||
{- 2 -} try do pure False
|
||||
{- 3 -} catch (_ : AnyException) -> pure True
|
||||
{- 4 -} if wasThrown
|
||||
{- 5 -} then pure ()
|
||||
{- 6 -} else abort "expected exception"
|
||||
|
||||
try_catch_recover : Script ()
|
||||
try_catch_recover = do
|
||||
x <-
|
||||
try do error "uh-oh"
|
||||
catch (e : AnyException) -> pure (message e <> "!")
|
||||
x === "uh-oh!"
|
||||
|
@ -7,6 +7,7 @@ import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
|
||||
import com.daml.lf.data.Ref._
|
||||
import com.daml.lf.data.{FrontStack, FrontStackCons, Numeric}
|
||||
import com.daml.lf.engine.script.{ScriptF, StackTrace}
|
||||
import com.daml.lf.engine.script.Runner.InterpretationError
|
||||
import com.daml.lf.speedy.SValue
|
||||
import com.daml.lf.speedy.SValue._
|
||||
import io.grpc.{Status, StatusRuntimeException}
|
||||
@ -331,6 +332,74 @@ abstract class AbstractFuncIT
|
||||
}
|
||||
}
|
||||
}
|
||||
"Exceptions:try_catch_then_error" should {
|
||||
"fail" in {
|
||||
for {
|
||||
clients <- participantClients()
|
||||
exception <- recoverToExceptionIf[InterpretationError](
|
||||
run(
|
||||
clients,
|
||||
QualifiedName.assertFromString("TestExceptions:try_catch_then_error"),
|
||||
dar = devDar,
|
||||
)
|
||||
).map(_.toString)
|
||||
} yield {
|
||||
exception should include("Unhandled Daml exception")
|
||||
exception should include("GeneralError")
|
||||
exception should include("expected exception")
|
||||
}
|
||||
}
|
||||
}
|
||||
"Exceptions:try_catch_then_fail" should {
|
||||
"fail" in {
|
||||
for {
|
||||
clients <- participantClients()
|
||||
exception <- recoverToExceptionIf[InterpretationError](
|
||||
run(
|
||||
clients,
|
||||
QualifiedName.assertFromString("TestExceptions:try_catch_then_fail"),
|
||||
dar = devDar,
|
||||
)
|
||||
).map(_.toString)
|
||||
} yield {
|
||||
exception should include("Unhandled Daml exception")
|
||||
exception should include("GeneralError")
|
||||
exception should include("expected exception")
|
||||
}
|
||||
}
|
||||
}
|
||||
"Exceptions:try_catch_then_abort" should {
|
||||
"fail" in {
|
||||
for {
|
||||
clients <- participantClients()
|
||||
exception <- recoverToExceptionIf[InterpretationError](
|
||||
run(
|
||||
clients,
|
||||
QualifiedName.assertFromString("TestExceptions:try_catch_then_abort"),
|
||||
dar = devDar,
|
||||
)
|
||||
).map(_.toString)
|
||||
} yield {
|
||||
exception should include("Unhandled Daml exception")
|
||||
exception should include("GeneralError")
|
||||
exception should include("expected exception")
|
||||
}
|
||||
}
|
||||
}
|
||||
"Exceptions:try_catch_recover" should {
|
||||
"succeed" in {
|
||||
for {
|
||||
clients <- participantClients()
|
||||
v <- run(
|
||||
clients,
|
||||
QualifiedName.assertFromString("TestExceptions:try_catch_recover"),
|
||||
dar = devDar,
|
||||
)
|
||||
} yield {
|
||||
v shouldBe (SUnit)
|
||||
}
|
||||
}
|
||||
}
|
||||
"Interface:test_queryInterface" should {
|
||||
"succeed" in {
|
||||
for {
|
||||
|
Loading…
Reference in New Issue
Block a user