mirror of
https://github.com/digital-asset/daml.git
synced 2024-09-20 01:07:18 +03:00
[User management] Terminate ongoing streams when user state has changed [DPP-830] (#12437)
CHANGELOG_BEGIN Ledger API Specification: When using user management based authorization streams will now get aborted on authenticated user's rights change. CHANGELOG_END
This commit is contained in:
parent
35eae895e4
commit
c72c27c967
@ -67,6 +67,7 @@ da_scala_library(
|
||||
"@maven//:org_scalatest_scalatest_core",
|
||||
"@maven//:org_scalatest_scalatest_matchers_core",
|
||||
"@maven//:org_scalatest_scalatest_shouldmatchers",
|
||||
"@maven//:com_typesafe_akka_akka_actor",
|
||||
],
|
||||
deps = [
|
||||
":bindings-rxjava",
|
||||
|
@ -8,6 +8,7 @@ import java.net.{InetSocketAddress, SocketAddress}
|
||||
import java.time.{Clock, Duration}
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import com.daml.ledger.rxjava.grpc._
|
||||
import com.daml.ledger.rxjava.grpc.helpers.TransactionsServiceImpl.LedgerItem
|
||||
import com.daml.ledger.rxjava.{CommandCompletionClient, LedgerConfigurationClient, PackageClient}
|
||||
@ -46,6 +47,7 @@ final class LedgerServices(val ledgerId: String) {
|
||||
|
||||
val executionContext: ExecutionContext = global
|
||||
private val esf: ExecutionSequencerFactory = new SingleThreadExecutionSequencerPool(ledgerId)
|
||||
private val akkaSystem = ActorSystem("LedgerServicesParticipant")
|
||||
private val participantId = "LedgerServicesParticipant"
|
||||
private val authorizer =
|
||||
Authorizer(
|
||||
@ -53,6 +55,10 @@ final class LedgerServices(val ledgerId: String) {
|
||||
ledgerId,
|
||||
participantId,
|
||||
new ErrorCodesVersionSwitcher(enableSelfServiceErrorCodes = true),
|
||||
new InMemoryUserManagementStore(),
|
||||
executionContext,
|
||||
userRightsCheckIntervalInSeconds = 1,
|
||||
akkaScheduler = akkaSystem.scheduler,
|
||||
)
|
||||
|
||||
def newServerBuilder(): NettyServerBuilder = NettyServerBuilder.forAddress(nextAddress())
|
||||
|
@ -4,9 +4,12 @@
|
||||
package com.daml.ledger
|
||||
|
||||
import com.daml.error.ErrorCodesVersionSwitcher
|
||||
|
||||
import java.time.Clock
|
||||
import java.util.UUID
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
|
||||
import scala.concurrent.ExecutionContext
|
||||
import com.daml.lf.data.Ref
|
||||
import com.daml.ledger.api.auth.{
|
||||
AuthServiceStatic,
|
||||
@ -18,11 +21,14 @@ import com.daml.ledger.api.auth.{
|
||||
ClaimReadAsParty,
|
||||
ClaimSet,
|
||||
}
|
||||
import com.daml.ledger.participant.state.index.impl.inmemory.InMemoryUserManagementStore
|
||||
|
||||
package object rxjava {
|
||||
|
||||
private[rxjava] def untestedEndpoint: Nothing =
|
||||
throw new UnsupportedOperationException("Untested endpoint, implement if needed")
|
||||
private val akkaSystem = ActorSystem("testActorSystem")
|
||||
sys.addShutdownHook(akkaSystem.terminate(): Unit)
|
||||
|
||||
private[rxjava] val authorizer =
|
||||
Authorizer(
|
||||
@ -30,6 +36,10 @@ package object rxjava {
|
||||
"testLedgerId",
|
||||
"testParticipantId",
|
||||
new ErrorCodesVersionSwitcher(enableSelfServiceErrorCodes = true),
|
||||
new InMemoryUserManagementStore(),
|
||||
ExecutionContext.parasitic,
|
||||
userRightsCheckIntervalInSeconds = 1,
|
||||
akkaScheduler = akkaSystem.scheduler,
|
||||
)
|
||||
|
||||
private[rxjava] val emptyToken = "empty"
|
||||
|
@ -65,6 +65,8 @@ da_scala_library(
|
||||
deps = [
|
||||
"//ledger/error",
|
||||
"//ledger/test-common",
|
||||
"//libs-scala/contextualized-logging",
|
||||
"//libs-scala/scala-utils",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:ch_qos_logback_logback_core",
|
||||
"@maven//:com_google_api_grpc_proto_google_common_protos",
|
||||
|
@ -16,9 +16,10 @@ import com.daml.lf.transaction.GlobalKey
|
||||
import com.daml.lf.value.Value
|
||||
import com.daml.lf.{VersionRange, language}
|
||||
import org.slf4j.event.Level
|
||||
|
||||
import java.time.{Duration, Instant}
|
||||
|
||||
import scala.concurrent.duration._
|
||||
|
||||
@Explanation(
|
||||
"Errors raised by or forwarded by the Ledger API."
|
||||
)
|
||||
@ -287,8 +288,30 @@ object LedgerApiErrors extends LedgerApiErrorGroup {
|
||||
}
|
||||
}
|
||||
|
||||
@Explanation("Authentication errors.")
|
||||
@Explanation("Authentication and authorization errors.")
|
||||
object AuthorizationChecks extends ErrorGroup() {
|
||||
|
||||
@Explanation("""The stream was aborted because the authenticated user's rights changed,
|
||||
|and the user might thus no longer be authorized to this stream.
|
||||
|""")
|
||||
@Resolution(
|
||||
"The application should automatically retry fetching the stream. It will either succeed, or fail with an explicit denial of authentication or permission."
|
||||
)
|
||||
object StaleUserManagementBasedStreamClaims
|
||||
extends ErrorCode(
|
||||
id = "STALE_STREAM_AUTHORIZATION",
|
||||
ErrorCategory.ContentionOnSharedResources,
|
||||
) {
|
||||
case class Reject()(implicit
|
||||
loggingContext: ContextualizedErrorLogger
|
||||
) extends LoggingTransactionErrorImpl("Stale stream authorization. Retry quickly.") {
|
||||
override def retryable: Option[ErrorCategoryRetry] = Some(
|
||||
ErrorCategoryRetry(who = "application", duration = 0.seconds)
|
||||
)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@Explanation(
|
||||
"""This rejection is given if the submitted command does not contain a JWT token on a participant enforcing JWT authentication."""
|
||||
)
|
||||
|
@ -7,6 +7,7 @@ import com.google.protobuf
|
||||
import com.google.rpc.{ErrorInfo, RequestInfo, ResourceInfo, RetryInfo}
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
import scala.concurrent.duration._
|
||||
|
||||
object ErrorDetails {
|
||||
sealed trait ErrorDetail extends Product with Serializable
|
||||
@ -14,7 +15,7 @@ object ErrorDetails {
|
||||
final case class ResourceInfoDetail(name: String, typ: String) extends ErrorDetail
|
||||
final case class ErrorInfoDetail(reason: String, metadata: Map[String, String])
|
||||
extends ErrorDetail
|
||||
final case class RetryInfoDetail(retryDelayInSeconds: Long) extends ErrorDetail
|
||||
final case class RetryInfoDetail(duration: Duration) extends ErrorDetail
|
||||
final case class RequestInfoDetail(requestId: String) extends ErrorDetail
|
||||
|
||||
def from(anys: Seq[protobuf.Any]): Seq[ErrorDetail] = anys.toList.map {
|
||||
@ -28,7 +29,9 @@ object ErrorDetails {
|
||||
|
||||
case any if any.is(classOf[RetryInfo]) =>
|
||||
val v = any.unpack(classOf[RetryInfo])
|
||||
RetryInfoDetail(v.getRetryDelay.getSeconds)
|
||||
val delay = v.getRetryDelay
|
||||
val duration = (delay.getSeconds.seconds + delay.getNanos.nanos).toCoarsest
|
||||
RetryInfoDetail(duration)
|
||||
|
||||
case any if any.is(classOf[RequestInfo]) =>
|
||||
val v = any.unpack(classOf[RequestInfo])
|
||||
|
@ -4,8 +4,10 @@
|
||||
package com.daml.error
|
||||
|
||||
import com.daml.error.utils.ErrorDetails
|
||||
import com.daml.logging.{ContextualizedLogger, LoggingContext}
|
||||
import com.daml.platform.testing.{LogCollector, LogCollectorAssertions}
|
||||
import com.daml.platform.testing.LogCollector.ExpectedLogEntry
|
||||
import com.daml.scalautil.Statement
|
||||
import io.grpc.Status.Code
|
||||
import io.grpc.StatusRuntimeException
|
||||
import io.grpc.protobuf.StatusProto
|
||||
@ -13,9 +15,40 @@ import org.scalatest.matchers.should.Matchers
|
||||
|
||||
import scala.jdk.CollectionConverters._
|
||||
import scala.reflect.ClassTag
|
||||
import org.scalatest.Checkpoints.Checkpoint
|
||||
|
||||
trait ErrorsAssertions {
|
||||
self: Matchers with LogCollectorAssertions =>
|
||||
self: Matchers =>
|
||||
|
||||
private val logger = ContextualizedLogger.get(getClass)
|
||||
private val loggingContext = LoggingContext.ForTesting
|
||||
private val errorLogger = new DamlContextualizedErrorLogger(logger, loggingContext, None)
|
||||
|
||||
def assertError(
|
||||
actual: StatusRuntimeException,
|
||||
expectedF: ContextualizedErrorLogger => StatusRuntimeException,
|
||||
): Unit = {
|
||||
assertError(
|
||||
actual = actual,
|
||||
expected = expectedF(errorLogger),
|
||||
)
|
||||
}
|
||||
|
||||
/** Asserts that the two errors have the same code, message and details.
|
||||
*/
|
||||
def assertError(
|
||||
actual: StatusRuntimeException,
|
||||
expected: StatusRuntimeException,
|
||||
): Unit = {
|
||||
val expectedStatus = StatusProto.fromThrowable(expected)
|
||||
val expectedDetails = expectedStatus.getDetailsList.asScala.toSeq
|
||||
assertError(
|
||||
actual = actual,
|
||||
expectedCode = expected.getStatus.getCode,
|
||||
expectedMessage = expectedStatus.getMessage,
|
||||
expectedDetails = ErrorDetails.from(expectedDetails),
|
||||
)
|
||||
}
|
||||
|
||||
def assertError(
|
||||
actual: StatusRuntimeException,
|
||||
@ -23,9 +56,26 @@ trait ErrorsAssertions {
|
||||
expectedMessage: String,
|
||||
expectedDetails: Seq[ErrorDetails.ErrorDetail],
|
||||
): Unit = {
|
||||
doAssertError(actual, expectedCode, expectedMessage, expectedDetails, None)
|
||||
val actualStatus = StatusProto.fromThrowable(actual)
|
||||
val actualDetails = actualStatus.getDetailsList.asScala.toSeq
|
||||
val cp = new Checkpoint
|
||||
cp { Statement.discard { actual.getStatus.getCode shouldBe expectedCode } }
|
||||
cp { Statement.discard { actualStatus.getMessage shouldBe expectedMessage } }
|
||||
cp {
|
||||
Statement.discard {
|
||||
ErrorDetails.from(actualDetails) should contain theSameElementsAs expectedDetails
|
||||
}
|
||||
}
|
||||
cp.reportAll()
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
trait ErrorAssertionsWithLogCollectorAssertions
|
||||
extends ErrorsAssertions
|
||||
with LogCollectorAssertions {
|
||||
self: Matchers =>
|
||||
|
||||
def assertError[Test, Logger](
|
||||
actual: StatusRuntimeException,
|
||||
expectedCode: Code,
|
||||
@ -36,32 +86,15 @@ trait ErrorsAssertions {
|
||||
test: ClassTag[Test],
|
||||
logger: ClassTag[Logger],
|
||||
): Unit = {
|
||||
doAssertError(actual, expectedCode, expectedMessage, expectedDetails, Some(expectedLogEntry))(
|
||||
test,
|
||||
logger,
|
||||
assertError(
|
||||
actual = actual,
|
||||
expectedCode = expectedCode,
|
||||
expectedMessage = expectedMessage,
|
||||
expectedDetails = expectedDetails,
|
||||
)
|
||||
}
|
||||
|
||||
private def doAssertError[Test, Logger](
|
||||
actual: StatusRuntimeException,
|
||||
expectedCode: Code,
|
||||
expectedMessage: String,
|
||||
expectedDetails: Seq[ErrorDetails.ErrorDetail],
|
||||
expectedLogEntry: Option[ExpectedLogEntry],
|
||||
)(implicit
|
||||
test: ClassTag[Test],
|
||||
logger: ClassTag[Logger],
|
||||
): Unit = {
|
||||
val status = StatusProto.fromThrowable(actual)
|
||||
status.getCode shouldBe expectedCode.value()
|
||||
status.getMessage shouldBe expectedMessage
|
||||
val details = status.getDetailsList.asScala.toSeq
|
||||
val _ = ErrorDetails.from(details) should contain theSameElementsAs expectedDetails
|
||||
if (expectedLogEntry.isDefined) {
|
||||
val actualLogs: Seq[LogCollector.Entry] = LogCollector.readAsEntries(test, logger)
|
||||
actualLogs should have size 1
|
||||
assertLogEntry(actualLogs.head, expectedLogEntry.get)
|
||||
}
|
||||
val actualLogs: Seq[LogCollector.Entry] = LogCollector.readAsEntries(test, logger)
|
||||
actualLogs should have size 1
|
||||
assertLogEntry(actualLogs.head, expectedLogEntry)
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -21,7 +21,7 @@ class ErrorCodeSpec
|
||||
with Matchers
|
||||
with BeforeAndAfter
|
||||
with LogCollectorAssertions
|
||||
with ErrorsAssertions {
|
||||
with ErrorAssertionsWithLogCollectorAssertions {
|
||||
|
||||
implicit private val testLoggingContext: LoggingContext = LoggingContext.ForTesting
|
||||
private val logger = ContextualizedLogger.get(getClass)
|
||||
@ -85,7 +85,7 @@ class ErrorCodeSpec
|
||||
NotSoSeriousError.id,
|
||||
Map("category" -> "1") ++ contextMetadata ++ Map("definite_answer" -> "true"),
|
||||
),
|
||||
ErrorDetails.RetryInfoDetail(TransientServerFailure.retryable.get.duration.toSeconds),
|
||||
ErrorDetails.RetryInfoDetail(TransientServerFailure.retryable.get.duration),
|
||||
ErrorDetails.RequestInfoDetail(correlationId),
|
||||
ErrorDetails.ResourceInfoDetail(error.resources.head._1.asString, error.resources.head._2),
|
||||
),
|
||||
|
@ -15,6 +15,7 @@ da_scala_library(
|
||||
scala_deps = [
|
||||
"@maven//:io_spray_spray_json",
|
||||
"@maven//:org_scalaz_scalaz_core",
|
||||
"@maven//:com_typesafe_akka_akka_actor",
|
||||
],
|
||||
tags = ["maven_coordinates=com.daml:ledger-api-auth:__VERSION__"],
|
||||
visibility = [
|
||||
@ -63,17 +64,28 @@ da_scala_test_suite(
|
||||
"@maven//:org_scalatest_scalatest_shouldmatchers",
|
||||
"@maven//:org_scalatest_scalatest_wordspec",
|
||||
"@maven//:org_scalatestplus_scalacheck_1_15",
|
||||
"@maven//:com_typesafe_akka_akka_actor",
|
||||
"@maven//:com_typesafe_akka_akka_stream",
|
||||
],
|
||||
deps = [
|
||||
":ledger-api-auth",
|
||||
"//daml-lf/data",
|
||||
"//ledger-api/rs-grpc-bridge",
|
||||
"//ledger-api/testing-utils",
|
||||
"//ledger/error",
|
||||
"//ledger/error:error-test-lib",
|
||||
"//ledger/ledger-api-common",
|
||||
"//ledger/ledger-api-domain",
|
||||
"//ledger/participant-state-index",
|
||||
"//ledger/test-common",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/contextualized-logging",
|
||||
"@maven//:com_google_api_grpc_proto_google_common_protos",
|
||||
"@maven//:com_google_protobuf_protobuf_java",
|
||||
"@maven//:io_grpc_grpc_api",
|
||||
"@maven//:io_grpc_grpc_context",
|
||||
"@maven//:io_grpc_grpc_protobuf",
|
||||
"@maven//:io_grpc_grpc_stub",
|
||||
"@maven//:org_mockito_mockito_core",
|
||||
"@maven//:org_scalatest_scalatest_compatible",
|
||||
],
|
||||
|
@ -3,6 +3,10 @@
|
||||
|
||||
package com.daml.ledger.api.auth
|
||||
|
||||
import java.time.Instant
|
||||
|
||||
import akka.actor.Scheduler
|
||||
import com.daml.error.definitions.LedgerApiErrors
|
||||
import com.daml.error.{
|
||||
ContextualizedErrorLogger,
|
||||
DamlContextualizedErrorLogger,
|
||||
@ -10,16 +14,14 @@ import com.daml.error.{
|
||||
}
|
||||
import com.daml.ledger.api.auth.interceptor.AuthorizationInterceptor
|
||||
import com.daml.ledger.api.v1.transaction_filter.TransactionFilter
|
||||
import com.daml.ledger.participant.state.index.v2.UserManagementStore
|
||||
import com.daml.logging.{ContextualizedLogger, LoggingContext}
|
||||
import com.daml.platform.server.api.validation.ErrorFactories
|
||||
import io.grpc.stub.{ServerCallStreamObserver, StreamObserver}
|
||||
import java.time.Instant
|
||||
|
||||
import com.daml.error.definitions.LedgerApiErrors
|
||||
import io.grpc.StatusRuntimeException
|
||||
import io.grpc.stub.{ServerCallStreamObserver, StreamObserver}
|
||||
import scalapb.lenses.Lens
|
||||
|
||||
import scala.concurrent.Future
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
/** A simple helper that allows services to use authorization claims
|
||||
@ -30,6 +32,10 @@ final class Authorizer(
|
||||
ledgerId: String,
|
||||
participantId: String,
|
||||
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
|
||||
userManagementStore: UserManagementStore,
|
||||
ec: ExecutionContext,
|
||||
userRightsCheckIntervalInSeconds: Int,
|
||||
akkaScheduler: Scheduler,
|
||||
)(implicit loggingContext: LoggingContext) {
|
||||
private val logger = ContextualizedLogger.get(this.getClass)
|
||||
private val errorFactories = ErrorFactories(errorCodesVersionSwitcher)
|
||||
@ -226,16 +232,17 @@ final class Authorizer(
|
||||
}
|
||||
|
||||
private def ongoingAuthorization[Res](
|
||||
scso: ServerCallStreamObserver[Res],
|
||||
observer: ServerCallStreamObserver[Res],
|
||||
claims: ClaimSet.Claims,
|
||||
) = new OngoingAuthorizationObserver[Res](
|
||||
scso,
|
||||
claims,
|
||||
_.notExpired(now()),
|
||||
authorizationError => {
|
||||
errorFactories.permissionDenied(authorizationError.reason)
|
||||
},
|
||||
)
|
||||
observer = observer,
|
||||
originalClaims = claims,
|
||||
nowF = now,
|
||||
errorFactories = errorFactories,
|
||||
userManagementStore = userManagementStore,
|
||||
userRightsCheckIntervalInSeconds = userRightsCheckIntervalInSeconds,
|
||||
akkaScheduler = akkaScheduler,
|
||||
)(loggingContext, ec)
|
||||
|
||||
/** Directly access the authenticated claims from the thread-local context.
|
||||
*
|
||||
@ -263,7 +270,7 @@ final class Authorizer(
|
||||
private def authorizeWithReq[Req, Res](call: (Req, ServerCallStreamObserver[Res]) => Unit)(
|
||||
authorized: (ClaimSet.Claims, Req) => Either[StatusRuntimeException, Req]
|
||||
): (Req, StreamObserver[Res]) => Unit = (request, observer) => {
|
||||
val scso = assertServerCall(observer)
|
||||
val serverCallStreamObserver = assertServerCall(observer)
|
||||
authenticatedClaimsFromContext()
|
||||
.fold(
|
||||
ex => {
|
||||
@ -278,10 +285,10 @@ final class Authorizer(
|
||||
case Right(modifiedRequest) =>
|
||||
call(
|
||||
modifiedRequest,
|
||||
if (claims.expiration.isDefined)
|
||||
ongoingAuthorization(scso, claims)
|
||||
if (claims.expiration.isDefined || claims.resolvedFromUser)
|
||||
ongoingAuthorization(serverCallStreamObserver, claims)
|
||||
else
|
||||
scso,
|
||||
serverCallStreamObserver,
|
||||
)
|
||||
case Left(ex) =>
|
||||
observer.onError(ex)
|
||||
@ -324,8 +331,21 @@ object Authorizer {
|
||||
ledgerId: String,
|
||||
participantId: String,
|
||||
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
|
||||
userManagementStore: UserManagementStore,
|
||||
ec: ExecutionContext,
|
||||
userRightsCheckIntervalInSeconds: Int,
|
||||
akkaScheduler: Scheduler,
|
||||
): Authorizer =
|
||||
LoggingContext.newLoggingContext { loggingContext =>
|
||||
new Authorizer(now, ledgerId, participantId, errorCodesVersionSwitcher)(loggingContext)
|
||||
new Authorizer(
|
||||
now = now,
|
||||
ledgerId = ledgerId,
|
||||
participantId = participantId,
|
||||
errorCodesVersionSwitcher = errorCodesVersionSwitcher,
|
||||
userManagementStore = userManagementStore,
|
||||
ec = ec,
|
||||
userRightsCheckIntervalInSeconds = userRightsCheckIntervalInSeconds,
|
||||
akkaScheduler = akkaScheduler,
|
||||
)(loggingContext)
|
||||
}
|
||||
}
|
||||
|
@ -3,14 +3,84 @@
|
||||
|
||||
package com.daml.ledger.api.auth
|
||||
|
||||
import java.time.Instant
|
||||
|
||||
import akka.actor.{Cancellable, Scheduler}
|
||||
import com.daml.error.DamlContextualizedErrorLogger
|
||||
import com.daml.error.definitions.LedgerApiErrors
|
||||
import com.daml.ledger.api.auth.interceptor.AuthorizationInterceptor
|
||||
import com.daml.ledger.participant.state.index.v2.UserManagementStore
|
||||
import com.daml.lf.data.Ref
|
||||
import com.daml.logging.{ContextualizedLogger, LoggingContext}
|
||||
import com.daml.platform.server.api.validation.ErrorFactories
|
||||
import io.grpc.StatusRuntimeException
|
||||
import io.grpc.stub.ServerCallStreamObserver
|
||||
|
||||
import scala.concurrent.ExecutionContext
|
||||
import scala.concurrent.duration._
|
||||
import scala.util.{Failure, Success}
|
||||
|
||||
/** @param userRightsCheckIntervalInSeconds - determines the interval at which to check whether user rights state has changed.
|
||||
* Also, double of this value serves as timeout value for subsequent user rights state checks.
|
||||
*/
|
||||
private[auth] final class OngoingAuthorizationObserver[A](
|
||||
observer: ServerCallStreamObserver[A],
|
||||
claims: ClaimSet.Claims,
|
||||
authorized: ClaimSet.Claims => Either[AuthorizationError, Unit],
|
||||
throwOnFailure: AuthorizationError => Throwable,
|
||||
) extends ServerCallStreamObserver[A] {
|
||||
originalClaims: ClaimSet.Claims,
|
||||
nowF: () => Instant,
|
||||
errorFactories: ErrorFactories,
|
||||
userManagementStore: UserManagementStore,
|
||||
userRightsCheckIntervalInSeconds: Int,
|
||||
akkaScheduler: Scheduler,
|
||||
)(implicit loggingContext: LoggingContext, ec: ExecutionContext)
|
||||
extends ServerCallStreamObserver[A] {
|
||||
|
||||
private val logger = ContextualizedLogger.get(getClass)
|
||||
private val errorLogger = new DamlContextualizedErrorLogger(logger, loggingContext, None)
|
||||
|
||||
// Guards against propagating calls to delegate observer after either
|
||||
// [[onComplete]] or [[onError]] has already been called once.
|
||||
// We need this because [[onError]] can be invoked two concurrent sources:
|
||||
// 1) scheduled user rights state change task (see [[cancellableO]]),
|
||||
// 2) upstream component that is translating upstream Akka stream into [[onNext]] and other signals.
|
||||
private var afterCompletionOrError = false
|
||||
|
||||
@volatile private var lastUserRightsCheckTime = nowF()
|
||||
|
||||
private lazy val userId = originalClaims.applicationId.fold[Ref.UserId](
|
||||
throw new RuntimeException(
|
||||
"Claims were resolved from a user but userId (applicationId) is missing in the claims."
|
||||
)
|
||||
)(Ref.UserId.assertFromString)
|
||||
|
||||
// Scheduling a task that periodically checks
|
||||
// whether user rights state has changed.
|
||||
// If user rights state has changed it aborts the stream by calling [[onError]]
|
||||
private val cancellableO: Option[Cancellable] = {
|
||||
if (originalClaims.resolvedFromUser) {
|
||||
val delay = userRightsCheckIntervalInSeconds.seconds
|
||||
// Note: https://doc.akka.io/docs/akka/2.6.13/scheduler.html states that:
|
||||
// "All scheduled task will be executed when the ActorSystem is terminated, i.e. the task may execute before its timeout."
|
||||
val c = akkaScheduler.scheduleWithFixedDelay(initialDelay = delay, delay = delay)(runnable =
|
||||
checkUserRights _
|
||||
)
|
||||
Some(c)
|
||||
} else None
|
||||
}
|
||||
|
||||
private def checkUserRights(): Unit = {
|
||||
userManagementStore
|
||||
.listUserRights(userId)
|
||||
.onComplete {
|
||||
case Failure(_) | Success(Left(_)) =>
|
||||
onError(staleStreamAuthError)
|
||||
case Success(Right(userRights)) =>
|
||||
val updatedClaims = AuthorizationInterceptor.convertUserRightsToClaims(userRights)
|
||||
if (updatedClaims.toSet != originalClaims.claims.toSet) {
|
||||
onError(staleStreamAuthError)
|
||||
}
|
||||
lastUserRightsCheckTime = nowF()
|
||||
}
|
||||
}
|
||||
|
||||
override def isCancelled: Boolean = observer.isCancelled
|
||||
|
||||
@ -26,15 +96,81 @@ private[auth] final class OngoingAuthorizationObserver[A](
|
||||
|
||||
override def request(i: Int): Unit = observer.request(i)
|
||||
|
||||
override def setMessageCompression(b: Boolean): Unit = observer.setMessageCompression(b)
|
||||
override def setMessageCompression(b: Boolean): Unit = synchronized(
|
||||
observer.setMessageCompression(b)
|
||||
)
|
||||
|
||||
override def onNext(v: A): Unit =
|
||||
authorized(claims) match {
|
||||
case Right(_) => observer.onNext(v)
|
||||
case Left(authorizationError) => observer.onError(throwOnFailure(authorizationError))
|
||||
override def onNext(v: A): Unit = synchronized {
|
||||
if (!afterCompletionOrError) {
|
||||
val now = nowF()
|
||||
(for {
|
||||
_ <- checkClaimsExpiry(now)
|
||||
_ <- checkUserRightsRefreshTimeout(now)
|
||||
} yield {
|
||||
()
|
||||
}) match {
|
||||
case Right(_) => observer.onNext(v)
|
||||
case Left(e) =>
|
||||
onError(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def onError(throwable: Throwable): Unit = observer.onError(throwable)
|
||||
override def onError(throwable: Throwable): Unit = synchronized {
|
||||
if (!afterCompletionOrError) {
|
||||
afterCompletionOrError = true
|
||||
cancelUserRightsCheckTask()
|
||||
observer.onError(throwable)
|
||||
}
|
||||
}
|
||||
|
||||
override def onCompleted(): Unit = synchronized {
|
||||
if (!afterCompletionOrError) {
|
||||
afterCompletionOrError = true
|
||||
cancelUserRightsCheckTask()
|
||||
observer.onCompleted()
|
||||
}
|
||||
}
|
||||
|
||||
private def checkUserRightsRefreshTimeout(now: Instant): Either[StatusRuntimeException, Unit] = {
|
||||
|
||||
// Safety switch to abort the stream if the user-rights-state-check task
|
||||
// fails to refresh within 2*[[userRightsCheckIntervalInSeconds]] seconds.
|
||||
// In normal conditions we expected the refresh delay to be about [[userRightsCheckIntervalInSeconds]] seconds.
|
||||
if (
|
||||
originalClaims.resolvedFromUser &&
|
||||
lastUserRightsCheckTime.isBefore(
|
||||
now.minusSeconds(2 * userRightsCheckIntervalInSeconds.toLong)
|
||||
)
|
||||
) {
|
||||
Left(staleStreamAuthError)
|
||||
} else Right(())
|
||||
}
|
||||
|
||||
private def checkClaimsExpiry(now: Instant): Either[StatusRuntimeException, Unit] = {
|
||||
originalClaims
|
||||
.notExpired(now)
|
||||
.left
|
||||
.map(authorizationError =>
|
||||
errorFactories.permissionDenied(authorizationError.reason)(errorLogger)
|
||||
)
|
||||
}
|
||||
|
||||
private def staleStreamAuthError: StatusRuntimeException = {
|
||||
// Terminate the stream, so that clients will restart their streams
|
||||
// and claims will be rechecked precisely.
|
||||
LedgerApiErrors.AuthorizationChecks.StaleUserManagementBasedStreamClaims
|
||||
.Reject()(errorLogger)
|
||||
.asGrpcError
|
||||
}
|
||||
|
||||
private def cancelUserRightsCheckTask(): Unit = {
|
||||
cancellableO.foreach { cancellable =>
|
||||
cancellable.cancel()
|
||||
if (!cancellable.isCancelled) {
|
||||
logger.debug(s"Failed to cancel stream authorization task")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def onCompleted(): Unit = observer.onCompleted()
|
||||
}
|
||||
|
@ -91,10 +91,10 @@ final class AuthorizationInterceptor(
|
||||
s"Could not resolve rights for user '$userId' due to '$msg'"
|
||||
)(errorLogger)
|
||||
)
|
||||
case Right(userClaims) =>
|
||||
case Right(userRights: Set[UserRight]) =>
|
||||
Future.successful(
|
||||
ClaimSet.Claims(
|
||||
claims = userClaims.view.map(userRightToClaim).toList.prepended(ClaimPublic),
|
||||
claims = AuthorizationInterceptor.convertUserRightsToClaims(userRights),
|
||||
ledgerId = None,
|
||||
participantId = participantId,
|
||||
applicationId = Some(userId),
|
||||
@ -133,11 +133,6 @@ final class AuthorizationInterceptor(
|
||||
Future.successful(userId)
|
||||
}
|
||||
|
||||
private[this] def userRightToClaim(r: UserRight): Claim = r match {
|
||||
case UserRight.CanActAs(p) => ClaimActAsParty(Ref.Party.assertFromString(p))
|
||||
case UserRight.CanReadAs(p) => ClaimReadAsParty(Ref.Party.assertFromString(p))
|
||||
case UserRight.ParticipantAdmin => ClaimAdmin
|
||||
}
|
||||
}
|
||||
|
||||
object AuthorizationInterceptor {
|
||||
@ -165,4 +160,14 @@ object AuthorizationInterceptor {
|
||||
LoggingContext.newLoggingContext { implicit loggingContext: LoggingContext =>
|
||||
new AuthorizationInterceptor(authService, userManagementStoreO, ec, errorCodesStatusSwitcher)
|
||||
}
|
||||
|
||||
def convertUserRightsToClaims(userRights: Set[UserRight]): Seq[Claim] = {
|
||||
userRights.view.map(userRightToClaim).toList.prepended(ClaimPublic)
|
||||
}
|
||||
|
||||
private[this] def userRightToClaim(r: UserRight): Claim = r match {
|
||||
case UserRight.CanActAs(p) => ClaimActAsParty(Ref.Party.assertFromString(p))
|
||||
case UserRight.CanReadAs(p) => ClaimReadAsParty(Ref.Party.assertFromString(p))
|
||||
case UserRight.ParticipantAdmin => ClaimAdmin
|
||||
}
|
||||
}
|
||||
|
@ -9,12 +9,20 @@ import io.grpc.{Status, StatusRuntimeException}
|
||||
import org.scalatest.Assertion
|
||||
import org.scalatest.flatspec.AsyncFlatSpec
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
|
||||
import java.time.Instant
|
||||
import scala.concurrent.Future
|
||||
|
||||
import com.daml.ledger.api.testing.utils.AkkaBeforeAndAfterAll
|
||||
import com.daml.ledger.participant.state.index.v2.UserManagementStore
|
||||
import org.mockito.MockitoSugar
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.util.{Failure, Success, Try}
|
||||
|
||||
class AuthorizerSpec extends AsyncFlatSpec with Matchers {
|
||||
class AuthorizerSpec
|
||||
extends AsyncFlatSpec
|
||||
with Matchers
|
||||
with MockitoSugar
|
||||
with AkkaBeforeAndAfterAll {
|
||||
private val className = classOf[Authorizer].getSimpleName
|
||||
private val dummyRequest = 1337L
|
||||
private val expectedSuccessfulResponse = "expectedSuccessfulResponse"
|
||||
@ -77,5 +85,9 @@ class AuthorizerSpec extends AsyncFlatSpec with Matchers {
|
||||
"some-ledger-id",
|
||||
"participant-id",
|
||||
new ErrorCodesVersionSwitcher(selfServiceErrorCodes),
|
||||
mock[UserManagementStore],
|
||||
mock[ExecutionContext],
|
||||
userRightsCheckIntervalInSeconds = 1,
|
||||
akkaScheduler = system.scheduler,
|
||||
)
|
||||
}
|
||||
|
@ -0,0 +1,92 @@
|
||||
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.ledger.api.auth
|
||||
|
||||
import java.time.{Clock, Duration, Instant, ZoneId}
|
||||
|
||||
import akka.actor.{Cancellable, Scheduler}
|
||||
import com.daml.clock.AdjustableClock
|
||||
import com.daml.error.ErrorsAssertions
|
||||
import com.daml.error.definitions.LedgerApiErrors
|
||||
import com.daml.ledger.participant.state.index.v2.UserManagementStore
|
||||
import com.daml.logging.LoggingContext
|
||||
import com.daml.platform.server.api.validation.ErrorFactories
|
||||
import io.grpc.StatusRuntimeException
|
||||
import io.grpc.stub.ServerCallStreamObserver
|
||||
import org.mockito.{ArgumentCaptor, ArgumentMatchersSugar, MockitoSugar}
|
||||
import org.scalatest.flatspec.AsyncFlatSpec
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
|
||||
import scala.concurrent.ExecutionContext
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
|
||||
class OngoingAuthorizationObserverSpec
|
||||
extends AsyncFlatSpec
|
||||
with Matchers
|
||||
with MockitoSugar
|
||||
with ArgumentMatchersSugar
|
||||
with ErrorsAssertions {
|
||||
|
||||
private val loggingContext = LoggingContext.ForTesting
|
||||
|
||||
it should "signal onError aborting the stream when user rights state hasn't been refreshed in a timely manner" in {
|
||||
val clock = AdjustableClock(
|
||||
baseClock = Clock.fixed(Instant.now(), ZoneId.systemDefault()),
|
||||
offset = Duration.ZERO,
|
||||
)
|
||||
val delegate = mock[ServerCallStreamObserver[Int]]
|
||||
val mockScheduler = mock[Scheduler]
|
||||
// Set scheduler to do nothing
|
||||
val cancellableMock = mock[Cancellable]
|
||||
when(
|
||||
mockScheduler.scheduleWithFixedDelay(any[FiniteDuration], any[FiniteDuration])(any[Runnable])(
|
||||
any[ExecutionContext]
|
||||
)
|
||||
).thenReturn(cancellableMock)
|
||||
val userRightsCheckIntervalInSeconds = 10
|
||||
val tested = new OngoingAuthorizationObserver(
|
||||
observer = delegate,
|
||||
originalClaims = ClaimSet.Claims.Empty.copy(resolvedFromUser = true),
|
||||
nowF = clock.instant,
|
||||
errorFactories = mock[ErrorFactories],
|
||||
userManagementStore = mock[UserManagementStore],
|
||||
// This is also the user rights state refresh timeout
|
||||
userRightsCheckIntervalInSeconds = userRightsCheckIntervalInSeconds,
|
||||
akkaScheduler = mockScheduler,
|
||||
)(loggingContext, executionContext)
|
||||
|
||||
// After 20 seconds pass we expect onError to be called due to lack of user rights state refresh task inactivity
|
||||
tested.onNext(1)
|
||||
clock.fastForward(Duration.ofSeconds(2.toLong * userRightsCheckIntervalInSeconds - 1))
|
||||
tested.onNext(2)
|
||||
clock.fastForward(Duration.ofSeconds(2))
|
||||
// Next onNext detects the user rights state refresh task inactivity
|
||||
tested.onNext(3)
|
||||
|
||||
val captor = ArgumentCaptor.forClass(classOf[StatusRuntimeException])
|
||||
val order = inOrder(delegate)
|
||||
order.verify(delegate, times(1)).onNext(1)
|
||||
order.verify(delegate, times(1)).onNext(2)
|
||||
order.verify(delegate, times(1)).onError(captor.capture())
|
||||
order.verifyNoMoreInteractions()
|
||||
// Scheduled task is cancelled
|
||||
verify(cancellableMock, times(1)).cancel()
|
||||
assertError(
|
||||
actual = captor.getValue,
|
||||
expectedF = LedgerApiErrors.AuthorizationChecks.StaleUserManagementBasedStreamClaims
|
||||
.Reject()(_)
|
||||
.asGrpcError,
|
||||
)
|
||||
|
||||
// onError has already been called by tested implementation so subsequent onNext, onError and onComplete
|
||||
// must not be forwarded to the delegate observer
|
||||
tested.onNext(4)
|
||||
tested.onError(new RuntimeException)
|
||||
tested.onCompleted()
|
||||
verifyNoMoreInteractions(delegate)
|
||||
|
||||
succeed
|
||||
}
|
||||
|
||||
}
|
@ -13,8 +13,8 @@ import com.daml.error.utils.ErrorDetails
|
||||
import com.daml.error.{
|
||||
ContextualizedErrorLogger,
|
||||
DamlContextualizedErrorLogger,
|
||||
ErrorAssertionsWithLogCollectorAssertions,
|
||||
ErrorCodesVersionSwitcher,
|
||||
ErrorsAssertions,
|
||||
}
|
||||
import com.daml.ledger.api.domain.LedgerId
|
||||
import com.daml.lf.data.Ref
|
||||
@ -31,6 +31,7 @@ import org.scalatest.BeforeAndAfter
|
||||
import org.scalatest.matchers.should.Matchers
|
||||
import org.scalatest.prop.TableDrivenPropertyChecks
|
||||
import org.scalatest.wordspec.AnyWordSpec
|
||||
import scala.concurrent.duration._
|
||||
|
||||
import scala.annotation.nowarn
|
||||
import scala.jdk.CollectionConverters._
|
||||
@ -43,7 +44,7 @@ class ErrorFactoriesSpec
|
||||
with MockitoSugar
|
||||
with BeforeAndAfter
|
||||
with LogCollectorAssertions
|
||||
with ErrorsAssertions {
|
||||
with ErrorAssertionsWithLogCollectorAssertions {
|
||||
|
||||
private val logger = ContextualizedLogger.get(getClass)
|
||||
private val loggingContext = LoggingContext.ForTesting
|
||||
@ -83,7 +84,7 @@ class ErrorFactoriesSpec
|
||||
expectedMessage = msg,
|
||||
expectedDetails = Seq[ErrorDetails.ErrorDetail](
|
||||
expectedCorrelationIdRequestInfo,
|
||||
ErrorDetails.RetryInfoDetail(1),
|
||||
ErrorDetails.RetryInfoDetail(1.second),
|
||||
ErrorDetails.ErrorInfoDetail(
|
||||
"INDEX_DB_SQL_TRANSIENT_ERROR",
|
||||
Map("category" -> "1", "definite_answer" -> "false"),
|
||||
@ -161,7 +162,7 @@ class ErrorFactoriesSpec
|
||||
),
|
||||
),
|
||||
expectedCorrelationIdRequestInfo,
|
||||
ErrorDetails.RetryInfoDetail(1),
|
||||
ErrorDetails.RetryInfoDetail(1.second),
|
||||
),
|
||||
v2_logEntry = ExpectedLogEntry(
|
||||
Level.WARN,
|
||||
@ -194,7 +195,7 @@ class ErrorFactoriesSpec
|
||||
),
|
||||
),
|
||||
expectedCorrelationIdRequestInfo,
|
||||
ErrorDetails.RetryInfoDetail(1),
|
||||
ErrorDetails.RetryInfoDetail(1.second),
|
||||
),
|
||||
v2_logEntry = ExpectedLogEntry(
|
||||
Level.INFO,
|
||||
@ -223,7 +224,7 @@ class ErrorFactoriesSpec
|
||||
Map("category" -> "3", "definite_answer" -> "false"),
|
||||
),
|
||||
expectedCorrelationIdRequestInfo,
|
||||
ErrorDetails.RetryInfoDetail(1),
|
||||
ErrorDetails.RetryInfoDetail(1.second),
|
||||
),
|
||||
v2_logEntry = ExpectedLogEntry(
|
||||
Level.INFO,
|
||||
@ -399,7 +400,7 @@ class ErrorFactoriesSpec
|
||||
Map("category" -> "3", "definite_answer" -> "false"),
|
||||
),
|
||||
expectedCorrelationIdRequestInfo,
|
||||
ErrorDetails.RetryInfoDetail(1),
|
||||
ErrorDetails.RetryInfoDetail(1.second),
|
||||
),
|
||||
v2_logEntry = ExpectedLogEntry(
|
||||
Level.INFO,
|
||||
@ -742,7 +743,7 @@ class ErrorFactoriesSpec
|
||||
Map("category" -> "1", "definite_answer" -> "false", "service_name" -> serviceName),
|
||||
),
|
||||
expectedCorrelationIdRequestInfo,
|
||||
ErrorDetails.RetryInfoDetail(1),
|
||||
ErrorDetails.RetryInfoDetail(1.second),
|
||||
),
|
||||
v2_logEntry = ExpectedLogEntry(
|
||||
Level.INFO,
|
||||
|
@ -107,12 +107,19 @@ da_scala_binary(
|
||||
srcs = suites_sources(lf_version),
|
||||
scala_deps = [
|
||||
"@maven//:com_chuusai_shapeless",
|
||||
"@maven//:com_typesafe_akka_akka_actor",
|
||||
"@maven//:com_typesafe_akka_akka_stream",
|
||||
],
|
||||
scaladoc = False,
|
||||
visibility = [
|
||||
"//:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
"//ledger-api/rs-grpc-bridge",
|
||||
"//ledger-service/jwt",
|
||||
"//ledger/ledger-api-client",
|
||||
"//ledger/ledger-api-domain",
|
||||
"@maven//:com_auth0_java_jwt",
|
||||
":ledger-api-test-tool-%s-lib" % lf_version,
|
||||
"//daml-lf/data",
|
||||
"//language-support/scala/bindings",
|
||||
|
@ -187,15 +187,15 @@ object Assertions {
|
||||
)
|
||||
)
|
||||
val expectedErrorId = expectedErrorCode.id
|
||||
val expectedRetryabilitySeconds = expectedErrorCode.category.retryable.map(_.duration.toSeconds)
|
||||
val expectedRetryability = expectedErrorCode.category.retryable.map(_.duration)
|
||||
|
||||
val actualStatusCode = status.getCode
|
||||
val actualErrorDetails = ErrorDetails.from(status.getDetailsList.asScala.toSeq)
|
||||
val actualErrorId = actualErrorDetails
|
||||
.collectFirst { case err: ErrorDetails.ErrorInfoDetail => err.reason }
|
||||
.getOrElse(fail("Actual error id is not defined"))
|
||||
val actualRetryabilitySeconds = actualErrorDetails
|
||||
.collectFirst { case err: ErrorDetails.RetryInfoDetail => err.retryDelayInSeconds }
|
||||
val actualRetryability = actualErrorDetails
|
||||
.collectFirst { case err: ErrorDetails.RetryInfoDetail => err.duration }
|
||||
|
||||
if (actualErrorId != expectedErrorId)
|
||||
fail(s"Actual error id ($actualErrorId) does not match expected error id ($expectedErrorId}")
|
||||
@ -208,8 +208,8 @@ object Assertions {
|
||||
|
||||
Assertions.assertEquals(
|
||||
s"Error retryability details mismatch",
|
||||
actualRetryabilitySeconds,
|
||||
expectedRetryabilitySeconds,
|
||||
actualRetryability,
|
||||
expectedRetryability,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -27,6 +27,7 @@ import com.daml.platform.configuration.{
|
||||
SubmissionConfiguration,
|
||||
}
|
||||
import com.daml.platform.services.time.TimeProviderType
|
||||
import com.daml.platform.usermanagement.UserManagementConfig
|
||||
import com.daml.ports.{Port, PortFiles}
|
||||
import com.daml.telemetry.TelemetryContext
|
||||
import io.grpc.{BindableService, ServerInterceptor}
|
||||
@ -59,6 +60,7 @@ object StandaloneApiServer {
|
||||
checkOverloaded: TelemetryContext => Option[state.SubmissionResult] =
|
||||
_ => None, // Used for Canton rate-limiting,
|
||||
ledgerFeatures: LedgerFeatures,
|
||||
userManagementConfig: UserManagementConfig,
|
||||
)(implicit
|
||||
actorSystem: ActorSystem,
|
||||
materializer: Materializer,
|
||||
@ -86,6 +88,10 @@ object StandaloneApiServer {
|
||||
ledgerId,
|
||||
participantId,
|
||||
errorCodesVersionSwitcher,
|
||||
userManagementStore,
|
||||
servicesExecutionContext,
|
||||
userRightsCheckIntervalInSeconds = userManagementConfig.cacheExpiryAfterWriteInSeconds,
|
||||
akkaScheduler = actorSystem.scheduler,
|
||||
)
|
||||
val healthChecksWithIndexService = healthChecks + ("index" -> indexService)
|
||||
|
||||
|
@ -23,7 +23,7 @@ import scalaz.std.list._
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
|
||||
private[apiserver] final class ApiUserManagementService(
|
||||
userManagementService: UserManagementStore,
|
||||
userManagementStore: UserManagementStore,
|
||||
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
|
||||
)(implicit
|
||||
executionContext: ExecutionContext,
|
||||
@ -53,7 +53,7 @@ private[apiserver] final class ApiUserManagementService(
|
||||
pRights <- fromProtoRights(request.rights)
|
||||
} yield (User(pUserId, pOptPrimaryParty), pRights)
|
||||
} { case (user, pRights) =>
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.createUser(
|
||||
user = user,
|
||||
rights = pRights,
|
||||
@ -66,7 +66,7 @@ private[apiserver] final class ApiUserManagementService(
|
||||
withValidation(
|
||||
requireUserId(request.userId, "user_id")
|
||||
)(userId =>
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.getUser(userId)
|
||||
.flatMap(handleResult("getting user"))
|
||||
.map(toProtoUser)
|
||||
@ -76,14 +76,14 @@ private[apiserver] final class ApiUserManagementService(
|
||||
withValidation(
|
||||
requireUserId(request.userId, "user_id")
|
||||
)(userId =>
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.deleteUser(userId)
|
||||
.flatMap(handleResult("deleting user"))
|
||||
.map(_ => proto.DeleteUserResponse())
|
||||
)
|
||||
|
||||
override def listUsers(request: proto.ListUsersRequest): Future[proto.ListUsersResponse] =
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.listUsers()
|
||||
.flatMap(handleResult("listing users"))
|
||||
.map(
|
||||
@ -100,7 +100,7 @@ private[apiserver] final class ApiUserManagementService(
|
||||
rights <- fromProtoRights(request.rights)
|
||||
} yield (userId, rights)
|
||||
) { case (userId, rights) =>
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.grantRights(
|
||||
id = userId,
|
||||
rights = rights,
|
||||
@ -119,7 +119,7 @@ private[apiserver] final class ApiUserManagementService(
|
||||
rights <- fromProtoRights(request.rights)
|
||||
} yield (userId, rights)
|
||||
) { case (userId, rights) =>
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.revokeRights(
|
||||
id = userId,
|
||||
rights = rights,
|
||||
@ -135,7 +135,7 @@ private[apiserver] final class ApiUserManagementService(
|
||||
withValidation(
|
||||
requireUserId(request.userId, "user_id")
|
||||
)(userId =>
|
||||
userManagementService
|
||||
userManagementStore
|
||||
.listUserRights(userId)
|
||||
.flatMap(handleResult("list user rights"))
|
||||
.map(_.view.map(toProtoRight).toList)
|
||||
|
@ -654,8 +654,8 @@ object Config {
|
||||
.optional()
|
||||
.text(
|
||||
s"Defaults to ${UserManagementConfig.DefaultCacheExpiryAfterWriteInSeconds} seconds. " +
|
||||
// TODO participant user management: Update max delay to 2x the configured value when made use of in throttled stream authorization.
|
||||
"Determines the maximum delay for propagating user management state changes."
|
||||
"Used to set expiry time for user management cache. " +
|
||||
"Also determines the maximum delay for propagating user management state changes which is double its value."
|
||||
)
|
||||
.action((value, config: Config[Extra]) =>
|
||||
config.withUserManagementConfig(_.copy(cacheExpiryAfterWriteInSeconds = value))
|
||||
|
@ -258,6 +258,7 @@ final class Runner[T <: ReadWriteService, Extra](
|
||||
v1 = ExperimentalContractIds.ContractIdV1Support.NON_SUFFIXED
|
||||
),
|
||||
),
|
||||
userManagementConfig = config.userManagementConfig,
|
||||
).acquire()
|
||||
} yield Some(apiServer.port)
|
||||
case ParticipantRunMode.Indexer =>
|
||||
|
@ -204,6 +204,7 @@ test_deps = [
|
||||
"//libs-scala/postgresql-testing",
|
||||
"//libs-scala/resources",
|
||||
"//libs-scala/timer-utils",
|
||||
"//ledger/error:error-test-lib",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:ch_qos_logback_logback_core",
|
||||
"@maven//:com_typesafe_config",
|
||||
|
@ -51,16 +51,16 @@ trait ExpiringStreamServiceCallAuthTests[T]
|
||||
toHeader(expiringIn(Duration.ofSeconds(5), readOnlyToken(mainActor)))
|
||||
|
||||
it should "break a stream in flight upon read-only token expiration" in {
|
||||
val _ = Delayed.Future.by(10.seconds)(submitAndWait())
|
||||
val _ = Delayed.Future.by(10.seconds)(submitAndWaitAsMainActor())
|
||||
expectExpiration(canReadAsMainActorExpiresInFiveSeconds).map(_ => succeed)
|
||||
}
|
||||
|
||||
it should "break a stream in flight upon read/write token expiration" in {
|
||||
val _ = Delayed.Future.by(10.seconds)(submitAndWait())
|
||||
val _ = Delayed.Future.by(10.seconds)(submitAndWaitAsMainActor())
|
||||
expectExpiration(canActAsMainActorExpiresInFiveSeconds).map(_ => succeed)
|
||||
}
|
||||
|
||||
override def serviceCallWithToken(token: Option[String]): Future[Any] =
|
||||
submitAndWait().flatMap(_ => new StreamConsumer[T](stream(token)).first())
|
||||
submitAndWaitAsMainActor().flatMap(_ => new StreamConsumer[T](stream(token)).first())
|
||||
|
||||
}
|
||||
|
@ -6,20 +6,33 @@ package com.daml.platform.sandbox.services
|
||||
import java.util.UUID
|
||||
|
||||
import com.daml.ledger.api.v1.command_service.{CommandServiceGrpc, SubmitAndWaitRequest}
|
||||
import com.daml.platform.sandbox.auth.ServiceCallWithMainActorAuthTests
|
||||
import com.daml.platform.sandbox.auth.{ServiceCallAuthTests, ServiceCallWithMainActorAuthTests}
|
||||
import com.google.protobuf.empty.Empty
|
||||
|
||||
import scala.concurrent.Future
|
||||
|
||||
trait SubmitAndWaitDummyCommand extends TestCommands { self: ServiceCallWithMainActorAuthTests =>
|
||||
trait SubmitAndWaitDummyCommand extends TestCommands with SubmitAndWaitDummyCommandHelpers {
|
||||
self: ServiceCallWithMainActorAuthTests =>
|
||||
|
||||
protected def submitAndWait(): Future[Empty] =
|
||||
submitAndWait(Option(toHeader(readWriteToken(mainActor))))
|
||||
protected def submitAndWaitAsMainActor(): Future[Empty] =
|
||||
submitAndWait(
|
||||
Option(toHeader(readWriteToken(mainActor))),
|
||||
applicationId = serviceCallName,
|
||||
party = mainActor,
|
||||
)
|
||||
|
||||
protected def dummySubmitAndWaitRequest(applicationId: String): SubmitAndWaitRequest =
|
||||
}
|
||||
|
||||
trait SubmitAndWaitDummyCommandHelpers extends TestCommands {
|
||||
self: ServiceCallAuthTests =>
|
||||
|
||||
protected def dummySubmitAndWaitRequest(
|
||||
applicationId: String,
|
||||
party: String,
|
||||
): SubmitAndWaitRequest =
|
||||
SubmitAndWaitRequest(
|
||||
dummyCommands(wrappedLedgerId, s"$serviceCallName-${UUID.randomUUID}", mainActor)
|
||||
.update(_.commands.applicationId := applicationId, _.commands.party := mainActor)
|
||||
dummyCommands(wrappedLedgerId, s"$serviceCallName-${UUID.randomUUID}", party = party)
|
||||
.update(_.commands.applicationId := applicationId, _.commands.party := party)
|
||||
.commands
|
||||
)
|
||||
|
||||
@ -29,31 +42,35 @@ trait SubmitAndWaitDummyCommand extends TestCommands { self: ServiceCallWithMain
|
||||
protected def submitAndWait(
|
||||
token: Option[String],
|
||||
applicationId: String = serviceCallName,
|
||||
party: String,
|
||||
): Future[Empty] =
|
||||
service(token).submitAndWait(dummySubmitAndWaitRequest(applicationId))
|
||||
service(token).submitAndWait(dummySubmitAndWaitRequest(applicationId, party = party))
|
||||
|
||||
protected def submitAndWaitForTransaction(
|
||||
token: Option[String],
|
||||
applicationId: String = serviceCallName,
|
||||
party: String,
|
||||
): Future[Empty] =
|
||||
service(token)
|
||||
.submitAndWaitForTransaction(dummySubmitAndWaitRequest(applicationId))
|
||||
.submitAndWaitForTransaction(dummySubmitAndWaitRequest(applicationId, party = party))
|
||||
.map(_ => Empty())
|
||||
|
||||
protected def submitAndWaitForTransactionId(
|
||||
token: Option[String],
|
||||
applicationId: String = serviceCallName,
|
||||
party: String,
|
||||
): Future[Empty] =
|
||||
service(token)
|
||||
.submitAndWaitForTransactionId(dummySubmitAndWaitRequest(applicationId))
|
||||
.submitAndWaitForTransactionId(dummySubmitAndWaitRequest(applicationId, party = party))
|
||||
.map(_ => Empty())
|
||||
|
||||
protected def submitAndWaitForTransactionTree(
|
||||
token: Option[String],
|
||||
applicationId: String = serviceCallName,
|
||||
party: String,
|
||||
): Future[Empty] =
|
||||
service(token)
|
||||
.submitAndWaitForTransactionTree(dummySubmitAndWaitRequest(applicationId))
|
||||
.submitAndWaitForTransactionTree(dummySubmitAndWaitRequest(applicationId, party = party))
|
||||
.map(_ => Empty())
|
||||
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ final class CompletionStreamAuthIT
|
||||
|
||||
override protected def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] =
|
||||
// Note: the token must allow actAs mainActor for this call to work.
|
||||
submitAndWait(token, "").flatMap(_ =>
|
||||
submitAndWait(token, "", party = mainActor).flatMap(_ =>
|
||||
new StreamConsumer[CompletionStreamResponse](streamFor("")(token)).first()
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,146 @@
|
||||
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
package com.daml.platform.sandbox.auth
|
||||
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.util.{Timer, TimerTask, UUID}
|
||||
|
||||
import com.daml.error.ErrorsAssertions
|
||||
import com.daml.error.utils.ErrorDetails
|
||||
import com.daml.ledger.api.v1.admin.user_management_service.Right
|
||||
import com.daml.ledger.api.v1.transaction_filter.{Filters, TransactionFilter}
|
||||
import com.daml.ledger.api.v1.transaction_service.{
|
||||
GetTransactionsRequest,
|
||||
GetTransactionsResponse,
|
||||
TransactionServiceGrpc,
|
||||
}
|
||||
import com.daml.platform.sandbox.config.SandboxConfig
|
||||
import com.daml.platform.sandbox.services.SubmitAndWaitDummyCommandHelpers
|
||||
import io.grpc.stub.StreamObserver
|
||||
import io.grpc.{Status, StatusRuntimeException}
|
||||
import com.daml.ledger.api.v1.admin.{user_management_service => user_management_service_proto}
|
||||
import scala.concurrent.duration._
|
||||
|
||||
import scala.concurrent.{Future, Promise}
|
||||
|
||||
final class OngoingStreamAuthIT
|
||||
extends ServiceCallAuthTests
|
||||
with SubmitAndWaitDummyCommandHelpers
|
||||
with ErrorsAssertions {
|
||||
|
||||
private val UserManagementCacheExpiryInSeconds = 1
|
||||
|
||||
override protected def config: SandboxConfig = super.config.withUserManagementConfig(
|
||||
_.copy(cacheExpiryAfterWriteInSeconds = UserManagementCacheExpiryInSeconds)
|
||||
)
|
||||
|
||||
override def serviceCallName: String = ""
|
||||
|
||||
override protected def serviceCallWithToken(token: Option[String]): Future[Any] = ???
|
||||
|
||||
private val testId = UUID.randomUUID().toString
|
||||
|
||||
it should "abort an ongoing stream after user state has changed" in {
|
||||
val partyAlice = "alice-party"
|
||||
val userIdAlice = testId + "-alice"
|
||||
val receivedTransactionsCount = new AtomicInteger(0)
|
||||
val transactionStreamAbortedPromise = Promise[Throwable]()
|
||||
|
||||
def observeTransactionsStream(
|
||||
token: Option[String],
|
||||
party: String,
|
||||
): StreamObserver[GetTransactionsResponse] = {
|
||||
val observer = new StreamObserver[GetTransactionsResponse] {
|
||||
override def onNext(value: GetTransactionsResponse): Unit = {
|
||||
val _ = receivedTransactionsCount.incrementAndGet()
|
||||
}
|
||||
|
||||
override def onError(t: Throwable): Unit = {
|
||||
val _ = transactionStreamAbortedPromise.trySuccess(t)
|
||||
}
|
||||
|
||||
override def onCompleted(): Unit = ()
|
||||
}
|
||||
val request = new GetTransactionsRequest(
|
||||
begin = Option(ledgerBegin),
|
||||
end = None,
|
||||
filter = Some(
|
||||
new TransactionFilter(
|
||||
Map(party -> new Filters)
|
||||
)
|
||||
),
|
||||
)
|
||||
stub(TransactionServiceGrpc.stub(channel), token)
|
||||
.getTransactions(request, observer)
|
||||
observer
|
||||
}
|
||||
|
||||
val canActAsAlice = Right(Right.Kind.CanActAs(Right.CanActAs(partyAlice)))
|
||||
for {
|
||||
(userAlice, tokenAlice) <- createUserByAdmin(
|
||||
userId = userIdAlice,
|
||||
rights = Vector(canActAsAlice),
|
||||
)
|
||||
applicationId = userAlice.id
|
||||
submitAndWaitF = () =>
|
||||
submitAndWait(token = tokenAlice, party = partyAlice, applicationId = applicationId)
|
||||
_ <- submitAndWaitF()
|
||||
streamObserver = observeTransactionsStream(tokenAlice, partyAlice)
|
||||
_ <- submitAndWaitF()
|
||||
// Making a change to the user Alice
|
||||
_ <- grantUserRightsByAdmin(
|
||||
userId = userIdAlice,
|
||||
Right(Right.Kind.CanActAs(Right.CanActAs(UUID.randomUUID().toString))),
|
||||
)
|
||||
_ = Thread.sleep(UserManagementCacheExpiryInSeconds.toLong * 1000)
|
||||
//
|
||||
// Timer that finishes the stream in case it isn't aborted as expected
|
||||
timerTask = new TimerTask {
|
||||
override def run(): Unit = streamObserver.onError(
|
||||
new AssertionError("Timed-out waiting while waiting for stream to abort")
|
||||
)
|
||||
}
|
||||
_ = new Timer(true).schedule(timerTask, 100)
|
||||
t <- transactionStreamAbortedPromise.future
|
||||
} yield {
|
||||
timerTask.cancel()
|
||||
t match {
|
||||
case sre: StatusRuntimeException =>
|
||||
assertError(
|
||||
actual = sre,
|
||||
expectedCode = Status.Code.ABORTED,
|
||||
expectedMessage =
|
||||
"STALE_STREAM_AUTHORIZATION(2,0): Stale stream authorization. Retry quickly.",
|
||||
expectedDetails = List(
|
||||
ErrorDetails.ErrorInfoDetail(
|
||||
"STALE_STREAM_AUTHORIZATION",
|
||||
Map(
|
||||
"participantId" -> "'sandbox-participant'",
|
||||
"category" -> "2",
|
||||
"definite_answer" -> "false",
|
||||
),
|
||||
),
|
||||
ErrorDetails.RetryInfoDetail(0.seconds),
|
||||
),
|
||||
)
|
||||
case _ => fail("Unexpected error", t)
|
||||
}
|
||||
assert(receivedTransactionsCount.get() >= 2)
|
||||
}
|
||||
}
|
||||
|
||||
private def grantUserRightsByAdmin(
|
||||
userId: String,
|
||||
right: user_management_service_proto.Right,
|
||||
): Future[Unit] = {
|
||||
val req = user_management_service_proto.GrantUserRightsRequest(userId, Seq(right))
|
||||
stub(
|
||||
user_management_service_proto.UserManagementServiceGrpc.stub(channel),
|
||||
canReadAsAdminStandardJWT,
|
||||
)
|
||||
.grantUserRights(req)
|
||||
.map(_ => ())
|
||||
}
|
||||
|
||||
}
|
@ -14,8 +14,8 @@ final class SubmitAndWaitAuthIT
|
||||
override def serviceCallName: String = "CommandService#SubmitAndWait"
|
||||
|
||||
override def serviceCallWithToken(token: Option[String]): Future[Any] =
|
||||
submitAndWait(token)
|
||||
submitAndWait(token, party = mainActor)
|
||||
|
||||
override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] =
|
||||
submitAndWait(token, "")
|
||||
submitAndWait(token, "", party = mainActor)
|
||||
}
|
||||
|
@ -14,9 +14,9 @@ final class SubmitAndWaitForTransactionAuthIT
|
||||
override def serviceCallName: String = "CommandService#SubmitAndWaitForTransaction"
|
||||
|
||||
override def serviceCallWithToken(token: Option[String]): Future[Any] =
|
||||
submitAndWaitForTransaction(token)
|
||||
submitAndWaitForTransaction(token, party = mainActor)
|
||||
|
||||
override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] =
|
||||
submitAndWaitForTransaction(token, "")
|
||||
submitAndWaitForTransaction(token, "", party = mainActor)
|
||||
|
||||
}
|
||||
|
@ -14,9 +14,9 @@ final class SubmitAndWaitForTransactionIdAuthIT
|
||||
override def serviceCallName: String = "CommandService#SubmitAndWaitForTransactionId"
|
||||
|
||||
override def serviceCallWithToken(token: Option[String]): Future[Any] =
|
||||
submitAndWaitForTransactionId(token)
|
||||
submitAndWaitForTransactionId(token, party = mainActor)
|
||||
|
||||
override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] =
|
||||
submitAndWaitForTransactionId(token, "")
|
||||
submitAndWaitForTransactionId(token, "", party = mainActor)
|
||||
|
||||
}
|
||||
|
@ -14,9 +14,9 @@ final class SubmitAndWaitForTransactionTreeAuthIT
|
||||
override def serviceCallName: String = "CommandService#SubmitAndWaitForTransactionTree"
|
||||
|
||||
override def serviceCallWithToken(token: Option[String]): Future[Any] =
|
||||
submitAndWaitForTransactionTree(token)
|
||||
submitAndWaitForTransactionTree(token, party = mainActor)
|
||||
|
||||
override def serviceCallWithoutApplicationId(token: Option[String]): Future[Any] =
|
||||
submitAndWaitForTransactionTree(token, "")
|
||||
submitAndWaitForTransactionTree(token, "", party = mainActor)
|
||||
|
||||
}
|
||||
|
@ -406,8 +406,8 @@ class CommonCliBase(name: LedgerName) {
|
||||
.optional()
|
||||
.text(
|
||||
s"Defaults to ${UserManagementConfig.DefaultCacheExpiryAfterWriteInSeconds} seconds. " +
|
||||
// TODO participant user management: Update max delay to 2x the configured value when made use of in throttled stream authorization.
|
||||
"Determines the maximum delay for propagating user management state changes."
|
||||
"Used to set expiry time for user management cache. " +
|
||||
"Also determines the maximum delay for propagating user management state changes which is double its value."
|
||||
)
|
||||
.action((value, config: SandboxConfig) =>
|
||||
config.withUserManagementConfig(_.copy(cacheExpiryAfterWriteInSeconds = value))
|
||||
|
@ -264,6 +264,7 @@ object SandboxOnXRunner {
|
||||
v1 = ExperimentalContractIds.ContractIdV1Support.NON_SUFFIXED
|
||||
),
|
||||
),
|
||||
userManagementConfig = config.userManagementConfig,
|
||||
)
|
||||
|
||||
private def buildIndexerServer(
|
||||
|
@ -291,6 +291,7 @@ class Runner(config: SandboxConfig) extends ResourceOwner[Port] {
|
||||
v1 = ExperimentalContractIds.ContractIdV1Support.NON_SUFFIXED
|
||||
),
|
||||
),
|
||||
userManagementConfig = config.userManagementConfig,
|
||||
)
|
||||
_ = apiServerServicesClosed.completeWith(apiServer.servicesClosed())
|
||||
} yield {
|
||||
|
Loading…
Reference in New Issue
Block a user