diff --git a/canton/community/base/src/main/scala/com/digitalasset/canton/lifecycle/FlagCloseable.scala b/canton/community/base/src/main/scala/com/digitalasset/canton/lifecycle/FlagCloseable.scala index 472e2cf4c2..0f0c1a4603 100644 --- a/canton/community/base/src/main/scala/com/digitalasset/canton/lifecycle/FlagCloseable.scala +++ b/canton/community/base/src/main/scala/com/digitalasset/canton/lifecycle/FlagCloseable.scala @@ -6,17 +6,18 @@ package com.digitalasset.canton.lifecycle import cats.data.EitherT import cats.syntax.traverse.* import com.digitalasset.canton.DiscardOps -import com.digitalasset.canton.concurrent.Threading +import com.digitalasset.canton.concurrent.{FutureSupervisor, Threading} import com.digitalasset.canton.config.ProcessingTimeout import com.digitalasset.canton.lifecycle.FlagCloseable.forceShutdownStr -import com.digitalasset.canton.logging.TracedLogger +import com.digitalasset.canton.logging.{ErrorLoggingContext, TracedLogger} import com.digitalasset.canton.tracing.TraceContext import com.digitalasset.canton.util.Thereafter.syntax.* import com.digitalasset.canton.util.{Checked, CheckedT, Thereafter} +import org.slf4j.event.Level import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.collection.immutable.MultiSet -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.duration.{Duration, DurationInt, FiniteDuration} import scala.concurrent.{ExecutionContext, Future} import scala.util.Try import scala.util.control.NonFatal @@ -335,6 +336,37 @@ object CloseContext { } /** Mix-in to obtain a [[CloseContext]] implicit based on the class's [[FlagCloseable]] */ -trait HasCloseContext { self: FlagCloseable => +trait HasCloseContext extends PromiseUnlessShutdownFactory { self: FlagCloseable => implicit val closeContext: CloseContext = CloseContext(self) } + +trait PromiseUnlessShutdownFactory { self: HasCloseContext => + protected def logger: TracedLogger + + /** Use this method to create a PromiseUnlessShutdown that will automatically be cancelled when the close context + * is closed. This allows proper clean up of stray promises when the node is transitioning to a passive state. + */ + def mkPromise[A]( + description: String, + futureSupervisor: FutureSupervisor, + logAfter: Duration = 10.seconds, + logLevel: Level = Level.DEBUG, + )(implicit elc: ErrorLoggingContext, ec: ExecutionContext): PromiseUnlessShutdown[A] = { + val promise = new PromiseUnlessShutdown[A](description, futureSupervisor, logAfter, logLevel) + + val cancelToken = closeContext.flagCloseable.runOnShutdown(new RunOnShutdown { + override def name: String = s"$description-abort-promise-on-shutdown" + override def done: Boolean = promise.isCompleted + override def run(): Unit = promise.shutdown() + })(elc.traceContext) + + promise.future + .onComplete { _ => + Try(closeContext.flagCloseable.cancelShutdownTask(cancelToken)).failed.foreach(e => + logger.debug(s"Failed to cancel shutdown task for $description", e)(elc.traceContext) + ) + } + + promise + } +} diff --git a/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencedEventMonotonicityChecker.scala b/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencedEventMonotonicityChecker.scala new file mode 100644 index 0000000000..ce252788ba --- /dev/null +++ b/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencedEventMonotonicityChecker.scala @@ -0,0 +1,126 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.digitalasset.canton.sequencing + +import akka.NotUsed +import akka.stream.scaladsl.Flow +import cats.syntax.functorFilter.* +import com.digitalasset.canton.SequencerCounter +import com.digitalasset.canton.data.CantonTimestamp +import com.digitalasset.canton.logging.{NamedLoggerFactory, NamedLogging} +import com.digitalasset.canton.sequencing.protocol.ClosedEnvelope +import com.digitalasset.canton.tracing.TraceContext +import com.digitalasset.canton.util.AkkaUtil.WithKillSwitch +import com.digitalasset.canton.util.ErrorUtil +import com.google.common.annotations.VisibleForTesting + +/** Checks that the sequenced events' sequencer counters are a gap-free increasing sequencing starting at `firstSequencerCounter` + * and their timestamps increase strictly monotonically. When a violation is detected, an error is logged and + * the processing is aborted. + * + * This is normally ensured by the [[com.digitalasset.canton.sequencing.client.SequencedEventValidator]] for individual sequencer subscriptions. + * However, due to aggregating multiple subscriptions from several sequencers up to a threshold, + * the stream of events emitted by the aggregation may violate monotonicity. This additional monotonicity check + * ensures that we catch such violations before we pass the events downstream. + */ +class SequencedEventMonotonicityChecker( + firstSequencerCounter: SequencerCounter, + firstTimestampLowerBoundInclusive: CantonTimestamp, + override protected val loggerFactory: NamedLoggerFactory, +) extends NamedLogging { + import SequencedEventMonotonicityChecker.* + + /** Akka version of the check. Pulls the kill switch and drains the source when a violation is detected. */ + def flow: Flow[ + WithKillSwitch[OrdinarySerializedEvent], + WithKillSwitch[OrdinarySerializedEvent], + NotUsed, + ] = { + Flow[WithKillSwitch[OrdinarySerializedEvent]] + .statefulMap(() => initialState)( + (state, eventAndKillSwitch) => eventAndKillSwitch.traverse(onNext(state, _)), + _ => None, + ) + .mapConcat { actionAndKillSwitch => + actionAndKillSwitch.traverse { + case Emit(event) => Some(event) + case failure: MonotonicityFailure => + implicit val traceContext: TraceContext = failure.event.traceContext + logger.error(failure.message) + actionAndKillSwitch.killSwitch.shutdown() + None + case Drop => None + } + } + } + + /** [[com.digitalasset.canton.sequencing.ApplicationHandler]] version. + * @throws com.digitalasset.canton.sequencing.SequencedEventMonotonicityChecker.MonotonicityFailureException + * when a monotonicity violation is detected + */ + def handler( + handler: OrdinaryApplicationHandler[ClosedEnvelope] + ): OrdinaryApplicationHandler[ClosedEnvelope] = { + // Application handlers must be called sequentially, so a plain var is good enough here + @SuppressWarnings(Array("org.wartremover.warts.Var")) + var state: State = initialState + handler.replace { tracedEvents => + val filtered = tracedEvents.map(_.mapFilter { event => + val (nextState, action) = onNext(state, event) + state = nextState + action match { + case Emit(_) => Some(event) + case failure: MonotonicityFailure => + implicit val traceContext: TraceContext = event.traceContext + ErrorUtil.internalError(failure.asException) + case Drop => None + } + }) + handler.apply(filtered) + } + } + + private def initialState: State = + GoodState(firstSequencerCounter, firstTimestampLowerBoundInclusive) + + private def onNext(state: State, event: OrdinarySerializedEvent): (State, Action) = state match { + case Failed => (state, Drop) + case GoodState(nextSequencerCounter, lowerBoundTimestamp) => + val monotonic = + event.counter == nextSequencerCounter && event.timestamp >= lowerBoundTimestamp + if (monotonic) { + val nextState = GoodState(event.counter + 1, event.timestamp.immediateSuccessor) + nextState -> Emit(event) + } else { + val error = MonotonicityFailure(nextSequencerCounter, lowerBoundTimestamp, event) + Failed -> error + } + } +} + +object SequencedEventMonotonicityChecker { + + private sealed trait Action extends Product with Serializable + private final case class Emit(event: OrdinarySerializedEvent) extends Action + private case object Drop extends Action + private final case class MonotonicityFailure( + expectedSequencerCounter: SequencerCounter, + timestampLowerBound: CantonTimestamp, + event: OrdinarySerializedEvent, + ) extends Action { + def message: String = + s"Sequencer counters and timestamps do not increase monotonically. Expected next counter=$expectedSequencerCounter with timestamp lower bound $timestampLowerBound, but received ${event.signedEvent.content}" + + def asException: Exception = new MonotonicityFailureException(message) + } + @VisibleForTesting + class MonotonicityFailureException(message: String) extends Exception(message) + + private sealed trait State extends Product with Serializable + private case object Failed extends State + private final case class GoodState( + nextSequencerCounter: SequencerCounter, + lowerBoundTimestamp: CantonTimestamp, + ) extends State +} diff --git a/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencerConnections.scala b/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencerConnections.scala index 091e3bf82b..950bcb85a2 100644 --- a/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencerConnections.scala +++ b/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/SequencerConnections.scala @@ -109,7 +109,8 @@ final case class SequencerConnections private ( ): SequencerConnections = modify(sequencerAlias, _.withCertificates(certificates)) - override def pretty: Pretty[SequencerConnections] = prettyOfParam(_.aliasToConnection.forgetNE) + override def pretty: Pretty[SequencerConnections] = + prettyOfParam(_.aliasToConnection.forgetNE) def toProtoV0: Seq[v0.SequencerConnection] = connections.map(_.toProtoV0) diff --git a/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/client/SequencerClient.scala b/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/client/SequencerClient.scala index 989ef3c46f..34146b37b2 100644 --- a/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/client/SequencerClient.scala +++ b/canton/community/base/src/main/scala/com/digitalasset/canton/sequencing/client/SequencerClient.scala @@ -652,7 +652,7 @@ class SequencerClientImpl( fetchCleanTimestamp: PeriodicAcknowledgements.FetchCleanTimestamp, requiresAuthentication: Boolean, )(implicit traceContext: TraceContext): Future[Unit] = { - val eventHandler = ThrottlingApplicationEventHandler.throttle( + val throttledEventHandler = ThrottlingApplicationEventHandler.throttle( config.maximumInFlightEventBatches, nonThrottledEventHandler, metrics, @@ -709,26 +709,36 @@ class SequencerClientImpl( _ = replayEvents.lastOption .orElse(initialPriorEventO) .foreach(event => timeTracker.subscriptionResumesAfter(event.timestamp)) - _ <- eventHandler.subscriptionStartsAt(subscriptionStartsAt, timeTracker) + _ <- throttledEventHandler.subscriptionStartsAt(subscriptionStartsAt, timeTracker) eventBatches = replayEvents.grouped(config.eventInboxSize.unwrap) _ <- FutureUnlessShutdown.outcomeF( MonadUtil - .sequentialTraverse_(eventBatches)(processEventBatch(eventHandler, _)) + .sequentialTraverse_(eventBatches)(processEventBatch(throttledEventHandler, _)) .valueOr(err => throw SequencerClientSubscriptionException(err)) ) } yield { - sequencerTransports.sequencerIdToTransportMap.keySet - .foreach { sequencerId => - createSubscription( - sequencerId, - replayEvents, - initialPriorEventO, - requiresAuthentication, - timeTracker, - eventHandler, - ).discard - } + val preSubscriptionEvent = replayEvents.lastOption.orElse(initialPriorEventO) + // previously seen counter takes precedence over the lower bound + val firstCounter = preSubscriptionEvent.fold(initialCounterLowerBound)(_.counter + 1) + val monotonicityChecker = new SequencedEventMonotonicityChecker( + firstCounter, + preSubscriptionEvent.fold(CantonTimestamp.MinValue)(_.timestamp), + loggerFactory, + ) + val eventHandler = monotonicityChecker.handler( + StoreSequencedEvent(sequencedEventStore, domainId, loggerFactory).apply( + timeTracker.wrapHandler(throttledEventHandler) + ) + ) + sequencerTransports.sequencerIdToTransportMap.keySet.foreach { sequencerId => + createSubscription( + sequencerId, + preSubscriptionEvent, + requiresAuthentication, + eventHandler, + ).discard + } // periodically acknowledge that we've successfully processed up to the clean counter // We only need to it setup once; the sequencer client will direct the acknowledgements to the @@ -766,23 +776,15 @@ class SequencerClientImpl( private def createSubscription( sequencerId: SequencerId, - replayEvents: Seq[PossiblyIgnoredSerializedEvent], - initialPriorEventO: Option[PossiblyIgnoredSerializedEvent], + preSubscriptionEvent: Option[PossiblyIgnoredSerializedEvent], requiresAuthentication: Boolean, - timeTracker: DomainTimeTracker, - eventHandler: PossiblyIgnoredApplicationHandler[ClosedEnvelope], - )(implicit traceContext: TraceContext) = { - val lastEvent = replayEvents.lastOption - val preSubscriptionEvent = lastEvent.orElse(initialPriorEventO) - - val nextCounter = - // previously seen counter takes precedence over the lower bound - preSubscriptionEvent.fold(initialCounterLowerBound)(_.counter) - - val eventValidator = eventValidatorFactory.create( - unauthenticated = !requiresAuthentication - ) - + eventHandler: OrdinaryApplicationHandler[ClosedEnvelope], + )(implicit + traceContext: TraceContext + ): ResilientSequencerSubscription[SequencerClientSubscriptionError] = { + // previously seen counter takes precedence over the lower bound + val nextCounter = preSubscriptionEvent.fold(initialCounterLowerBound)(_.counter) + val eventValidator = eventValidatorFactory.create(unauthenticated = !requiresAuthentication) logger.info( s"Starting subscription for alias=$sequencerId at timestamp ${preSubscriptionEvent .map(_.timestamp)}; next counter $nextCounter" @@ -790,8 +792,7 @@ class SequencerClientImpl( val eventDelay: DelaySequencedEvent = { val first = testingConfig.testSequencerClientFor.find(elem => - elem.memberName == member.uid.id.unwrap - && + elem.memberName == member.uid.id.unwrap && elem.domainName == domainId.unwrap.id.unwrap ) @@ -807,9 +808,7 @@ class SequencerClientImpl( } val subscriptionHandler = new SubscriptionHandler( - StoreSequencedEvent(sequencedEventStore, domainId, loggerFactory).apply( - timeTracker.wrapHandler(eventHandler) - ), + eventHandler, eventValidator, eventDelay, preSubscriptionEvent, diff --git a/canton/community/base/src/main/scala/com/digitalasset/canton/util/SimpleExecutionQueue.scala b/canton/community/base/src/main/scala/com/digitalasset/canton/util/SimpleExecutionQueue.scala index 7f250b0282..26cf750c71 100644 --- a/canton/community/base/src/main/scala/com/digitalasset/canton/util/SimpleExecutionQueue.scala +++ b/canton/community/base/src/main/scala/com/digitalasset/canton/util/SimpleExecutionQueue.scala @@ -19,7 +19,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.annotation.tailrec import scala.concurrent.Future import scala.concurrent.duration.Duration -import scala.util.{Failure, Success} +import scala.util.{Failure, Success, Try} /** Functions executed with this class will only run when all previous calls have completed executing. * This can be used when async code should not be run concurrently. @@ -66,10 +66,14 @@ class SimpleExecutionQueue( )(implicit loggingContext: ErrorLoggingContext): EitherT[FutureUnlessShutdown, A, B] = EitherT(executeUS(execution.value, description)) - def executeUS[A](execution: => FutureUnlessShutdown[A], description: String)(implicit + def executeUS[A]( + execution: => FutureUnlessShutdown[A], + description: String, + runWhenUnderFailures: => Unit = (), + )(implicit loggingContext: ErrorLoggingContext ): FutureUnlessShutdown[A] = - genExecute(runIfFailed = false, execution, description) + genExecute(runIfFailed = false, execution, description, runWhenUnderFailures) def executeUnderFailures[A](execution: => Future[A], description: String)(implicit loggingContext: ErrorLoggingContext @@ -89,6 +93,7 @@ class SimpleExecutionQueue( runIfFailed: Boolean, execution: => FutureUnlessShutdown[A], description: String, + runWhenUnderFailures: => Unit = (), )(implicit loggingContext: ErrorLoggingContext): FutureUnlessShutdown[A] = { val next = new TaskCell(description, logTaskTiming, futureSupervisor, directExecutionContext) val oldHead = queueHead.getAndSet(next) // linearization point @@ -100,6 +105,7 @@ class SimpleExecutionQueue( directExecutionContext, loggingContext.traceContext, ), + runWhenUnderFailures, ) } @@ -228,6 +234,7 @@ object SimpleExecutionQueue { pred: TaskCell, runIfFailed: Boolean, execution: => FutureUnlessShutdown[A], + runWhenUnderFailures: => Unit, )(implicit loggingContext: ErrorLoggingContext ): FutureUnlessShutdown[A] = { @@ -279,6 +286,13 @@ object SimpleExecutionQueue { s"Not running task ${description.singleQuoted} due to exception after waiting for $waitingDelay" )(loggingContext.traceContext) } + Try(runWhenUnderFailures).failed + .foreach(e => + loggingContext.logger.debug( + s"Failed to run 'runWhenUnderFailures' function for ${description.singleQuoted}", + e, + )(loggingContext.traceContext) + ) FutureUnlessShutdown.failed(ex) } }(directExecutionContext) diff --git a/canton/community/common/src/main/scala/com/digitalasset/canton/data/TaskScheduler.scala b/canton/community/common/src/main/scala/com/digitalasset/canton/data/TaskScheduler.scala index acc079b8f7..c0865159c2 100644 --- a/canton/community/common/src/main/scala/com/digitalasset/canton/data/TaskScheduler.scala +++ b/canton/community/common/src/main/scala/com/digitalasset/canton/data/TaskScheduler.scala @@ -6,7 +6,7 @@ package com.digitalasset.canton.data import com.daml.metrics.api.MetricHandle.Counter import com.daml.metrics.api.MetricHandle.Gauge.CloseableGauge import com.daml.nameof.NameOf.functionFullName -import com.digitalasset.canton.concurrent.FutureSupervisor +import com.digitalasset.canton.concurrent.{DirectExecutionContext, FutureSupervisor} import com.digitalasset.canton.config.ProcessingTimeout import com.digitalasset.canton.data.PeanoQueue.{BeforeHead, InsertedValue, NotInserted} import com.digitalasset.canton.lifecycle.{FlagCloseable, FutureUnlessShutdown, Lifecycle} @@ -22,6 +22,7 @@ import java.util.concurrent.atomic.AtomicReference import scala.annotation.tailrec import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future, Promise, blocking} +import scala.util.control.NonFatal /** The task scheduler manages tasks with associated timestamps and sequencer counters. * Tasks may be inserted in any order; they will be executed nevertheless in the correct order @@ -297,10 +298,18 @@ class TaskScheduler[Task <: TaskScheduler.TimedTask]( case Some(tracedTask) => tracedTask.withTraceContext { implicit traceContext => task => FutureUtil.doNotAwait( - // Close the task if the queue is shutdown + // Close the task if the queue is shutdown or if it has failed queue .executeUS(task.perform(), task.toString) - .onShutdown(task.close()), + .onShutdown(task.close()) + .recoverWith { + // If any task fails, none of subsequent tasks will be executed so we might as well close the scheduler + // to force completion of the tasks and signal that the scheduler is not functional + case NonFatal(e) if !this.isClosing => + this.close() + Future.failed(e) + // Use a direct context here to avoid closing the scheduler in a different thread + }(DirectExecutionContext(errorLoggingContext(traceContext).noTracingLogger)), show"A task failed with an exception.\n$task", ) taskQueue.dequeue() diff --git a/canton/community/common/src/main/scala/com/digitalasset/canton/topology/QueueBasedDomainOutboxX.scala b/canton/community/common/src/main/scala/com/digitalasset/canton/topology/QueueBasedDomainOutboxX.scala index c6cbf7d859..a9cb200a20 100644 --- a/canton/community/common/src/main/scala/com/digitalasset/canton/topology/QueueBasedDomainOutboxX.scala +++ b/canton/community/common/src/main/scala/com/digitalasset/canton/topology/QueueBasedDomainOutboxX.scala @@ -211,7 +211,7 @@ class QueueBasedDomainOutboxX( def markDone(delayRetry: Boolean = false): Unit = { val updated = queueState.getAndUpdate(_.done()) // if anything has been pushed in the meantime, we need to kick off a new flush - if (updated.hasPending) { + if (updated.hasPending && !isClosing) { if (delayRetry) { // kick off new flush in the background DelayUtil.delay(functionFullName, 10.seconds, this).map(_ => kickOffFlush()).discard @@ -272,17 +272,15 @@ class QueueBasedDomainOutboxX( queuedApprox = math.max(c.queuedApprox - pending.size, 0) ) }.discard + + domainOutboxQueue.completeCycle() + markDone() } - ret.transform { - case x @ Left(_) => - domainOutboxQueue.requeue() - markDone(delayRetry = true) - x - case x @ Right(_) => - domainOutboxQueue.completeCycle() - markDone() - x - } + + EitherTUtil.onErrorOrFailureUnlessShutdown { () => + domainOutboxQueue.requeue() + markDone(delayRetry = true) + }(ret) } else { markDone() EitherT.rightT(()) diff --git a/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/SequencedEventMonotonicityCheckerTest.scala b/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/SequencedEventMonotonicityCheckerTest.scala new file mode 100644 index 0000000000..5b1f426bf6 --- /dev/null +++ b/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/SequencedEventMonotonicityCheckerTest.scala @@ -0,0 +1,203 @@ +// Copyright (c) 2023 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.digitalasset.canton.sequencing + +import akka.stream.scaladsl.{Keep, Sink, Source} +import com.digitalasset.canton.data.CantonTimestamp +import com.digitalasset.canton.lifecycle.FutureUnlessShutdown +import com.digitalasset.canton.sequencing.SequencedEventMonotonicityChecker.MonotonicityFailureException +import com.digitalasset.canton.sequencing.client.SequencedEventTestFixture +import com.digitalasset.canton.sequencing.protocol.ClosedEnvelope +import com.digitalasset.canton.time.DomainTimeTracker +import com.digitalasset.canton.tracing.{TraceContext, Traced} +import com.digitalasset.canton.util.AkkaUtil.syntax.* +import com.digitalasset.canton.util.{ErrorUtil, ResourceUtil} +import com.digitalasset.canton.{ + BaseTest, + HasExecutionContext, + ProtocolVersionChecksFixtureAnyWordSpec, + SequencerCounter, +} +import org.scalatest.Outcome +import org.scalatest.wordspec.FixtureAnyWordSpec + +import java.util.concurrent.atomic.AtomicReference + +class SequencedEventMonotonicityCheckerTest + extends FixtureAnyWordSpec + with BaseTest + with HasExecutionContext + with ProtocolVersionChecksFixtureAnyWordSpec { + import SequencedEventMonotonicityCheckerTest.* + + override protected type FixtureParam = SequencedEventTestFixture + + override protected def withFixture(test: OneArgTest): Outcome = + ResourceUtil.withResource( + new SequencedEventTestFixture( + loggerFactory, + testedProtocolVersion, + timeouts, + futureSupervisor, + ) + ) { env => withFixture(test.toNoArgTest(env)) } + + private def mkHandler(): CapturingApplicationHandler = new CapturingApplicationHandler + + "handler" should { + "pass through monotonically increasing events" in { env => + import env.* + + val checker = new SequencedEventMonotonicityChecker( + bobEvents(0).counter, + bobEvents(0).timestamp, + loggerFactory, + ) + val handler = mkHandler() + val checkedHandler = checker.handler(handler) + val (batch1, batch2) = bobEvents.splitAt(2) + + checkedHandler(Traced(batch1)).futureValueUS.unwrap.futureValueUS + checkedHandler(Traced(batch2)).futureValueUS.unwrap.futureValueUS + handler.invocations.get.flatMap(_.value) shouldBe bobEvents + } + + "detect gaps in sequencer counters" in { env => + import env.* + + val checker = new SequencedEventMonotonicityChecker( + bobEvents(0).counter, + bobEvents(0).timestamp, + loggerFactory, + ) + val handler = mkHandler() + val checkedHandler = checker.handler(handler) + val (batch1, batch2) = bobEvents.splitAt(2) + + checkedHandler(Traced(batch1)).futureValueUS.unwrap.futureValueUS + loggerFactory.assertThrowsAndLogs[MonotonicityFailureException]( + checkedHandler(Traced(batch2.drop(1))).futureValueUS.unwrap.futureValueUS, + _.errorMessage should include(ErrorUtil.internalErrorMessage), + ) + } + + "detect non-monotonic timestamps" in { env => + import env.* + + val event1 = createEvent( + timestamp = CantonTimestamp.ofEpochSecond(2), + counter = 2L, + ).futureValue + val event2 = createEvent( + timestamp = CantonTimestamp.ofEpochSecond(2), + counter = 3L, + ).futureValue + + val checker = new SequencedEventMonotonicityChecker( + SequencerCounter(2L), + CantonTimestamp.MinValue, + loggerFactory, + ) + val handler = mkHandler() + val checkedHandler = checker.handler(handler) + + checkedHandler(Traced(Seq(event1))).futureValueUS.unwrap.futureValueUS + loggerFactory.assertThrowsAndLogs[MonotonicityFailureException]( + checkedHandler(Traced(Seq(event2))).futureValueUS.unwrap.futureValueUS, + _.errorMessage should include(ErrorUtil.internalErrorMessage), + ) + } + } + + "flow" should { + "pass through monotonically increasing events" in { env => + import env.* + + val checker = new SequencedEventMonotonicityChecker( + bobEvents(0).counter, + bobEvents(0).timestamp, + loggerFactory, + ) + val eventsF = Source(bobEvents) + .withUniqueKillSwitchMat()(Keep.left) + .via(checker.flow) + .toMat(Sink.seq)(Keep.right) + .run() + eventsF.futureValue.map(_.unwrap) shouldBe bobEvents + } + + "kill the stream upon a gap in the counters" in { env => + import env.* + + val checker = new SequencedEventMonotonicityChecker( + bobEvents(0).counter, + bobEvents(0).timestamp, + loggerFactory, + ) + val (batch1, batch2) = bobEvents.splitAt(2) + val eventsF = loggerFactory.assertLogs( + Source(batch1 ++ batch2.drop(1)) + .withUniqueKillSwitchMat()(Keep.left) + .via(checker.flow) + .toMat(Sink.seq)(Keep.right) + .run(), + _.errorMessage should include( + "Sequencer counters and timestamps do not increase monotonically" + ), + ) + eventsF.futureValue.map(_.unwrap) shouldBe batch1 + } + + "detect non-monotonic timestamps" in { env => + import env.* + + val event1 = createEvent( + timestamp = CantonTimestamp.ofEpochSecond(2), + counter = 2L, + ).futureValue + val event2 = createEvent( + timestamp = CantonTimestamp.ofEpochSecond(2), + counter = 3L, + ).futureValue + + val checker = new SequencedEventMonotonicityChecker( + SequencerCounter(2L), + CantonTimestamp.MinValue, + loggerFactory, + ) + val eventsF = loggerFactory.assertLogs( + Source(Seq(event1, event2)) + .withUniqueKillSwitchMat()(Keep.left) + .via(checker.flow) + .toMat(Sink.seq)(Keep.right) + .run(), + _.errorMessage should include( + "Sequencer counters and timestamps do not increase monotonically" + ), + ) + eventsF.futureValue.map(_.unwrap) shouldBe Seq(event1) + } + } +} + +object SequencedEventMonotonicityCheckerTest { + class CapturingApplicationHandler() + extends ApplicationHandler[OrdinaryEnvelopeBox, ClosedEnvelope] { + val invocations = + new AtomicReference[Seq[BoxedEnvelope[OrdinaryEnvelopeBox, ClosedEnvelope]]](Seq.empty) + + override def name: String = "capturing-application-handler" + override def subscriptionStartsAt( + start: SubscriptionStart, + domainTimeTracker: DomainTimeTracker, + )(implicit traceContext: TraceContext): FutureUnlessShutdown[Unit] = FutureUnlessShutdown.unit + + override def apply(boxed: BoxedEnvelope[OrdinaryEnvelopeBox, ClosedEnvelope]): HandlerResult = { + invocations + .getAndUpdate(_ :+ boxed) + .discard[Seq[BoxedEnvelope[OrdinaryEnvelopeBox, ClosedEnvelope]]] + HandlerResult.done + } + } +} diff --git a/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencedEventTestFixture.scala b/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencedEventTestFixture.scala index 8e5468a5a4..4a94ded87b 100644 --- a/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencedEventTestFixture.scala +++ b/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencedEventTestFixture.scala @@ -77,21 +77,24 @@ class SequencedEventTestFixture( ByteString.copyFromUtf8("signatureCarlos1"), carlos.uid.namespace.fingerprint, ) - lazy val aliceEvents = (1 to 5).map(s => + lazy val aliceEvents: Seq[OrdinarySerializedEvent] = (1 to 5).map(s => createEvent( timestamp = CantonTimestamp.Epoch.plusSeconds(s.toLong), + counter = updatedCounter + s.toLong, signatureOverride = Some(signatureAlice), ).futureValue ) - lazy val bobEvents = (1 to 5).map(s => + lazy val bobEvents: Seq[OrdinarySerializedEvent] = (1 to 5).map(s => createEvent( timestamp = CantonTimestamp.Epoch.plusSeconds(s.toLong), + counter = updatedCounter + s.toLong, signatureOverride = Some(signatureBob), ).futureValue ) - lazy val carlosEvents = (1 to 5).map(s => + lazy val carlosEvents: Seq[OrdinarySerializedEvent] = (1 to 5).map(s => createEvent( timestamp = CantonTimestamp.Epoch.plusSeconds(s.toLong), + counter = updatedCounter + s.toLong, signatureOverride = Some(signatureCarlos), ).futureValue ) diff --git a/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencerClientTest.scala b/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencerClientTest.scala index 0cf4e78bf5..b9e3b40909 100644 --- a/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencerClientTest.scala +++ b/canton/community/common/src/test/scala/com/digitalasset/canton/sequencing/client/SequencerClientTest.scala @@ -82,9 +82,10 @@ class SequencerClientTest MetricName("SequencerClientTest"), NoOpMetricsFactory, ) + private lazy val firstSequencerCounter = SequencerCounter(42L) private lazy val deliver: Deliver[Nothing] = SequencerTestUtils.mockDeliver( - 42, + firstSequencerCounter.unwrap, CantonTimestamp.Epoch, DefaultTestIdentities.domainId, ) @@ -136,13 +137,13 @@ class SequencerClientTest } yield error).futureValue shouldBe a[RuntimeException] } - "start from genesis if there is no recorded event" in { + "start from the specified sequencer counter if there is no recorded event" in { val counterF = for { - env <- Env.create() + env <- Env.create(initialSequencerCounter = SequencerCounter(5)) _ <- env.subscribeAfter() } yield env.transport.subscriber.value.request.counter - counterF.futureValue shouldBe SequencerCounter.Genesis + counterF.futureValue shouldBe SequencerCounter(5) } "starts subscription at last stored event (for fork verification)" in { @@ -490,6 +491,7 @@ class SequencerClientTest maximumInFlightEventBatches = PositiveInt.tryCreate(5), ), useParallelExecutionContext = true, + initialSequencerCounter = SequencerCounter(1L), ) _ <- env.subscribeAfter( CantonTimestamp.Epoch, @@ -670,7 +672,10 @@ class SequencerClientTest "create second subscription from the same counter as the previous one when there are no events" in { val secondTransport = new MockTransport val testF = for { - env <- Env.create(useParallelExecutionContext = true) + env <- Env.create( + useParallelExecutionContext = true, + initialSequencerCounter = SequencerCounter.Genesis, + ) _ <- env.subscribeAfter() _ <- env.changeTransport(secondTransport) } yield { @@ -697,16 +702,17 @@ class SequencerClientTest env <- Env.create(useParallelExecutionContext = true) _ <- env.subscribeAfter() - _ <- env.transport.subscriber.value.handler(signedDeliver) + _ <- env.transport.subscriber.value.sendToHandler(deliver) + _ <- env.transport.subscriber.value.sendToHandler(nextDeliver) _ <- env.client.flushClean() _ <- env.changeTransport(secondTransport) } yield { val originalSubscriber = env.transport.subscriber.value - originalSubscriber.request.counter shouldBe SequencerCounter.Genesis + originalSubscriber.request.counter shouldBe firstSequencerCounter val newSubscriber = secondTransport.subscriber.value - newSubscriber.request.counter shouldBe deliver.counter + newSubscriber.request.counter shouldBe nextDeliver.counter env.client.completion.isCompleted shouldBe false } @@ -993,6 +999,7 @@ class SequencerClientTest eventValidator: SequencedEventValidator = eventAlwaysValid, options: SequencerClientConfig = SequencerClientConfig(), useParallelExecutionContext: Boolean = false, + initialSequencerCounter: SequencerCounter = firstSequencerCounter, )(implicit closeContext: CloseContext): Future[Env] = { // if parallel execution is desired use the UseExecutorService executor service (which is a parallel execution context) // otherwise use the default serial execution context provided by ScalaTest @@ -1082,7 +1089,7 @@ class SequencerClientTest LoggingConfig(), loggerFactory, futureSupervisor, - SequencerCounter.Genesis, + initialSequencerCounter, )(executionContext, tracer) val signedEvents = storedEvents.map(SequencerTestUtils.sign) diff --git a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/AbstractMessageProcessor.scala b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/AbstractMessageProcessor.scala index a7184489f3..24a5f9bf88 100644 --- a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/AbstractMessageProcessor.scala +++ b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/AbstractMessageProcessor.scala @@ -9,7 +9,7 @@ import cats.syntax.functor.* import com.daml.nameof.NameOf.functionFullName import com.digitalasset.canton.crypto.{DomainSnapshotSyncCryptoApi, DomainSyncCryptoClient} import com.digitalasset.canton.data.CantonTimestamp -import com.digitalasset.canton.lifecycle.{FlagCloseable, FutureUnlessShutdown} +import com.digitalasset.canton.lifecycle.{FlagCloseable, FutureUnlessShutdown, HasCloseContext} import com.digitalasset.canton.logging.NamedLogging import com.digitalasset.canton.participant.protocol.RequestJournal.RequestState import com.digitalasset.canton.participant.protocol.conflictdetection.ActivenessSet @@ -42,7 +42,8 @@ abstract class AbstractMessageProcessor( protocolVersion: ProtocolVersion, )(implicit ec: ExecutionContext) extends NamedLogging - with FlagCloseable { + with FlagCloseable + with HasCloseContext { protected def terminateRequest( requestCounter: RequestCounter, @@ -190,10 +191,10 @@ abstract class AbstractMessageProcessor( if (!isCleanReplay(requestCounter)) { val timeoutF = requestFutures.timeoutResult.flatMap { timeoutResult => - if (timeoutResult.timedOut) onTimeout - else Future.unit + if (timeoutResult.timedOut) FutureUnlessShutdown.outcomeF(onTimeout) + else FutureUnlessShutdown.unit } - FutureUtil.doNotAwait(timeoutF, "Handling timeout failed") + FutureUtil.doNotAwaitUnlessShutdown(timeoutF, "Handling timeout failed") } } yield () diff --git a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/Phase37Synchronizer.scala b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/Phase37Synchronizer.scala index 200c9e3446..4d582d2648 100644 --- a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/Phase37Synchronizer.scala +++ b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/Phase37Synchronizer.scala @@ -3,9 +3,15 @@ package com.digitalasset.canton.participant.protocol -import com.digitalasset.canton.RequestCounter -import com.digitalasset.canton.concurrent.{FutureSupervisor, SupervisedPromise} +import com.digitalasset.canton.concurrent.FutureSupervisor +import com.digitalasset.canton.config.ProcessingTimeout import com.digitalasset.canton.data.{CantonTimestamp, ConcurrentHMap} +import com.digitalasset.canton.lifecycle.{ + FlagCloseable, + FutureUnlessShutdown, + HasCloseContext, + PromiseUnlessShutdown, +} import com.digitalasset.canton.logging.{NamedLoggerFactory, NamedLogging} import com.digitalasset.canton.participant.protocol.Phase37Synchronizer.* import com.digitalasset.canton.participant.protocol.ProcessingSteps.{ @@ -16,9 +22,11 @@ import com.digitalasset.canton.participant.protocol.ProtocolProcessor.PendingReq import com.digitalasset.canton.protocol.RequestId import com.digitalasset.canton.tracing.TraceContext import com.digitalasset.canton.util.ErrorUtil +import com.digitalasset.canton.{DiscardOps, RequestCounter} import com.google.common.annotations.VisibleForTesting -import scala.concurrent.{ExecutionContext, Future, Promise, blocking} +import scala.concurrent.{ExecutionContext, Future, blocking} +import scala.util.{Failure, Success} /** Synchronizes the request processing of phases 3 and 7. * At the end of phase 3, every request must signal that it has reached @@ -34,7 +42,10 @@ class Phase37Synchronizer( initRc: RequestCounter, override val loggerFactory: NamedLoggerFactory, futureSupervisor: FutureSupervisor, -) extends NamedLogging { + override val timeouts: ProcessingTimeout, +) extends NamedLogging + with FlagCloseable + with HasCloseContext { /** Maps request timestamps to a promise and a future, which is used to chain each request's evaluation (i.e. filter). * The future completes with either the pending request data, if it's the first valid call, @@ -65,10 +76,10 @@ class Phase37Synchronizer( val ts = CantonTimestampWithRequestType[requestType.type](requestId.unwrap, requestType) implicit val evRequest = ts.pendingRequestRelation - val promise: Promise[Option[ + val promise: PromiseUnlessShutdown[Option[ PendingRequestDataOrReplayData[requestType.PendingRequestData] ]] = - new SupervisedPromise[Option[ + mkPromise[Option[ PendingRequestDataOrReplayData[requestType.PendingRequestData] ]]("phase37sync-register-request-data", futureSupervisor) @@ -77,7 +88,7 @@ class Phase37Synchronizer( blocking(synchronized { val requestRelation: RequestRelation[requestType.PendingRequestData] = RequestRelation( promise.future - .map(_.orElse { + .map(_.onShutdown(None).orElse { blocking(synchronized { pendingRequests.remove_(ts) }) @@ -115,7 +126,7 @@ class Phase37Synchronizer( )(implicit traceContext: TraceContext, ec: ExecutionContext, - ): Future[RequestOutcome[requestType.PendingRequestData]] = { + ): FutureUnlessShutdown[RequestOutcome[requestType.PendingRequestData]] = { val ts = CantonTimestampWithRequestType[requestType.type](requestId.unwrap, requestType) implicit val evRequest = ts.pendingRequestRelation @@ -125,46 +136,52 @@ class Phase37Synchronizer( logger.debug( s"Request ${requestId.unwrap}: Request data is waiting to be validated" ) - val promise: Promise[RequestOutcome[requestType.PendingRequestData]] = - new SupervisedPromise[RequestOutcome[requestType.PendingRequestData]]( + val promise: PromiseUnlessShutdown[RequestOutcome[requestType.PendingRequestData]] = + mkPromise[RequestOutcome[requestType.PendingRequestData]]( "phase37sync-pending-request-data", futureSupervisor, ) - val newFut = fut.flatMap { + val newFut = fut.transformWith { /* either: (1) another call to awaitConfirmed has already received and successfully validated the data (2) the request was marked as a timeout */ - case None => - promise.success(RequestOutcome.AlreadyServedOrTimeout) + case Success(None) => + promise.outcome(RequestOutcome.AlreadyServedOrTimeout) Future.successful(None) - case Some(pData) => - filter(pData).map { - case true => + case Success(Some(pData)) => + filter(pData).transform { + case Success(true) => // we need a synchronized block here to avoid conflicts with the outer replace in awaitConfirmed blocking(synchronized { // the entry is removed when the first awaitConfirmed with a satisfied predicate is there pendingRequests.remove_(ts) }) - promise.success(RequestOutcome.Success(pData)) - None - case false => - promise.success(RequestOutcome.Invalid) - Some(pData) + promise.outcome(RequestOutcome.Success(pData)) + Success(None) + case Success(false) => + promise.outcome(RequestOutcome.Invalid) + Success(Some(pData)) + case Failure(exception) => + promise.tryFailure(exception).discard[Boolean] + Failure(exception) } + case Failure(exception) => + promise.tryFailure(exception).discard[Boolean] + Future.failed(exception) } pendingRequests.replace_[ts.type, RequestRelation[requestType.PendingRequestData]]( ts, rr.copy(pendingRequestDataFuture = newFut), ) - promise.future + promise.futureUS case None => logger.debug( s"Request ${requestId.unwrap}: Request data was already returned to another caller" + s" or has timed out" ) - Future.successful(RequestOutcome.AlreadyServedOrTimeout) + FutureUnlessShutdown.pure(RequestOutcome.AlreadyServedOrTimeout) } }) } @@ -212,10 +229,16 @@ object Phase37Synchronizer { ) final class PendingRequestDataHandle[T <: PendingRequestData]( - private val handle: Promise[Option[PendingRequestDataOrReplayData[T]]] + private val handle: PromiseUnlessShutdown[Option[PendingRequestDataOrReplayData[T]]] ) { def complete(pendingData: Option[PendingRequestDataOrReplayData[T]]): Unit = { - handle.success(pendingData) + handle.outcome(pendingData) + } + def failed(exception: Throwable): Unit = { + handle.failure(exception) + } + def shutdown(): Unit = { + handle.shutdown() } } diff --git a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/ProtocolProcessor.scala b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/ProtocolProcessor.scala index 5de4fc2d4f..53f6fbddb7 100644 --- a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/ProtocolProcessor.scala +++ b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/ProtocolProcessor.scala @@ -18,11 +18,7 @@ import com.digitalasset.canton.crypto.{ } import com.digitalasset.canton.data.{CantonTimestamp, ViewPosition, ViewTree, ViewType} import com.digitalasset.canton.ledger.api.DeduplicationPeriod -import com.digitalasset.canton.lifecycle.{ - FutureUnlessShutdown, - PromiseUnlessShutdown, - UnlessShutdown, -} +import com.digitalasset.canton.lifecycle.{FutureUnlessShutdown, UnlessShutdown} import com.digitalasset.canton.logging.NamedLoggerFactory import com.digitalasset.canton.logging.pretty.{Pretty, PrettyPrinting} import com.digitalasset.canton.participant.protocol.Phase37Synchronizer.RequestOutcome @@ -63,6 +59,7 @@ import com.digitalasset.canton.tracing.TraceContext import com.digitalasset.canton.util.EitherTUtil.{condUnitET, ifThenET} import com.digitalasset.canton.util.EitherUtil.RichEither import com.digitalasset.canton.util.FutureInstances.* +import com.digitalasset.canton.util.Thereafter.syntax.ThereafterOps import com.digitalasset.canton.util.* import com.digitalasset.canton.version.ProtocolVersion import com.digitalasset.canton.{DiscardOps, LfPartyId, RequestCounter, SequencerCounter, checked} @@ -450,7 +447,7 @@ abstract class ProtocolProcessor[ ) // use the send callback and a promise to capture the eventual sequenced event read by the submitter - sendResultP = new PromiseUnlessShutdown[SendResult]( + sendResultP = mkPromise[SendResult]( "sequenced-event-send-result", futureSupervisor, ) @@ -616,7 +613,14 @@ abstract class ProtocolProcessor[ .registerRequest(steps.requestType)(RequestId(ts)) ) .map { handleRequestData => + // If the result is not a success, we still need to complete the request data in some way performRequestProcessing(ts, rc, sc, handleRequestData, batch, freshOwnTimelyTxF) + .thereafter { + case Failure(exception) => handleRequestData.failed(exception) + case Success(UnlessShutdown.Outcome(Left(_))) => handleRequestData.complete(None) + case Success(UnlessShutdown.AbortedDueToShutdown) => handleRequestData.shutdown() + case _ => + } } } toHandlerRequest(ts, processedET) @@ -1045,7 +1049,7 @@ abstract class ProtocolProcessor[ timeoutEvent(), ) ) - _ = EitherTUtil.doNotAwait(timeoutET, "Handling timeout failed") + _ = EitherTUtil.doNotAwaitUS(timeoutET, "Handling timeout failed") signedResponsesTo <- EitherT.right(responsesTo.parTraverse { case (response, recipients) => FutureUnlessShutdown.outcomeF( @@ -1421,7 +1425,7 @@ abstract class ProtocolProcessor[ ephemeral.requestTracker.tick(sc, resultTs) Left(steps.embedResultError(InvalidPendingRequest(requestId))) } - ).mapK(FutureUnlessShutdown.outcomeK).flatMap { pendingRequestDataOrReplayData => + ).flatMap { pendingRequestDataOrReplayData => performResultProcessing3( signedResultBatchE, unsignedResultE, @@ -1709,7 +1713,7 @@ abstract class ProtocolProcessor[ result: TimeoutResult )(implicit traceContext: TraceContext - ): EitherT[Future, steps.ResultError, Unit] = + ): EitherT[FutureUnlessShutdown, steps.ResultError, Unit] = if (result.timedOut) { logger.info( show"${steps.requestKind.unquoted} request at $requestId timed out without a transaction result message." @@ -1750,13 +1754,15 @@ abstract class ProtocolProcessor[ // No need to clean up the pending submissions because this is handled (concurrently) by schedulePendingSubmissionRemoval cleanReplay = isCleanReplay(requestCounter, pendingRequestDataOrReplayData) - _ <- EitherT.right[steps.ResultError]( - ephemeral.storedContractManager.deleteIfPending(requestCounter, pendingContracts) - ) + _ <- EitherT + .right[steps.ResultError]( + ephemeral.storedContractManager.deleteIfPending(requestCounter, pendingContracts) + ) + .mapK(FutureUnlessShutdown.outcomeK) - _ <- ifThenET(!cleanReplay)(publishEvent()) + _ <- ifThenET(!cleanReplay)(publishEvent()).mapK(FutureUnlessShutdown.outcomeK) } yield () - } else EitherT.pure[Future, steps.ResultError](()) + } else EitherT.pure[FutureUnlessShutdown, steps.ResultError](()) private[this] def isCleanReplay( requestCounter: RequestCounter, diff --git a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/NaiveRequestTracker.scala b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/NaiveRequestTracker.scala index 8886e7e356..3a64b34c9b 100644 --- a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/NaiveRequestTracker.scala +++ b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/NaiveRequestTracker.scala @@ -6,14 +6,18 @@ package com.digitalasset.canton.participant.protocol.conflictdetection import cats.data.{EitherT, NonEmptyChain} import cats.syntax.either.* import com.daml.nameof.NameOf.functionFullName -import com.digitalasset.canton.concurrent.{FutureSupervisor, SupervisedPromise} +import com.digitalasset.canton.concurrent.FutureSupervisor import com.digitalasset.canton.config.ProcessingTimeout import com.digitalasset.canton.data.{CantonTimestamp, TaskScheduler, TaskSchedulerMetrics} +import com.digitalasset.canton.lifecycle.UnlessShutdown.AbortedDueToShutdown import com.digitalasset.canton.lifecycle.{ AsyncOrSyncCloseable, FlagCloseableAsync, FutureUnlessShutdown, + HasCloseContext, PromiseUnlessShutdown, + PromiseUnlessShutdownFactory, + RunOnShutdown, SyncCloseable, UnlessShutdown, } @@ -24,7 +28,7 @@ import com.digitalasset.canton.participant.util.TimeOfChange import com.digitalasset.canton.protocol.LfContractId import com.digitalasset.canton.tracing.TraceContext import com.digitalasset.canton.util.{ErrorUtil, FutureUtil, SingleUseCell} -import com.digitalasset.canton.{RequestCounter, SequencerCounter} +import com.digitalasset.canton.{DiscardOps, RequestCounter, SequencerCounter} import com.google.common.annotations.VisibleForTesting import scala.annotation.nowarn @@ -53,7 +57,8 @@ private[participant] class NaiveRequestTracker( )(implicit executionContext: ExecutionContext) extends RequestTracker with NamedLogging - with FlagCloseableAsync { + with FlagCloseableAsync + with HasCloseContext { self => import NaiveRequestTracker.* import RequestTracker.* @@ -68,6 +73,16 @@ private[participant] class NaiveRequestTracker( futureSupervisor, ) + // The task scheduler can decide to close itself if a task fails to execute + // If that happens, close the tracker as well since we won't be able to make progress without a scheduler + taskScheduler.runOnShutdown_( + new RunOnShutdown { + override def name: String = "close-request-tracker-due-to-scheduler-shutdown" + override def done: Boolean = isClosing + override def run(): Unit = self.close() + } + )(TraceContext.empty) + /** Maps request counters to the data associated with a request. * * A request resides in the map from the call to [[RequestTracker!.addRequest]] until some time after @@ -111,7 +126,14 @@ private[participant] class NaiveRequestTracker( ), ) - val data = RequestData.mk(sc, requestTimestamp, decisionTime, activenessSet, futureSupervisor) + val data = RequestData.mk( + sc, + requestTimestamp, + decisionTime, + activenessSet, + this, + futureSupervisor, + ) requests.putIfAbsent(rc, data) match { case None => @@ -134,7 +156,7 @@ private[participant] class NaiveRequestTracker( val f = conflictDetector.registerActivenessSet(rc, activenessSet).map { _ => // Tick the task scheduler only after all states have been prefetched into the conflict detector taskScheduler.addTick(sc, requestTimestamp) - RequestFutures(data.activenessResult.futureUS, data.timeoutResult.future) + RequestFutures(data.activenessResult.futureUS, data.timeoutResult.futureUS) } Right(f) @@ -143,7 +165,7 @@ private[participant] class NaiveRequestTracker( logger.debug(withRC(rc, s"Added a second time to the request tracker")) Right( FutureUnlessShutdown.pure( - RequestFutures(oldData.activenessResult.futureUS, oldData.timeoutResult.future) + RequestFutures(oldData.activenessResult.futureUS, oldData.timeoutResult.futureUS) ) ) } else { @@ -187,7 +209,7 @@ private[participant] class NaiveRequestTracker( rc, sc, requestData.requestTimestamp, - requestData.commitSetPromise.future, + requestData.commitSetPromise.futureUS, commitTime, ) val data = FinalizationData(resultTimestamp, commitTime)(task.finalizationResult) @@ -196,7 +218,7 @@ private[participant] class NaiveRequestTracker( logger.debug( withRC(rc, s"New result at $resultTimestamp signalled to the request tracker") ) - requestData.timeoutResult success NoTimeout + requestData.timeoutResult outcome NoTimeout taskScheduler.scheduleTask(task) taskScheduler.addTick(sc, resultTimestamp) Right(()) @@ -222,7 +244,7 @@ private[participant] class NaiveRequestTracker( ], Unit]] = { def tryAddCommitSet( - commitSetPromise: Promise[CommitSet], + commitSetPromise: PromiseUnlessShutdown[CommitSet], finalizationResult: PromiseUnlessShutdown[ Either[NonEmptyChain[RequestTrackerStoreError], Unit] ], @@ -230,7 +252,9 @@ private[participant] class NaiveRequestTracker( RequestTrackerStoreError ], Unit]] = { // Complete the promise only if we're not shutting down. - performUnlessClosing(functionFullName) { commitSetPromise.tryComplete(commitSet) } match { + performUnlessClosing(functionFullName) { + commitSetPromise.tryComplete(commitSet.map(UnlessShutdown.Outcome(_))) + } match { case UnlessShutdown.AbortedDueToShutdown => // Try to clean up as good as possible even though recovery of the ephemeral state will ultimately // take care of the cleaning up. @@ -246,9 +270,17 @@ private[participant] class NaiveRequestTracker( withRC(rc, s"Completed commit set promise does not contain a value") ) ) - if (oldCommitSet == commitSet) { + if (oldCommitSet == commitSet.map(UnlessShutdown.Outcome(_))) { logger.debug(withRC(rc, s"Commit set added a second time.")) Right(EitherT(finalizationResult.futureUS)) + } else if (oldCommitSet.toEither.contains(AbortedDueToShutdown)) { + logger.debug( + withRC( + rc, + s"Old commit set was aborted due to shutdown. New commit set will be ignored.", + ) + ) + Left(CommitSetAlreadyExists(rc)) } else { logger.warn(withRC(rc, s"Commit set with different parameters added a second time.")) Left(CommitSetAlreadyExists(rc)) @@ -336,7 +368,7 @@ private[participant] class NaiveRequestTracker( result.map { actRes => logger.trace(withRC(rc, s"Activeness result $actRes")) } - } + }.tapOnShutdown(activenessResult.shutdown()) override def pretty: Pretty[this.type] = prettyOfClass( param("timestamp", _.timestamp), @@ -356,7 +388,7 @@ private[participant] class NaiveRequestTracker( */ private[this] class TriggerTimeout( val rc: RequestCounter, - timeoutPromise: Promise[TimeoutResult], + timeoutPromise: PromiseUnlessShutdown[TimeoutResult], val requestTimestamp: CantonTimestamp, override val timestamp: CantonTimestamp, override val sequencerCounter: SequencerCounter, @@ -385,7 +417,7 @@ private[participant] class NaiveRequestTracker( * the promise because this would complete the timeout promise too early, as the conflict detector has * not yet released the locks held by the request. */ - timeoutPromise success Timeout + timeoutPromise outcome Timeout () } } else { FutureUnlessShutdown.unit } @@ -398,7 +430,7 @@ private[participant] class NaiveRequestTracker( param("rc", _.rc), ) - override def close(): Unit = () + override def close(): Unit = timeoutPromise.shutdown() } /** The action for finalizing a request by committing and rolling back contract changes. @@ -412,7 +444,7 @@ private[participant] class NaiveRequestTracker( rc: RequestCounter, override val sequencerCounter: SequencerCounter, requestTimestamp: CantonTimestamp, - commitSetFuture: Future[CommitSet], + commitSetFuture: FutureUnlessShutdown[CommitSet], commitTime: CantonTimestamp, )(override implicit val traceContext: TraceContext) extends TimedTask(commitTime, sequencerCounter, Kind.Finalization) { @@ -422,7 +454,7 @@ private[participant] class NaiveRequestTracker( */ val finalizationResult: PromiseUnlessShutdown[ Either[NonEmptyChain[RequestTrackerStoreError], Unit] - ] = new PromiseUnlessShutdown[Either[NonEmptyChain[RequestTrackerStoreError], Unit]]( + ] = mkPromise[Either[NonEmptyChain[RequestTrackerStoreError], Unit]]( "finalization-result", futureSupervisor, ) @@ -434,20 +466,28 @@ private[participant] class NaiveRequestTracker( */ override def perform(): FutureUnlessShutdown[Unit] = performUnlessClosingUSF("finalize-request") { - FutureUnlessShutdown.outcomeF(commitSetFuture).transformWith { + commitSetFuture.transformWith { case Success(UnlessShutdown.Outcome(commitSet)) => logger.debug(withRC(rc, s"Finalizing at $commitTime")) conflictDetector .finalizeRequest(commitSet, TimeOfChange(rc, requestTimestamp)) - .map { storeFuture => - // The finalization is complete when the conflict detection stores have been updated - finalizationResult.completeWith(storeFuture.unwrap) - // Immediately evict the request - evictRequest(rc) + .transform { + case Success(UnlessShutdown.Outcome(storeFuture)) => + // The finalization is complete when the conflict detection stores have been updated + finalizationResult.completeWith(storeFuture) + // Immediately evict the request + Success(UnlessShutdown.Outcome(evictRequest(rc))) + case Success(UnlessShutdown.AbortedDueToShutdown) => + finalizationResult.shutdown() + Success(UnlessShutdown.AbortedDueToShutdown) + case Failure(e) => + finalizationResult.tryFailure(e).discard[Boolean] + Failure(e) } case Success(UnlessShutdown.AbortedDueToShutdown) => logger.debug(withRC(rc, s"Aborted finalizing at $commitTime due to shutdown")) + finalizationResult.shutdown() FutureUnlessShutdown.abortedDueToShutdown case Failure(ex) => @@ -562,9 +602,9 @@ private[conflictdetection] object NaiveRequestTracker { activenessSet: ActivenessSet, )( val activenessResult: PromiseUnlessShutdown[ActivenessResult], - val timeoutResult: Promise[TimeoutResult], + val timeoutResult: PromiseUnlessShutdown[TimeoutResult], val finalizationDataCell: SingleUseCell[FinalizationData], - val commitSetPromise: Promise[CommitSet], + val commitSetPromise: PromiseUnlessShutdown[CommitSet], ) private[NaiveRequestTracker] object RequestData { @@ -573,22 +613,19 @@ private[conflictdetection] object NaiveRequestTracker { requestTimestamp: CantonTimestamp, decisionTime: CantonTimestamp, activenessSet: ActivenessSet, + promiseUSFactory: PromiseUnlessShutdownFactory, futureSupervisor: FutureSupervisor, - )(implicit - elc: ErrorLoggingContext, - ec: ExecutionContext, - ): RequestData = + )(implicit elc: ErrorLoggingContext, executionContext: ExecutionContext): RequestData = new RequestData( sequencerCounter = sc, requestTimestamp = requestTimestamp, decisionTime = decisionTime, activenessSet = activenessSet, )( - activenessResult = - new PromiseUnlessShutdown[ActivenessResult]("activeness-result", futureSupervisor), - timeoutResult = Promise[TimeoutResult](), + activenessResult = promiseUSFactory.mkPromise("activeness-result", futureSupervisor), + timeoutResult = promiseUSFactory.mkPromise("timeout-result", futureSupervisor), finalizationDataCell = new SingleUseCell[FinalizationData], - commitSetPromise = new SupervisedPromise[CommitSet]("commit-set", futureSupervisor), + commitSetPromise = promiseUSFactory.mkPromise("commit-set", futureSupervisor), ) } diff --git a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTracker.scala b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTracker.scala index 8cbb49dace..438d2cbd24 100644 --- a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTracker.scala +++ b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTracker.scala @@ -308,7 +308,7 @@ object RequestTracker { */ final case class RequestFutures( activenessResult: FutureUnlessShutdown[ActivenessResult], - timeoutResult: Future[TimeoutResult], + timeoutResult: FutureUnlessShutdown[TimeoutResult], ) /** Indicates whether the request has timed out. */ diff --git a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/store/SyncDomainEphemeralState.scala b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/store/SyncDomainEphemeralState.scala index 4be178281c..effc610cb2 100644 --- a/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/store/SyncDomainEphemeralState.scala +++ b/canton/community/participant/src/main/scala/com/digitalasset/canton/participant/store/SyncDomainEphemeralState.scala @@ -147,6 +147,7 @@ class SyncDomainEphemeralState( startingPoints.cleanReplay.nextRequestCounter, loggerFactory, futureSupervisor, + timeouts, ) val observedTimestampTracker = new WatermarkTracker[CantonTimestamp]( @@ -187,6 +188,7 @@ class SyncDomainEphemeralState( requestTracker, recordOrderPublisher, submissionTracker, + phase37Synchronizer, AsyncCloseable( "request-journal-flush", requestJournal.flush(), diff --git a/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/Phase37SynchronizerTest.scala b/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/Phase37SynchronizerTest.scala index 5011b92995..7fb8319fa5 100644 --- a/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/Phase37SynchronizerTest.scala +++ b/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/Phase37SynchronizerTest.scala @@ -22,7 +22,7 @@ import scala.concurrent.Future class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutionContext { private def mk(initRc: RequestCounter = RequestCounter(0)): Phase37Synchronizer = - new Phase37Synchronizer(initRc, loggerFactory, FutureSupervisor.Noop) + new Phase37Synchronizer(initRc, loggerFactory, FutureSupervisor.Noop, timeouts) private val requestId1 = RequestId(CantonTimestamp.ofEpochSecond(1)) private val requestId2 = RequestId(CantonTimestamp.ofEpochSecond(2)) @@ -52,6 +52,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio ) p37s .awaitConfirmed(requestType)(requestId1) + .failOnShutdown .futureValue shouldBe RequestOutcome.Success(pendingRequestData) } @@ -61,6 +62,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio p37s.registerRequest(requestType)(requestId1).complete(None) p37s .awaitConfirmed(requestType)(requestId1) + .failOnShutdown .futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout } @@ -75,7 +77,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio handle.complete( Some(pendingRequestData) ) - f.futureValue shouldBe RequestOutcome.Success(pendingRequestData) + f.failOnShutdown.futureValue shouldBe RequestOutcome.Success(pendingRequestData) } "return only after reaching confirmed (for request timeout)" in { @@ -86,7 +88,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio assert(!f.isCompleted) handle.complete(None) - f.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout + f.failOnShutdown.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout } "return after request is marked as timeout and the memory cleaned" in { @@ -100,6 +102,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio } p37s .awaitConfirmed(requestType)(requestId1) + .failOnShutdown .futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout } @@ -116,8 +119,8 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio val f2 = p37s.awaitConfirmed(requestType)(requestId1) - f1.futureValue shouldBe RequestOutcome.Success(pendingRequestData) - f2.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout + f1.failOnShutdown.futureValue shouldBe RequestOutcome.Success(pendingRequestData) + f2.failOnShutdown.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout } "complain if multiple registers have been called for the same requestID" in { @@ -152,8 +155,10 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio val f3 = p37s.awaitConfirmed(requestType)(requestId1) - f1.futureValue shouldBe RequestOutcome.Success(pendingRequestData) - forAll(Seq(f2, f3))(fut => fut.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout) + f1.failOnShutdown.futureValue shouldBe RequestOutcome.Success(pendingRequestData) + forAll(Seq(f2, f3))(fut => + fut.failOnShutdown.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout + ) } "no valid confirms" in { @@ -184,7 +189,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio handle.complete(Some(pendingRequestData)) - forAll(Seq(f1, f2, f3))(fut => fut.futureValue shouldBe RequestOutcome.Invalid) + forAll(Seq(f1, f2, f3))(fut => fut.failOnShutdown.futureValue shouldBe RequestOutcome.Invalid) } "deal with several calls for the same unconfirmed request with different filters" in { @@ -221,8 +226,10 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio _ => Future.successful(true), ) - f1.futureValue shouldBe RequestOutcome.Success(pendingRequestData) - forAll(Seq(f2, f3, f4))(fut => fut.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout) + f1.failOnShutdown.futureValue shouldBe RequestOutcome.Success(pendingRequestData) + forAll(Seq(f2, f3, f4))(fut => + fut.failOnShutdown.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout + ) } "deal with several calls for the same confirmed request with different filters" in { @@ -238,10 +245,13 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio val f1 = p37s .awaitConfirmed(requestType)(requestId1, _ => Future.successful(true)) + .failOnShutdown val f2 = p37s .awaitConfirmed(requestType)(requestId1, _ => Future.successful(false)) + .failOnShutdown val f3 = p37s .awaitConfirmed(requestType)(requestId1, _ => Future.successful(true)) + .failOnShutdown f1.futureValue shouldBe RequestOutcome.Success(pendingRequestData0) forAll(Seq(f2, f3))(fut => fut.futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout) @@ -253,10 +263,13 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio ) val f4 = p37s .awaitConfirmed(requestType)(requestId2, _ => Future.successful(false)) + .failOnShutdown val f5 = p37s .awaitConfirmed(requestType)(requestId2, _ => Future.successful(true)) + .failOnShutdown val f6 = p37s .awaitConfirmed(requestType)(requestId2, _ => Future.successful(false)) + .failOnShutdown f4.futureValue shouldBe RequestOutcome.Invalid f5.futureValue shouldBe RequestOutcome.Success(pendingRequestData1) @@ -274,6 +287,7 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio ) p37s .awaitConfirmed(requestType)(requestId1) + .failOnShutdown .futureValue shouldBe RequestOutcome.Success(pendingRequestData) p37s @@ -320,17 +334,21 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio requestId1, _ => Future.successful(true), ) + .failOnShutdown true }, ) + .failOnShutdown false }) }, ) + .failOnShutdown false }) }, ) + .failOnShutdown eventually() { f1.futureValue shouldBe RequestOutcome.Invalid @@ -353,10 +371,12 @@ class Phase37SynchronizerTest extends AnyWordSpec with BaseTest with HasExecutio .complete(Some(pendingRequestData)) p37s .awaitConfirmed(AnotherTestPendingRequestDataType)(requestId1) + .failOnShutdown .futureValue shouldBe RequestOutcome.AlreadyServedOrTimeout - p37s.awaitConfirmed(requestType)(requestId1).futureValue shouldBe RequestOutcome.Success( - pendingRequestData - ) + p37s.awaitConfirmed(requestType)(requestId1).failOnShutdown.futureValue shouldBe RequestOutcome + .Success( + pendingRequestData + ) } } diff --git a/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTrackerTest.scala b/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTrackerTest.scala index e753d9d813..ca5869d172 100644 --- a/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTrackerTest.scala +++ b/canton/community/participant/src/test/scala/com/digitalasset/canton/participant/protocol/conflictdetection/RequestTrackerTest.scala @@ -684,7 +684,7 @@ private[conflictdetection] trait RequestTrackerTest { ) _ = enterTick(rt, SequencerCounter(0), CantonTimestamp.Epoch) _ = enterTick(rt, SequencerCounter(2), ofEpochMilli(10)) - timeout <- toF + timeout <- toF.failOnShutdown("activeness result") _ = assert(timeout.timedOut) } yield succeed } @@ -1153,7 +1153,7 @@ private[conflictdetection] trait RequestTrackerTest { decisionTime, activenessSet, ).map { case (aR, tR) => - (aR.failOnShutdown("activeness result"), tR) + (aR.failOnShutdown("activeness result"), tR.failOnShutdown("timeout result")) } } @@ -1164,7 +1164,7 @@ private[conflictdetection] trait RequestTrackerTest { confirmationRequestTimestamp: CantonTimestamp, decisionTime: CantonTimestamp, activenessSet: ActivenessSet, - ): Future[(FutureUnlessShutdown[ActivenessResult], Future[TimeoutResult])] = + ): Future[(FutureUnlessShutdown[ActivenessResult], FutureUnlessShutdown[TimeoutResult])] = enterCR_US( rt, rc, @@ -1183,7 +1183,7 @@ private[conflictdetection] trait RequestTrackerTest { activenessTimestamp: CantonTimestamp, decisionTime: CantonTimestamp, activenessSet: ActivenessSet, - ): Future[(FutureUnlessShutdown[ActivenessResult], Future[TimeoutResult])] = { + ): Future[(FutureUnlessShutdown[ActivenessResult], FutureUnlessShutdown[TimeoutResult])] = { val resCR = rt.addRequest( rc, sc,