diff --git a/language-support/java/bindings-rxjava/BUILD.bazel b/language-support/java/bindings-rxjava/BUILD.bazel index 748180d7b4..72178aa141 100644 --- a/language-support/java/bindings-rxjava/BUILD.bazel +++ b/language-support/java/bindings-rxjava/BUILD.bazel @@ -67,6 +67,7 @@ da_scala_library( "@maven//:org_scalatest_scalatest_core", "@maven//:org_scalatest_scalatest_matchers_core", "@maven//:org_scalatest_scalatest_shouldmatchers", + "@maven//:com_typesafe_akka_akka_actor", ], deps = [ ":bindings-rxjava", diff --git a/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/grpc/helpers/LedgerServices.scala b/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/grpc/helpers/LedgerServices.scala index 15cfe274c8..405a488cb3 100644 --- a/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/grpc/helpers/LedgerServices.scala +++ b/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/grpc/helpers/LedgerServices.scala @@ -8,6 +8,7 @@ import java.net.{InetSocketAddress, SocketAddress} import java.time.{Clock, Duration} import java.util.concurrent.TimeUnit +import akka.actor.ActorSystem import com.daml.ledger.rxjava.grpc._ import com.daml.ledger.rxjava.grpc.helpers.TransactionsServiceImpl.LedgerItem import com.daml.ledger.rxjava.{CommandCompletionClient, LedgerConfigurationClient, PackageClient} @@ -46,6 +47,7 @@ final class LedgerServices(val ledgerId: String) { val executionContext: ExecutionContext = global private val esf: ExecutionSequencerFactory = new SingleThreadExecutionSequencerPool(ledgerId) + private val akkaSystem = ActorSystem("LedgerServicesParticipant") private val participantId = "LedgerServicesParticipant" private val authorizer = Authorizer( @@ -53,6 +55,10 @@ final class LedgerServices(val ledgerId: String) { ledgerId, participantId, new ErrorCodesVersionSwitcher(enableSelfServiceErrorCodes = true), + new InMemoryUserManagementStore(), + executionContext, + userRightsCheckIntervalInSeconds = 1, + akkaScheduler = akkaSystem.scheduler, ) def newServerBuilder(): NettyServerBuilder = NettyServerBuilder.forAddress(nextAddress()) diff --git a/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/package.scala b/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/package.scala index b4ed477f61..4393ffebe5 100644 --- a/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/package.scala +++ b/language-support/java/bindings-rxjava/src/test/scala/com/daml/ledger/rxjava/package.scala @@ -4,9 +4,12 @@ package com.daml.ledger import com.daml.error.ErrorCodesVersionSwitcher - import java.time.Clock import java.util.UUID + +import akka.actor.ActorSystem + +import scala.concurrent.ExecutionContext import com.daml.lf.data.Ref import com.daml.ledger.api.auth.{ AuthServiceStatic, @@ -18,11 +21,14 @@ import com.daml.ledger.api.auth.{ ClaimReadAsParty, ClaimSet, } +import com.daml.ledger.participant.state.index.impl.inmemory.InMemoryUserManagementStore package object rxjava { private[rxjava] def untestedEndpoint: Nothing = throw new UnsupportedOperationException("Untested endpoint, implement if needed") + private val akkaSystem = ActorSystem("testActorSystem") + sys.addShutdownHook(akkaSystem.terminate(): Unit) private[rxjava] val authorizer = Authorizer( @@ -30,6 +36,10 @@ package object rxjava { "testLedgerId", "testParticipantId", new ErrorCodesVersionSwitcher(enableSelfServiceErrorCodes = true), + new InMemoryUserManagementStore(), + ExecutionContext.parasitic, + userRightsCheckIntervalInSeconds = 1, + akkaScheduler = akkaSystem.scheduler, ) private[rxjava] val emptyToken = "empty" diff --git a/ledger/error/BUILD.bazel b/ledger/error/BUILD.bazel index 9d250dd3f5..8fb58f4b15 100644 --- a/ledger/error/BUILD.bazel +++ b/ledger/error/BUILD.bazel @@ -65,6 +65,8 @@ da_scala_library( deps = [ "//ledger/error", "//ledger/test-common", + "//libs-scala/contextualized-logging", + "//libs-scala/scala-utils", "@maven//:ch_qos_logback_logback_classic", "@maven//:ch_qos_logback_logback_core", "@maven//:com_google_api_grpc_proto_google_common_protos", diff --git a/ledger/error/src/main/scala/com/daml/error/definitions/LedgerApiErrors.scala b/ledger/error/src/main/scala/com/daml/error/definitions/LedgerApiErrors.scala index 57066df9da..9f212217fe 100644 --- a/ledger/error/src/main/scala/com/daml/error/definitions/LedgerApiErrors.scala +++ b/ledger/error/src/main/scala/com/daml/error/definitions/LedgerApiErrors.scala @@ -16,9 +16,10 @@ import com.daml.lf.transaction.GlobalKey import com.daml.lf.value.Value import com.daml.lf.{VersionRange, language} import org.slf4j.event.Level - import java.time.{Duration, Instant} +import scala.concurrent.duration._ + @Explanation( "Errors raised by or forwarded by the Ledger API." ) @@ -287,8 +288,30 @@ object LedgerApiErrors extends LedgerApiErrorGroup { } } - @Explanation("Authentication errors.") + @Explanation("Authentication and authorization errors.") object AuthorizationChecks extends ErrorGroup() { + + @Explanation("""The stream was aborted because the authenticated user's rights changed, + |and the user might thus no longer be authorized to this stream. + |""") + @Resolution( + "The application should automatically retry fetching the stream. It will either succeed, or fail with an explicit denial of authentication or permission." + ) + object StaleUserManagementBasedStreamClaims + extends ErrorCode( + id = "STALE_STREAM_AUTHORIZATION", + ErrorCategory.ContentionOnSharedResources, + ) { + case class Reject()(implicit + loggingContext: ContextualizedErrorLogger + ) extends LoggingTransactionErrorImpl("Stale stream authorization. Retry quickly.") { + override def retryable: Option[ErrorCategoryRetry] = Some( + ErrorCategoryRetry(who = "application", duration = 0.seconds) + ) + } + + } + @Explanation( """This rejection is given if the submitted command does not contain a JWT token on a participant enforcing JWT authentication.""" ) diff --git a/ledger/error/src/main/scala/com/daml/error/utils/ErrorDetails.scala b/ledger/error/src/main/scala/com/daml/error/utils/ErrorDetails.scala index 34635a6a2f..db3d34597c 100644 --- a/ledger/error/src/main/scala/com/daml/error/utils/ErrorDetails.scala +++ b/ledger/error/src/main/scala/com/daml/error/utils/ErrorDetails.scala @@ -7,6 +7,7 @@ import com.google.protobuf import com.google.rpc.{ErrorInfo, RequestInfo, ResourceInfo, RetryInfo} import scala.jdk.CollectionConverters._ +import scala.concurrent.duration._ object ErrorDetails { sealed trait ErrorDetail extends Product with Serializable @@ -14,7 +15,7 @@ object ErrorDetails { final case class ResourceInfoDetail(name: String, typ: String) extends ErrorDetail final case class ErrorInfoDetail(reason: String, metadata: Map[String, String]) extends ErrorDetail - final case class RetryInfoDetail(retryDelayInSeconds: Long) extends ErrorDetail + final case class RetryInfoDetail(duration: Duration) extends ErrorDetail final case class RequestInfoDetail(requestId: String) extends ErrorDetail def from(anys: Seq[protobuf.Any]): Seq[ErrorDetail] = anys.toList.map { @@ -28,7 +29,9 @@ object ErrorDetails { case any if any.is(classOf[RetryInfo]) => val v = any.unpack(classOf[RetryInfo]) - RetryInfoDetail(v.getRetryDelay.getSeconds) + val delay = v.getRetryDelay + val duration = (delay.getSeconds.seconds + delay.getNanos.nanos).toCoarsest + RetryInfoDetail(duration) case any if any.is(classOf[RequestInfo]) => val v = any.unpack(classOf[RequestInfo]) diff --git a/ledger/error/src/test/lib/scala/com/daml/error/ErrorsAssertions.scala b/ledger/error/src/test/lib/scala/com/daml/error/ErrorsAssertions.scala index abc53c7399..411c6b575a 100644 --- a/ledger/error/src/test/lib/scala/com/daml/error/ErrorsAssertions.scala +++ b/ledger/error/src/test/lib/scala/com/daml/error/ErrorsAssertions.scala @@ -4,8 +4,10 @@ package com.daml.error import com.daml.error.utils.ErrorDetails +import com.daml.logging.{ContextualizedLogger, LoggingContext} import com.daml.platform.testing.{LogCollector, LogCollectorAssertions} import com.daml.platform.testing.LogCollector.ExpectedLogEntry +import com.daml.scalautil.Statement import io.grpc.Status.Code import io.grpc.StatusRuntimeException import io.grpc.protobuf.StatusProto @@ -13,9 +15,40 @@ import org.scalatest.matchers.should.Matchers import scala.jdk.CollectionConverters._ import scala.reflect.ClassTag +import org.scalatest.Checkpoints.Checkpoint trait ErrorsAssertions { - self: Matchers with LogCollectorAssertions => + self: Matchers => + + private val logger = ContextualizedLogger.get(getClass) + private val loggingContext = LoggingContext.ForTesting + private val errorLogger = new DamlContextualizedErrorLogger(logger, loggingContext, None) + + def assertError( + actual: StatusRuntimeException, + expectedF: ContextualizedErrorLogger => StatusRuntimeException, + ): Unit = { + assertError( + actual = actual, + expected = expectedF(errorLogger), + ) + } + + /** Asserts that the two errors have the same code, message and details. + */ + def assertError( + actual: StatusRuntimeException, + expected: StatusRuntimeException, + ): Unit = { + val expectedStatus = StatusProto.fromThrowable(expected) + val expectedDetails = expectedStatus.getDetailsList.asScala.toSeq + assertError( + actual = actual, + expectedCode = expected.getStatus.getCode, + expectedMessage = expectedStatus.getMessage, + expectedDetails = ErrorDetails.from(expectedDetails), + ) + } def assertError( actual: StatusRuntimeException, @@ -23,9 +56,26 @@ trait ErrorsAssertions { expectedMessage: String, expectedDetails: Seq[ErrorDetails.ErrorDetail], ): Unit = { - doAssertError(actual, expectedCode, expectedMessage, expectedDetails, None) + val actualStatus = StatusProto.fromThrowable(actual) + val actualDetails = actualStatus.getDetailsList.asScala.toSeq + val cp = new Checkpoint + cp { Statement.discard { actual.getStatus.getCode shouldBe expectedCode } } + cp { Statement.discard { actualStatus.getMessage shouldBe expectedMessage } } + cp { + Statement.discard { + ErrorDetails.from(actualDetails) should contain theSameElementsAs expectedDetails + } + } + cp.reportAll() } +} + +trait ErrorAssertionsWithLogCollectorAssertions + extends ErrorsAssertions + with LogCollectorAssertions { + self: Matchers => + def assertError[Test, Logger]( actual: StatusRuntimeException, expectedCode: Code, @@ -36,32 +86,15 @@ trait ErrorsAssertions { test: ClassTag[Test], logger: ClassTag[Logger], ): Unit = { - doAssertError(actual, expectedCode, expectedMessage, expectedDetails, Some(expectedLogEntry))( - test, - logger, + assertError( + actual = actual, + expectedCode = expectedCode, + expectedMessage = expectedMessage, + expectedDetails = expectedDetails, ) - } - - private def doAssertError[Test, Logger]( - actual: StatusRuntimeException, - expectedCode: Code, - expectedMessage: String, - expectedDetails: Seq[ErrorDetails.ErrorDetail], - expectedLogEntry: Option[ExpectedLogEntry], - )(implicit - test: ClassTag[Test], - logger: ClassTag[Logger], - ): Unit = { - val status = StatusProto.fromThrowable(actual) - status.getCode shouldBe expectedCode.value() - status.getMessage shouldBe expectedMessage - val details = status.getDetailsList.asScala.toSeq - val _ = ErrorDetails.from(details) should contain theSameElementsAs expectedDetails - if (expectedLogEntry.isDefined) { - val actualLogs: Seq[LogCollector.Entry] = LogCollector.readAsEntries(test, logger) - actualLogs should have size 1 - assertLogEntry(actualLogs.head, expectedLogEntry.get) - } + val actualLogs: Seq[LogCollector.Entry] = LogCollector.readAsEntries(test, logger) + actualLogs should have size 1 + assertLogEntry(actualLogs.head, expectedLogEntry) } } diff --git a/ledger/error/src/test/suite/scala/com/daml/error/ErrorCodeSpec.scala b/ledger/error/src/test/suite/scala/com/daml/error/ErrorCodeSpec.scala index ae21bfa0dd..c3a0de7998 100644 --- a/ledger/error/src/test/suite/scala/com/daml/error/ErrorCodeSpec.scala +++ b/ledger/error/src/test/suite/scala/com/daml/error/ErrorCodeSpec.scala @@ -21,7 +21,7 @@ class ErrorCodeSpec with Matchers with BeforeAndAfter with LogCollectorAssertions - with ErrorsAssertions { + with ErrorAssertionsWithLogCollectorAssertions { implicit private val testLoggingContext: LoggingContext = LoggingContext.ForTesting private val logger = ContextualizedLogger.get(getClass) @@ -85,7 +85,7 @@ class ErrorCodeSpec NotSoSeriousError.id, Map("category" -> "1") ++ contextMetadata ++ Map("definite_answer" -> "true"), ), - ErrorDetails.RetryInfoDetail(TransientServerFailure.retryable.get.duration.toSeconds), + ErrorDetails.RetryInfoDetail(TransientServerFailure.retryable.get.duration), ErrorDetails.RequestInfoDetail(correlationId), ErrorDetails.ResourceInfoDetail(error.resources.head._1.asString, error.resources.head._2), ), diff --git a/ledger/ledger-api-auth/BUILD.bazel b/ledger/ledger-api-auth/BUILD.bazel index d61a5a063c..f8d0f70cbd 100644 --- a/ledger/ledger-api-auth/BUILD.bazel +++ b/ledger/ledger-api-auth/BUILD.bazel @@ -15,6 +15,7 @@ da_scala_library( scala_deps = [ "@maven//:io_spray_spray_json", "@maven//:org_scalaz_scalaz_core", + "@maven//:com_typesafe_akka_akka_actor", ], tags = ["maven_coordinates=com.daml:ledger-api-auth:__VERSION__"], visibility = [ @@ -63,17 +64,28 @@ da_scala_test_suite( "@maven//:org_scalatest_scalatest_shouldmatchers", "@maven//:org_scalatest_scalatest_wordspec", "@maven//:org_scalatestplus_scalacheck_1_15", + "@maven//:com_typesafe_akka_akka_actor", + "@maven//:com_typesafe_akka_akka_stream", ], deps = [ ":ledger-api-auth", + "//daml-lf/data", + "//ledger-api/rs-grpc-bridge", + "//ledger-api/testing-utils", "//ledger/error", + "//ledger/error:error-test-lib", + "//ledger/ledger-api-common", + "//ledger/ledger-api-domain", "//ledger/participant-state-index", "//ledger/test-common", + "//libs-scala/adjustable-clock", + "//libs-scala/contextualized-logging", "@maven//:com_google_api_grpc_proto_google_common_protos", "@maven//:com_google_protobuf_protobuf_java", "@maven//:io_grpc_grpc_api", "@maven//:io_grpc_grpc_context", "@maven//:io_grpc_grpc_protobuf", + "@maven//:io_grpc_grpc_stub", "@maven//:org_mockito_mockito_core", "@maven//:org_scalatest_scalatest_compatible", ], diff --git a/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/Authorizer.scala b/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/Authorizer.scala index 3a6153808e..9834b57c71 100644 --- a/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/Authorizer.scala +++ b/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/Authorizer.scala @@ -3,6 +3,10 @@ package com.daml.ledger.api.auth +import java.time.Instant + +import akka.actor.Scheduler +import com.daml.error.definitions.LedgerApiErrors import com.daml.error.{ ContextualizedErrorLogger, DamlContextualizedErrorLogger, @@ -10,16 +14,14 @@ import com.daml.error.{ } import com.daml.ledger.api.auth.interceptor.AuthorizationInterceptor import com.daml.ledger.api.v1.transaction_filter.TransactionFilter +import com.daml.ledger.participant.state.index.v2.UserManagementStore import com.daml.logging.{ContextualizedLogger, LoggingContext} import com.daml.platform.server.api.validation.ErrorFactories -import io.grpc.stub.{ServerCallStreamObserver, StreamObserver} -import java.time.Instant - -import com.daml.error.definitions.LedgerApiErrors import io.grpc.StatusRuntimeException +import io.grpc.stub.{ServerCallStreamObserver, StreamObserver} import scalapb.lenses.Lens -import scala.concurrent.Future +import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success, Try} /** A simple helper that allows services to use authorization claims @@ -30,6 +32,10 @@ final class Authorizer( ledgerId: String, participantId: String, errorCodesVersionSwitcher: ErrorCodesVersionSwitcher, + userManagementStore: UserManagementStore, + ec: ExecutionContext, + userRightsCheckIntervalInSeconds: Int, + akkaScheduler: Scheduler, )(implicit loggingContext: LoggingContext) { private val logger = ContextualizedLogger.get(this.getClass) private val errorFactories = ErrorFactories(errorCodesVersionSwitcher) @@ -226,16 +232,17 @@ final class Authorizer( } private def ongoingAuthorization[Res]( - scso: ServerCallStreamObserver[Res], + observer: ServerCallStreamObserver[Res], claims: ClaimSet.Claims, ) = new OngoingAuthorizationObserver[Res]( - scso, - claims, - _.notExpired(now()), - authorizationError => { - errorFactories.permissionDenied(authorizationError.reason) - }, - ) + observer = observer, + originalClaims = claims, + nowF = now, + errorFactories = errorFactories, + userManagementStore = userManagementStore, + userRightsCheckIntervalInSeconds = userRightsCheckIntervalInSeconds, + akkaScheduler = akkaScheduler, + )(loggingContext, ec) /** Directly access the authenticated claims from the thread-local context. * @@ -263,7 +270,7 @@ final class Authorizer( private def authorizeWithReq[Req, Res](call: (Req, ServerCallStreamObserver[Res]) => Unit)( authorized: (ClaimSet.Claims, Req) => Either[StatusRuntimeException, Req] ): (Req, StreamObserver[Res]) => Unit = (request, observer) => { - val scso = assertServerCall(observer) + val serverCallStreamObserver = assertServerCall(observer) authenticatedClaimsFromContext() .fold( ex => { @@ -278,10 +285,10 @@ final class Authorizer( case Right(modifiedRequest) => call( modifiedRequest, - if (claims.expiration.isDefined) - ongoingAuthorization(scso, claims) + if (claims.expiration.isDefined || claims.resolvedFromUser) + ongoingAuthorization(serverCallStreamObserver, claims) else - scso, + serverCallStreamObserver, ) case Left(ex) => observer.onError(ex) @@ -324,8 +331,21 @@ object Authorizer { ledgerId: String, participantId: String, errorCodesVersionSwitcher: ErrorCodesVersionSwitcher, + userManagementStore: UserManagementStore, + ec: ExecutionContext, + userRightsCheckIntervalInSeconds: Int, + akkaScheduler: Scheduler, ): Authorizer = LoggingContext.newLoggingContext { loggingContext => - new Authorizer(now, ledgerId, participantId, errorCodesVersionSwitcher)(loggingContext) + new Authorizer( + now = now, + ledgerId = ledgerId, + participantId = participantId, + errorCodesVersionSwitcher = errorCodesVersionSwitcher, + userManagementStore = userManagementStore, + ec = ec, + userRightsCheckIntervalInSeconds = userRightsCheckIntervalInSeconds, + akkaScheduler = akkaScheduler, + )(loggingContext) } } diff --git a/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserver.scala b/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserver.scala index 90773cd42d..7764242d21 100644 --- a/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserver.scala +++ b/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserver.scala @@ -3,14 +3,84 @@ package com.daml.ledger.api.auth +import java.time.Instant + +import akka.actor.{Cancellable, Scheduler} +import com.daml.error.DamlContextualizedErrorLogger +import com.daml.error.definitions.LedgerApiErrors +import com.daml.ledger.api.auth.interceptor.AuthorizationInterceptor +import com.daml.ledger.participant.state.index.v2.UserManagementStore +import com.daml.lf.data.Ref +import com.daml.logging.{ContextualizedLogger, LoggingContext} +import com.daml.platform.server.api.validation.ErrorFactories +import io.grpc.StatusRuntimeException import io.grpc.stub.ServerCallStreamObserver +import scala.concurrent.ExecutionContext +import scala.concurrent.duration._ +import scala.util.{Failure, Success} + +/** @param userRightsCheckIntervalInSeconds - determines the interval at which to check whether user rights state has changed. + * Also, double of this value serves as timeout value for subsequent user rights state checks. + */ private[auth] final class OngoingAuthorizationObserver[A]( observer: ServerCallStreamObserver[A], - claims: ClaimSet.Claims, - authorized: ClaimSet.Claims => Either[AuthorizationError, Unit], - throwOnFailure: AuthorizationError => Throwable, -) extends ServerCallStreamObserver[A] { + originalClaims: ClaimSet.Claims, + nowF: () => Instant, + errorFactories: ErrorFactories, + userManagementStore: UserManagementStore, + userRightsCheckIntervalInSeconds: Int, + akkaScheduler: Scheduler, +)(implicit loggingContext: LoggingContext, ec: ExecutionContext) + extends ServerCallStreamObserver[A] { + + private val logger = ContextualizedLogger.get(getClass) + private val errorLogger = new DamlContextualizedErrorLogger(logger, loggingContext, None) + + // Guards against propagating calls to delegate observer after either + // [[onComplete]] or [[onError]] has already been called once. + // We need this because [[onError]] can be invoked two concurrent sources: + // 1) scheduled user rights state change task (see [[cancellableO]]), + // 2) upstream component that is translating upstream Akka stream into [[onNext]] and other signals. + private var afterCompletionOrError = false + + @volatile private var lastUserRightsCheckTime = nowF() + + private lazy val userId = originalClaims.applicationId.fold[Ref.UserId]( + throw new RuntimeException( + "Claims were resolved from a user but userId (applicationId) is missing in the claims." + ) + )(Ref.UserId.assertFromString) + + // Scheduling a task that periodically checks + // whether user rights state has changed. + // If user rights state has changed it aborts the stream by calling [[onError]] + private val cancellableO: Option[Cancellable] = { + if (originalClaims.resolvedFromUser) { + val delay = userRightsCheckIntervalInSeconds.seconds + // Note: https://doc.akka.io/docs/akka/2.6.13/scheduler.html states that: + // "All scheduled task will be executed when the ActorSystem is terminated, i.e. the task may execute before its timeout." + val c = akkaScheduler.scheduleWithFixedDelay(initialDelay = delay, delay = delay)(runnable = + checkUserRights _ + ) + Some(c) + } else None + } + + private def checkUserRights(): Unit = { + userManagementStore + .listUserRights(userId) + .onComplete { + case Failure(_) | Success(Left(_)) => + onError(staleStreamAuthError) + case Success(Right(userRights)) => + val updatedClaims = AuthorizationInterceptor.convertUserRightsToClaims(userRights) + if (updatedClaims.toSet != originalClaims.claims.toSet) { + onError(staleStreamAuthError) + } + lastUserRightsCheckTime = nowF() + } + } override def isCancelled: Boolean = observer.isCancelled @@ -26,15 +96,81 @@ private[auth] final class OngoingAuthorizationObserver[A]( override def request(i: Int): Unit = observer.request(i) - override def setMessageCompression(b: Boolean): Unit = observer.setMessageCompression(b) + override def setMessageCompression(b: Boolean): Unit = synchronized( + observer.setMessageCompression(b) + ) - override def onNext(v: A): Unit = - authorized(claims) match { - case Right(_) => observer.onNext(v) - case Left(authorizationError) => observer.onError(throwOnFailure(authorizationError)) + override def onNext(v: A): Unit = synchronized { + if (!afterCompletionOrError) { + val now = nowF() + (for { + _ <- checkClaimsExpiry(now) + _ <- checkUserRightsRefreshTimeout(now) + } yield { + () + }) match { + case Right(_) => observer.onNext(v) + case Left(e) => + onError(e) + } } + } - override def onError(throwable: Throwable): Unit = observer.onError(throwable) + override def onError(throwable: Throwable): Unit = synchronized { + if (!afterCompletionOrError) { + afterCompletionOrError = true + cancelUserRightsCheckTask() + observer.onError(throwable) + } + } + + override def onCompleted(): Unit = synchronized { + if (!afterCompletionOrError) { + afterCompletionOrError = true + cancelUserRightsCheckTask() + observer.onCompleted() + } + } + + private def checkUserRightsRefreshTimeout(now: Instant): Either[StatusRuntimeException, Unit] = { + + // Safety switch to abort the stream if the user-rights-state-check task + // fails to refresh within 2*[[userRightsCheckIntervalInSeconds]] seconds. + // In normal conditions we expected the refresh delay to be about [[userRightsCheckIntervalInSeconds]] seconds. + if ( + originalClaims.resolvedFromUser && + lastUserRightsCheckTime.isBefore( + now.minusSeconds(2 * userRightsCheckIntervalInSeconds.toLong) + ) + ) { + Left(staleStreamAuthError) + } else Right(()) + } + + private def checkClaimsExpiry(now: Instant): Either[StatusRuntimeException, Unit] = { + originalClaims + .notExpired(now) + .left + .map(authorizationError => + errorFactories.permissionDenied(authorizationError.reason)(errorLogger) + ) + } + + private def staleStreamAuthError: StatusRuntimeException = { + // Terminate the stream, so that clients will restart their streams + // and claims will be rechecked precisely. + LedgerApiErrors.AuthorizationChecks.StaleUserManagementBasedStreamClaims + .Reject()(errorLogger) + .asGrpcError + } + + private def cancelUserRightsCheckTask(): Unit = { + cancellableO.foreach { cancellable => + cancellable.cancel() + if (!cancellable.isCancelled) { + logger.debug(s"Failed to cancel stream authorization task") + } + } + } - override def onCompleted(): Unit = observer.onCompleted() } diff --git a/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/interceptor/AuthorizationInterceptor.scala b/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/interceptor/AuthorizationInterceptor.scala index 07535bc04a..c6333d194d 100644 --- a/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/interceptor/AuthorizationInterceptor.scala +++ b/ledger/ledger-api-auth/src/main/scala/com/digitalasset/ledger/api/auth/interceptor/AuthorizationInterceptor.scala @@ -91,10 +91,10 @@ final class AuthorizationInterceptor( s"Could not resolve rights for user '$userId' due to '$msg'" )(errorLogger) ) - case Right(userClaims) => + case Right(userRights: Set[UserRight]) => Future.successful( ClaimSet.Claims( - claims = userClaims.view.map(userRightToClaim).toList.prepended(ClaimPublic), + claims = AuthorizationInterceptor.convertUserRightsToClaims(userRights), ledgerId = None, participantId = participantId, applicationId = Some(userId), @@ -133,11 +133,6 @@ final class AuthorizationInterceptor( Future.successful(userId) } - private[this] def userRightToClaim(r: UserRight): Claim = r match { - case UserRight.CanActAs(p) => ClaimActAsParty(Ref.Party.assertFromString(p)) - case UserRight.CanReadAs(p) => ClaimReadAsParty(Ref.Party.assertFromString(p)) - case UserRight.ParticipantAdmin => ClaimAdmin - } } object AuthorizationInterceptor { @@ -165,4 +160,14 @@ object AuthorizationInterceptor { LoggingContext.newLoggingContext { implicit loggingContext: LoggingContext => new AuthorizationInterceptor(authService, userManagementStoreO, ec, errorCodesStatusSwitcher) } + + def convertUserRightsToClaims(userRights: Set[UserRight]): Seq[Claim] = { + userRights.view.map(userRightToClaim).toList.prepended(ClaimPublic) + } + + private[this] def userRightToClaim(r: UserRight): Claim = r match { + case UserRight.CanActAs(p) => ClaimActAsParty(Ref.Party.assertFromString(p)) + case UserRight.CanReadAs(p) => ClaimReadAsParty(Ref.Party.assertFromString(p)) + case UserRight.ParticipantAdmin => ClaimAdmin + } } diff --git a/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/AuthorizerSpec.scala b/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/AuthorizerSpec.scala index e8e354b631..27ef580366 100644 --- a/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/AuthorizerSpec.scala +++ b/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/AuthorizerSpec.scala @@ -9,12 +9,20 @@ import io.grpc.{Status, StatusRuntimeException} import org.scalatest.Assertion import org.scalatest.flatspec.AsyncFlatSpec import org.scalatest.matchers.should.Matchers - import java.time.Instant -import scala.concurrent.Future + +import com.daml.ledger.api.testing.utils.AkkaBeforeAndAfterAll +import com.daml.ledger.participant.state.index.v2.UserManagementStore +import org.mockito.MockitoSugar + +import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success, Try} -class AuthorizerSpec extends AsyncFlatSpec with Matchers { +class AuthorizerSpec + extends AsyncFlatSpec + with Matchers + with MockitoSugar + with AkkaBeforeAndAfterAll { private val className = classOf[Authorizer].getSimpleName private val dummyRequest = 1337L private val expectedSuccessfulResponse = "expectedSuccessfulResponse" @@ -77,5 +85,9 @@ class AuthorizerSpec extends AsyncFlatSpec with Matchers { "some-ledger-id", "participant-id", new ErrorCodesVersionSwitcher(selfServiceErrorCodes), + mock[UserManagementStore], + mock[ExecutionContext], + userRightsCheckIntervalInSeconds = 1, + akkaScheduler = system.scheduler, ) } diff --git a/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserverSpec.scala b/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserverSpec.scala new file mode 100644 index 0000000000..1c2e2e0242 --- /dev/null +++ b/ledger/ledger-api-auth/src/test/suite/scala/com/digitalasset/ledger/api/auth/OngoingAuthorizationObserverSpec.scala @@ -0,0 +1,92 @@ +// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.daml.ledger.api.auth + +import java.time.{Clock, Duration, Instant, ZoneId} + +import akka.actor.{Cancellable, Scheduler} +import com.daml.clock.AdjustableClock +import com.daml.error.ErrorsAssertions +import com.daml.error.definitions.LedgerApiErrors +import com.daml.ledger.participant.state.index.v2.UserManagementStore +import com.daml.logging.LoggingContext +import com.daml.platform.server.api.validation.ErrorFactories +import io.grpc.StatusRuntimeException +import io.grpc.stub.ServerCallStreamObserver +import org.mockito.{ArgumentCaptor, ArgumentMatchersSugar, MockitoSugar} +import org.scalatest.flatspec.AsyncFlatSpec +import org.scalatest.matchers.should.Matchers + +import scala.concurrent.ExecutionContext +import scala.concurrent.duration.FiniteDuration + +class OngoingAuthorizationObserverSpec + extends AsyncFlatSpec + with Matchers + with MockitoSugar + with ArgumentMatchersSugar + with ErrorsAssertions { + + private val loggingContext = LoggingContext.ForTesting + + it should "signal onError aborting the stream when user rights state hasn't been refreshed in a timely manner" in { + val clock = AdjustableClock( + baseClock = Clock.fixed(Instant.now(), ZoneId.systemDefault()), + offset = Duration.ZERO, + ) + val delegate = mock[ServerCallStreamObserver[Int]] + val mockScheduler = mock[Scheduler] + // Set scheduler to do nothing + val cancellableMock = mock[Cancellable] + when( + mockScheduler.scheduleWithFixedDelay(any[FiniteDuration], any[FiniteDuration])(any[Runnable])( + any[ExecutionContext] + ) + ).thenReturn(cancellableMock) + val userRightsCheckIntervalInSeconds = 10 + val tested = new OngoingAuthorizationObserver( + observer = delegate, + originalClaims = ClaimSet.Claims.Empty.copy(resolvedFromUser = true), + nowF = clock.instant, + errorFactories = mock[ErrorFactories], + userManagementStore = mock[UserManagementStore], + // This is also the user rights state refresh timeout + userRightsCheckIntervalInSeconds = userRightsCheckIntervalInSeconds, + akkaScheduler = mockScheduler, + )(loggingContext, executionContext) + + // After 20 seconds pass we expect onError to be called due to lack of user rights state refresh task inactivity + tested.onNext(1) + clock.fastForward(Duration.ofSeconds(2.toLong * userRightsCheckIntervalInSeconds - 1)) + tested.onNext(2) + clock.fastForward(Duration.ofSeconds(2)) + // Next onNext detects the user rights state refresh task inactivity + tested.onNext(3) + + val captor = ArgumentCaptor.forClass(classOf[StatusRuntimeException]) + val order = inOrder(delegate) + order.verify(delegate, times(1)).onNext(1) + order.verify(delegate, times(1)).onNext(2) + order.verify(delegate, times(1)).onError(captor.capture()) + order.verifyNoMoreInteractions() + // Scheduled task is cancelled + verify(cancellableMock, times(1)).cancel() + assertError( + actual = captor.getValue, + expectedF = LedgerApiErrors.AuthorizationChecks.StaleUserManagementBasedStreamClaims + .Reject()(_) + .asGrpcError, + ) + + // onError has already been called by tested implementation so subsequent onNext, onError and onComplete + // must not be forwarded to the delegate observer + tested.onNext(4) + tested.onError(new RuntimeException) + tested.onCompleted() + verifyNoMoreInteractions(delegate) + + succeed + } + +} diff --git a/ledger/ledger-api-common/src/test/suite/scala/com/digitalasset/platform/server/api/validation/ErrorFactoriesSpec.scala b/ledger/ledger-api-common/src/test/suite/scala/com/digitalasset/platform/server/api/validation/ErrorFactoriesSpec.scala index 62b93b5ef5..eef8a800cf 100644 --- a/ledger/ledger-api-common/src/test/suite/scala/com/digitalasset/platform/server/api/validation/ErrorFactoriesSpec.scala +++ b/ledger/ledger-api-common/src/test/suite/scala/com/digitalasset/platform/server/api/validation/ErrorFactoriesSpec.scala @@ -13,8 +13,8 @@ import com.daml.error.utils.ErrorDetails import com.daml.error.{ ContextualizedErrorLogger, DamlContextualizedErrorLogger, + ErrorAssertionsWithLogCollectorAssertions, ErrorCodesVersionSwitcher, - ErrorsAssertions, } import com.daml.ledger.api.domain.LedgerId import com.daml.lf.data.Ref @@ -31,6 +31,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.matchers.should.Matchers import org.scalatest.prop.TableDrivenPropertyChecks import org.scalatest.wordspec.AnyWordSpec +import scala.concurrent.duration._ import scala.annotation.nowarn import scala.jdk.CollectionConverters._ @@ -43,7 +44,7 @@ class ErrorFactoriesSpec with MockitoSugar with BeforeAndAfter with LogCollectorAssertions - with ErrorsAssertions { + with ErrorAssertionsWithLogCollectorAssertions { private val logger = ContextualizedLogger.get(getClass) private val loggingContext = LoggingContext.ForTesting @@ -83,7 +84,7 @@ class ErrorFactoriesSpec expectedMessage = msg, expectedDetails = Seq[ErrorDetails.ErrorDetail]( expectedCorrelationIdRequestInfo, - ErrorDetails.RetryInfoDetail(1), + ErrorDetails.RetryInfoDetail(1.second), ErrorDetails.ErrorInfoDetail( "INDEX_DB_SQL_TRANSIENT_ERROR", Map("category" -> "1", "definite_answer" -> "false"), @@ -161,7 +162,7 @@ class ErrorFactoriesSpec ), ), expectedCorrelationIdRequestInfo, - ErrorDetails.RetryInfoDetail(1), + ErrorDetails.RetryInfoDetail(1.second), ), v2_logEntry = ExpectedLogEntry( Level.WARN, @@ -194,7 +195,7 @@ class ErrorFactoriesSpec ), ), expectedCorrelationIdRequestInfo, - ErrorDetails.RetryInfoDetail(1), + ErrorDetails.RetryInfoDetail(1.second), ), v2_logEntry = ExpectedLogEntry( Level.INFO, @@ -223,7 +224,7 @@ class ErrorFactoriesSpec Map("category" -> "3", "definite_answer" -> "false"), ), expectedCorrelationIdRequestInfo, - ErrorDetails.RetryInfoDetail(1), + ErrorDetails.RetryInfoDetail(1.second), ), v2_logEntry = ExpectedLogEntry( Level.INFO, @@ -399,7 +400,7 @@ class ErrorFactoriesSpec Map("category" -> "3", "definite_answer" -> "false"), ), expectedCorrelationIdRequestInfo, - ErrorDetails.RetryInfoDetail(1), + ErrorDetails.RetryInfoDetail(1.second), ), v2_logEntry = ExpectedLogEntry( Level.INFO, @@ -742,7 +743,7 @@ class ErrorFactoriesSpec Map("category" -> "1", "definite_answer" -> "false", "service_name" -> serviceName), ), expectedCorrelationIdRequestInfo, - ErrorDetails.RetryInfoDetail(1), + ErrorDetails.RetryInfoDetail(1.second), ), v2_logEntry = ExpectedLogEntry( Level.INFO, diff --git a/ledger/ledger-api-test-tool/BUILD.bazel b/ledger/ledger-api-test-tool/BUILD.bazel index 7aaaa94c4d..52da18f6c6 100644 --- a/ledger/ledger-api-test-tool/BUILD.bazel +++ b/ledger/ledger-api-test-tool/BUILD.bazel @@ -107,12 +107,19 @@ da_scala_binary( srcs = suites_sources(lf_version), scala_deps = [ "@maven//:com_chuusai_shapeless", + "@maven//:com_typesafe_akka_akka_actor", + "@maven//:com_typesafe_akka_akka_stream", ], scaladoc = False, visibility = [ "//:__subpackages__", ], deps = [ + "//ledger-api/rs-grpc-bridge", + "//ledger-service/jwt", + "//ledger/ledger-api-client", + "//ledger/ledger-api-domain", + "@maven//:com_auth0_java_jwt", ":ledger-api-test-tool-%s-lib" % lf_version, "//daml-lf/data", "//language-support/scala/bindings", diff --git a/ledger/ledger-api-test-tool/src/main/scala/com/daml/ledger/api/testtool/infrastructure/Assertions.scala b/ledger/ledger-api-test-tool/src/main/scala/com/daml/ledger/api/testtool/infrastructure/Assertions.scala index 18b7b5e57f..b09d75b2ff 100644 --- a/ledger/ledger-api-test-tool/src/main/scala/com/daml/ledger/api/testtool/infrastructure/Assertions.scala +++ b/ledger/ledger-api-test-tool/src/main/scala/com/daml/ledger/api/testtool/infrastructure/Assertions.scala @@ -187,15 +187,15 @@ object Assertions { ) ) val expectedErrorId = expectedErrorCode.id - val expectedRetryabilitySeconds = expectedErrorCode.category.retryable.map(_.duration.toSeconds) + val expectedRetryability = expectedErrorCode.category.retryable.map(_.duration) val actualStatusCode = status.getCode val actualErrorDetails = ErrorDetails.from(status.getDetailsList.asScala.toSeq) val actualErrorId = actualErrorDetails .collectFirst { case err: ErrorDetails.ErrorInfoDetail => err.reason } .getOrElse(fail("Actual error id is not defined")) - val actualRetryabilitySeconds = actualErrorDetails - .collectFirst { case err: ErrorDetails.RetryInfoDetail => err.retryDelayInSeconds } + val actualRetryability = actualErrorDetails + .collectFirst { case err: ErrorDetails.RetryInfoDetail => err.duration } if (actualErrorId != expectedErrorId) fail(s"Actual error id ($actualErrorId) does not match expected error id ($expectedErrorId}") @@ -208,8 +208,8 @@ object Assertions { Assertions.assertEquals( s"Error retryability details mismatch", - actualRetryabilitySeconds, - expectedRetryabilitySeconds, + actualRetryability, + expectedRetryability, ) } diff --git a/ledger/participant-integration-api/src/main/scala/platform/apiserver/StandaloneApiServer.scala b/ledger/participant-integration-api/src/main/scala/platform/apiserver/StandaloneApiServer.scala index d63e45b478..1adc49bc25 100644 --- a/ledger/participant-integration-api/src/main/scala/platform/apiserver/StandaloneApiServer.scala +++ b/ledger/participant-integration-api/src/main/scala/platform/apiserver/StandaloneApiServer.scala @@ -27,6 +27,7 @@ import com.daml.platform.configuration.{ SubmissionConfiguration, } import com.daml.platform.services.time.TimeProviderType +import com.daml.platform.usermanagement.UserManagementConfig import com.daml.ports.{Port, PortFiles} import com.daml.telemetry.TelemetryContext import io.grpc.{BindableService, ServerInterceptor} @@ -59,6 +60,7 @@ object StandaloneApiServer { checkOverloaded: TelemetryContext => Option[state.SubmissionResult] = _ => None, // Used for Canton rate-limiting, ledgerFeatures: LedgerFeatures, + userManagementConfig: UserManagementConfig, )(implicit actorSystem: ActorSystem, materializer: Materializer, @@ -86,6 +88,10 @@ object StandaloneApiServer { ledgerId, participantId, errorCodesVersionSwitcher, + userManagementStore, + servicesExecutionContext, + userRightsCheckIntervalInSeconds = userManagementConfig.cacheExpiryAfterWriteInSeconds, + akkaScheduler = actorSystem.scheduler, ) val healthChecksWithIndexService = healthChecks + ("index" -> indexService) diff --git a/ledger/participant-integration-api/src/main/scala/platform/apiserver/services/admin/ApiUserManagementService.scala b/ledger/participant-integration-api/src/main/scala/platform/apiserver/services/admin/ApiUserManagementService.scala index 9e3e258df5..5b46e9e80d 100644 --- a/ledger/participant-integration-api/src/main/scala/platform/apiserver/services/admin/ApiUserManagementService.scala +++ b/ledger/participant-integration-api/src/main/scala/platform/apiserver/services/admin/ApiUserManagementService.scala @@ -23,7 +23,7 @@ import scalaz.std.list._ import scala.concurrent.{ExecutionContext, Future} private[apiserver] final class ApiUserManagementService( - userManagementService: UserManagementStore, + userManagementStore: UserManagementStore, errorCodesVersionSwitcher: ErrorCodesVersionSwitcher, )(implicit executionContext: ExecutionContext, @@ -53,7 +53,7 @@ private[apiserver] final class ApiUserManagementService( pRights <- fromProtoRights(request.rights) } yield (User(pUserId, pOptPrimaryParty), pRights) } { case (user, pRights) => - userManagementService + userManagementStore .createUser( user = user, rights = pRights, @@ -66,7 +66,7 @@ private[apiserver] final class ApiUserManagementService( withValidation( requireUserId(request.userId, "user_id") )(userId => - userManagementService + userManagementStore .getUser(userId) .flatMap(handleResult("getting user")) .map(toProtoUser) @@ -76,14 +76,14 @@ private[apiserver] final class ApiUserManagementService( withValidation( requireUserId(request.userId, "user_id") )(userId => - userManagementService + userManagementStore .deleteUser(userId) .flatMap(handleResult("deleting user")) .map(_ => proto.DeleteUserResponse()) ) override def listUsers(request: proto.ListUsersRequest): Future[proto.ListUsersResponse] = - userManagementService + userManagementStore .listUsers() .flatMap(handleResult("listing users")) .map( @@ -100,7 +100,7 @@ private[apiserver] final class ApiUserManagementService( rights <- fromProtoRights(request.rights) } yield (userId, rights) ) { case (userId, rights) => - userManagementService + userManagementStore .grantRights( id = userId, rights = rights, @@ -119,7 +119,7 @@ private[apiserver] final class ApiUserManagementService( rights <- fromProtoRights(request.rights) } yield (userId, rights) ) { case (userId, rights) => - userManagementService + userManagementStore .revokeRights( id = userId, rights = rights, @@ -135,7 +135,7 @@ private[apiserver] final class ApiUserManagementService( withValidation( requireUserId(request.userId, "user_id") )(userId => - userManagementService + userManagementStore .listUserRights(userId) .flatMap(handleResult("list user rights")) .map(_.view.map(toProtoRight).toList) diff --git a/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Config.scala b/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Config.scala index 5e3c7a0cb5..cc529d4b27 100644 --- a/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Config.scala +++ b/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Config.scala @@ -654,8 +654,8 @@ object Config { .optional() .text( s"Defaults to ${UserManagementConfig.DefaultCacheExpiryAfterWriteInSeconds} seconds. " + - // TODO participant user management: Update max delay to 2x the configured value when made use of in throttled stream authorization. - "Determines the maximum delay for propagating user management state changes." + "Used to set expiry time for user management cache. " + + "Also determines the maximum delay for propagating user management state changes which is double its value." ) .action((value, config: Config[Extra]) => config.withUserManagementConfig(_.copy(cacheExpiryAfterWriteInSeconds = value)) diff --git a/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Runner.scala b/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Runner.scala index e8d1eac8e1..d6b026a8e4 100644 --- a/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Runner.scala +++ b/ledger/participant-state/kvutils/app/src/main/scala/com/daml/ledger/participant/state/kvutils/app/Runner.scala @@ -258,6 +258,7 @@ final class Runner[T <: ReadWriteService, Extra]( v1 = ExperimentalContractIds.ContractIdV1Support.NON_SUFFIXED ), ), + userManagementConfig = config.userManagementConfig, ).acquire() } yield Some(apiServer.port) case ParticipantRunMode.Indexer => diff --git a/ledger/sandbox-classic/BUILD.bazel b/ledger/sandbox-classic/BUILD.bazel index 84230cc1e5..c2ae8276a5 100644 --- a/ledger/sandbox-classic/BUILD.bazel +++ b/ledger/sandbox-classic/BUILD.bazel @@ -204,6 +204,7 @@ test_deps = [ "//libs-scala/postgresql-testing", "//libs-scala/resources", "//libs-scala/timer-utils", + "//ledger/error:error-test-lib", "@maven//:ch_qos_logback_logback_classic", "@maven//:ch_qos_logback_logback_core", "@maven//:com_typesafe_config", diff --git a/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/auth/ExpiringStreamServiceCallAuthTests.scala b/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/auth/ExpiringStreamServiceCallAuthTests.scala index c83d03a37d..da50b1e0d2 100644 --- a/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/auth/ExpiringStreamServiceCallAuthTests.scala +++ b/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/auth/ExpiringStreamServiceCallAuthTests.scala @@ -51,16 +51,16 @@ trait ExpiringStreamServiceCallAuthTests[T] toHeader(expiringIn(Duration.ofSeconds(5), readOnlyToken(mainActor))) it should "break a stream in flight upon read-only token expiration" in { - val _ = Delayed.Future.by(10.seconds)(submitAndWait()) + val _ = Delayed.Future.by(10.seconds)(submitAndWaitAsMainActor()) expectExpiration(canReadAsMainActorExpiresInFiveSeconds).map(_ => succeed) } it should "break a stream in flight upon read/write token expiration" in { - val _ = Delayed.Future.by(10.seconds)(submitAndWait()) + val _ = Delayed.Future.by(10.seconds)(submitAndWaitAsMainActor()) expectExpiration(canActAsMainActorExpiresInFiveSeconds).map(_ => succeed) } override def serviceCallWithToken(token: Option[String]): Future[Any] = - submitAndWait().flatMap(_ => new StreamConsumer[T](stream(token)).first()) + submitAndWaitAsMainActor().flatMap(_ => new StreamConsumer[T](stream(token)).first()) } diff --git a/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/services/SubmitAndWaitDummyCommand.scala b/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/services/SubmitAndWaitDummyCommand.scala index 9acaa0dc73..5d57cc4b0e 100644 --- a/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/services/SubmitAndWaitDummyCommand.scala +++ b/ledger/sandbox-classic/src/test/lib/scala/platform/sandbox/services/SubmitAndWaitDummyCommand.scala @@ -6,20 +6,33 @@ package com.daml.platform.sandbox.services import java.util.UUID import com.daml.ledger.api.v1.command_service.{CommandServiceGrpc, SubmitAndWaitRequest} -import com.daml.platform.sandbox.auth.ServiceCallWithMainActorAuthTests +import com.daml.platform.sandbox.auth.{ServiceCallAuthTests, ServiceCallWithMainActorAuthTests} import com.google.protobuf.empty.Empty import scala.concurrent.Future -trait SubmitAndWaitDummyCommand extends TestCommands { self: ServiceCallWithMainActorAuthTests => +trait SubmitAndWaitDummyCommand extends TestCommands with SubmitAndWaitDummyCommandHelpers { + self: ServiceCallWithMainActorAuthTests => - protected def submitAndWait(): Future[Empty] = - submitAndWait(Option(toHeader(readWriteToken(mainActor)))) + protected def submitAndWaitAsMainActor(): Future[Empty] = + submitAndWait( + Option(toHeader(readWriteToken(mainActor))), + applicationId = serviceCallName, + party = mainActor, + ) - protected def dummySubmitAndWaitRequest(applicationId: String): SubmitAndWaitRequest = +} + +trait SubmitAndWaitDummyCommandHelpers extends TestCommands { + self: ServiceCallAuthTests => + + protected def dummySubmitAndWaitRequest( + applicationId: String, + party: String, + ): SubmitAndWaitRequest = SubmitAndWaitRequest( - dummyCommands(wrappedLedgerId, s"$serviceCallName-${UUID.randomUUID}", mainActor) - .update(_.commands.applicationId := applicationId, _.commands.party := mainActor) + dummyCommands(wrappedLedgerId, s"$serviceCallName-${UUID.randomUUID}", party = party) + .update(_.commands.applicationId := applicationId, _.commands.party := party) .commands ) @@ -29,31 +42,35 @@ trait SubmitAndWaitDummyCommand extends TestCommands { self: ServiceCallWithMain protected def submitAndWait( token: Option[String], applicationId: String = serviceCallName, + party: String, ): Future[Empty] = - service(token).submitAndWait(dummySubmitAndWaitRequest(applicationId)) + service(token).submitAndWait(dummySubmitAndWaitRequest(applicationId, party = party)) protected def submitAndWaitForTransaction( token: Option[String], applicationId: String = serviceCallName, + party: String, ): Future[Empty] = service(token) - .submitAndWaitForTransaction(dummySubmitAndWaitRequest(applicationId)) + .submitAndWaitForTransaction(dummySubmitAndWaitRequest(applicationId, party = party)) .map(_ => Empty()) protected def submitAndWaitForTransactionId( token: Option[String], applicationId: String = serviceCallName, + party: String, ): Future[Empty] = service(token) - .submitAndWaitForTransactionId(dummySubmitAndWaitRequest(applicationId)) + .submitAndWaitForTransactionId(dummySubmitAndWaitRequest(applicationId, party = party)) .map(_ => Empty()) protected def submitAndWaitForTransactionTree( token: Option[String], applicationId: String = serviceCallName, + party: String, ): Future[Empty] = service(token) - .submitAndWaitForTransactionTree(dummySubmitAndWaitRequest(applicationId)) + .submitAndWaitForTransactionTree(dummySubmitAndWaitRequest(applicationId, party = party)) .map(_ => Empty()) } diff --git a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/CompletionStreamAuthIT.scala b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/CompletionStreamAuthIT.scala index c6391fc66c..16778a3521 100644 --- a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/CompletionStreamAuthIT.scala +++ b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/CompletionStreamAuthIT.scala @@ -42,7 +42,7 @@ final class CompletionStreamAuthIT override protected def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] = // Note: the token must allow actAs mainActor for this call to work. - submitAndWait(token, "").flatMap(_ => + submitAndWait(token, "", party = mainActor).flatMap(_ => new StreamConsumer[CompletionStreamResponse](streamFor("")(token)).first() ) diff --git a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/OngoingStreamAuthIT.scala b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/OngoingStreamAuthIT.scala new file mode 100644 index 0000000000..f455b0972f --- /dev/null +++ b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/OngoingStreamAuthIT.scala @@ -0,0 +1,146 @@ +// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.daml.platform.sandbox.auth + +import java.util.concurrent.atomic.AtomicInteger +import java.util.{Timer, TimerTask, UUID} + +import com.daml.error.ErrorsAssertions +import com.daml.error.utils.ErrorDetails +import com.daml.ledger.api.v1.admin.user_management_service.Right +import com.daml.ledger.api.v1.transaction_filter.{Filters, TransactionFilter} +import com.daml.ledger.api.v1.transaction_service.{ + GetTransactionsRequest, + GetTransactionsResponse, + TransactionServiceGrpc, +} +import com.daml.platform.sandbox.config.SandboxConfig +import com.daml.platform.sandbox.services.SubmitAndWaitDummyCommandHelpers +import io.grpc.stub.StreamObserver +import io.grpc.{Status, StatusRuntimeException} +import com.daml.ledger.api.v1.admin.{user_management_service => user_management_service_proto} +import scala.concurrent.duration._ + +import scala.concurrent.{Future, Promise} + +final class OngoingStreamAuthIT + extends ServiceCallAuthTests + with SubmitAndWaitDummyCommandHelpers + with ErrorsAssertions { + + private val UserManagementCacheExpiryInSeconds = 1 + + override protected def config: SandboxConfig = super.config.withUserManagementConfig( + _.copy(cacheExpiryAfterWriteInSeconds = UserManagementCacheExpiryInSeconds) + ) + + override def serviceCallName: String = "" + + override protected def serviceCallWithToken(token: Option[String]): Future[Any] = ??? + + private val testId = UUID.randomUUID().toString + + it should "abort an ongoing stream after user state has changed" in { + val partyAlice = "alice-party" + val userIdAlice = testId + "-alice" + val receivedTransactionsCount = new AtomicInteger(0) + val transactionStreamAbortedPromise = Promise[Throwable]() + + def observeTransactionsStream( + token: Option[String], + party: String, + ): StreamObserver[GetTransactionsResponse] = { + val observer = new StreamObserver[GetTransactionsResponse] { + override def onNext(value: GetTransactionsResponse): Unit = { + val _ = receivedTransactionsCount.incrementAndGet() + } + + override def onError(t: Throwable): Unit = { + val _ = transactionStreamAbortedPromise.trySuccess(t) + } + + override def onCompleted(): Unit = () + } + val request = new GetTransactionsRequest( + begin = Option(ledgerBegin), + end = None, + filter = Some( + new TransactionFilter( + Map(party -> new Filters) + ) + ), + ) + stub(TransactionServiceGrpc.stub(channel), token) + .getTransactions(request, observer) + observer + } + + val canActAsAlice = Right(Right.Kind.CanActAs(Right.CanActAs(partyAlice))) + for { + (userAlice, tokenAlice) <- createUserByAdmin( + userId = userIdAlice, + rights = Vector(canActAsAlice), + ) + applicationId = userAlice.id + submitAndWaitF = () => + submitAndWait(token = tokenAlice, party = partyAlice, applicationId = applicationId) + _ <- submitAndWaitF() + streamObserver = observeTransactionsStream(tokenAlice, partyAlice) + _ <- submitAndWaitF() + // Making a change to the user Alice + _ <- grantUserRightsByAdmin( + userId = userIdAlice, + Right(Right.Kind.CanActAs(Right.CanActAs(UUID.randomUUID().toString))), + ) + _ = Thread.sleep(UserManagementCacheExpiryInSeconds.toLong * 1000) + // + // Timer that finishes the stream in case it isn't aborted as expected + timerTask = new TimerTask { + override def run(): Unit = streamObserver.onError( + new AssertionError("Timed-out waiting while waiting for stream to abort") + ) + } + _ = new Timer(true).schedule(timerTask, 100) + t <- transactionStreamAbortedPromise.future + } yield { + timerTask.cancel() + t match { + case sre: StatusRuntimeException => + assertError( + actual = sre, + expectedCode = Status.Code.ABORTED, + expectedMessage = + "STALE_STREAM_AUTHORIZATION(2,0): Stale stream authorization. Retry quickly.", + expectedDetails = List( + ErrorDetails.ErrorInfoDetail( + "STALE_STREAM_AUTHORIZATION", + Map( + "participantId" -> "'sandbox-participant'", + "category" -> "2", + "definite_answer" -> "false", + ), + ), + ErrorDetails.RetryInfoDetail(0.seconds), + ), + ) + case _ => fail("Unexpected error", t) + } + assert(receivedTransactionsCount.get() >= 2) + } + } + + private def grantUserRightsByAdmin( + userId: String, + right: user_management_service_proto.Right, + ): Future[Unit] = { + val req = user_management_service_proto.GrantUserRightsRequest(userId, Seq(right)) + stub( + user_management_service_proto.UserManagementServiceGrpc.stub(channel), + canReadAsAdminStandardJWT, + ) + .grantUserRights(req) + .map(_ => ()) + } + +} diff --git a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitAuthIT.scala b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitAuthIT.scala index e1ee313bde..b6cee7632f 100644 --- a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitAuthIT.scala +++ b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitAuthIT.scala @@ -14,8 +14,8 @@ final class SubmitAndWaitAuthIT override def serviceCallName: String = "CommandService#SubmitAndWait" override def serviceCallWithToken(token: Option[String]): Future[Any] = - submitAndWait(token) + submitAndWait(token, party = mainActor) override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] = - submitAndWait(token, "") + submitAndWait(token, "", party = mainActor) } diff --git a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionAuthIT.scala b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionAuthIT.scala index 3c3d2c5d0b..9078485ca0 100644 --- a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionAuthIT.scala +++ b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionAuthIT.scala @@ -14,9 +14,9 @@ final class SubmitAndWaitForTransactionAuthIT override def serviceCallName: String = "CommandService#SubmitAndWaitForTransaction" override def serviceCallWithToken(token: Option[String]): Future[Any] = - submitAndWaitForTransaction(token) + submitAndWaitForTransaction(token, party = mainActor) override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] = - submitAndWaitForTransaction(token, "") + submitAndWaitForTransaction(token, "", party = mainActor) } diff --git a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionIdAuthIT.scala b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionIdAuthIT.scala index 4ee76df002..d533307f87 100644 --- a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionIdAuthIT.scala +++ b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionIdAuthIT.scala @@ -14,9 +14,9 @@ final class SubmitAndWaitForTransactionIdAuthIT override def serviceCallName: String = "CommandService#SubmitAndWaitForTransactionId" override def serviceCallWithToken(token: Option[String]): Future[Any] = - submitAndWaitForTransactionId(token) + submitAndWaitForTransactionId(token, party = mainActor) override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] = - submitAndWaitForTransactionId(token, "") + submitAndWaitForTransactionId(token, "", party = mainActor) } diff --git a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionTreeAuthIT.scala b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionTreeAuthIT.scala index deb45a110a..069b666b64 100644 --- a/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionTreeAuthIT.scala +++ b/ledger/sandbox-classic/src/test/suite/scala/platform/sandbox/auth/SubmitAndWaitForTransactionTreeAuthIT.scala @@ -14,9 +14,9 @@ final class SubmitAndWaitForTransactionTreeAuthIT override def serviceCallName: String = "CommandService#SubmitAndWaitForTransactionTree" override def serviceCallWithToken(token: Option[String]): Future[Any] = - submitAndWaitForTransactionTree(token) + submitAndWaitForTransactionTree(token, party = mainActor) override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] = - submitAndWaitForTransactionTree(token, "") + submitAndWaitForTransactionTree(token, "", party = mainActor) } diff --git a/ledger/sandbox-common/src/main/scala/platform/sandbox/cli/CommonCliBase.scala b/ledger/sandbox-common/src/main/scala/platform/sandbox/cli/CommonCliBase.scala index 3b0a09502e..1d7d458e07 100644 --- a/ledger/sandbox-common/src/main/scala/platform/sandbox/cli/CommonCliBase.scala +++ b/ledger/sandbox-common/src/main/scala/platform/sandbox/cli/CommonCliBase.scala @@ -406,8 +406,8 @@ class CommonCliBase(name: LedgerName) { .optional() .text( s"Defaults to ${UserManagementConfig.DefaultCacheExpiryAfterWriteInSeconds} seconds. " + - // TODO participant user management: Update max delay to 2x the configured value when made use of in throttled stream authorization. - "Determines the maximum delay for propagating user management state changes." + "Used to set expiry time for user management cache. " + + "Also determines the maximum delay for propagating user management state changes which is double its value." ) .action((value, config: SandboxConfig) => config.withUserManagementConfig(_.copy(cacheExpiryAfterWriteInSeconds = value)) diff --git a/ledger/sandbox-on-x/src/main/scala/com/daml/ledger/sandbox/SandboxOnXRunner.scala b/ledger/sandbox-on-x/src/main/scala/com/daml/ledger/sandbox/SandboxOnXRunner.scala index 132ce69e3d..33908eeb8c 100644 --- a/ledger/sandbox-on-x/src/main/scala/com/daml/ledger/sandbox/SandboxOnXRunner.scala +++ b/ledger/sandbox-on-x/src/main/scala/com/daml/ledger/sandbox/SandboxOnXRunner.scala @@ -264,6 +264,7 @@ object SandboxOnXRunner { v1 = ExperimentalContractIds.ContractIdV1Support.NON_SUFFIXED ), ), + userManagementConfig = config.userManagementConfig, ) private def buildIndexerServer( diff --git a/ledger/sandbox/src/main/scala/platform/sandboxnext/Runner.scala b/ledger/sandbox/src/main/scala/platform/sandboxnext/Runner.scala index 49dcea21e2..a54d421d63 100644 --- a/ledger/sandbox/src/main/scala/platform/sandboxnext/Runner.scala +++ b/ledger/sandbox/src/main/scala/platform/sandboxnext/Runner.scala @@ -291,6 +291,7 @@ class Runner(config: SandboxConfig) extends ResourceOwner[Port] { v1 = ExperimentalContractIds.ContractIdV1Support.NON_SUFFIXED ), ), + userManagementConfig = config.userManagementConfig, ) _ = apiServerServicesClosed.completeWith(apiServer.servicesClosed()) } yield {