Implement PoC of user management for Ledger API server (fix #12014) (#12063)

CHANGELOG_BEGIN
- [User Management]: add support for managing participant node users and authenticating
  requests as these users using standard JWT tokens.
CHANGELOG_END

Co-authored-by: Marton Nagy <marton.nagy@digitalasset.com>
Co-authored-by: Adriaan Moors <90182053+adriaanm-da@users.noreply.github.com>
This commit is contained in:
Simon Meier 2021-12-13 17:58:30 +01:00 committed by GitHub
parent 787dccb3d5
commit f223528bfb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
50 changed files with 1443 additions and 215 deletions

View File

@ -63,8 +63,8 @@ ScriptTest:testQueryContractId SUCCESS
ScriptTest:testQueryContractKey SUCCESS
ScriptTest:testSetTime SUCCESS
ScriptTest:testStack SUCCESS
ScriptTest:testUserManagement FAILURE (com.daml.lf.engine.script.ScriptF$FailedCmd: Command listUsers failed: UNIMPLEMENTED: Method not found: com.daml.ledger.api.v1.admin.UserManagementService/ListUsers
ScriptTest:testUserRightManagement FAILURE (com.daml.lf.engine.script.ScriptF$FailedCmd: Command createUser failed: UNIMPLEMENTED: Method not found: com.daml.ledger.api.v1.admin.UserManagementService/CreateUser
ScriptTest:testUserManagement SUCCESS
ScriptTest:testUserRightManagement SUCCESS
ScriptTest:traceOrder SUCCESS
ScriptTest:tree SUCCESS
ScriptTest:tupleKey SUCCESS

View File

@ -108,7 +108,7 @@ final class AuthSpec
}
it should "succeed if the proper token is provided" in {
setToken(toHeader(operatorPayload))
setToken(customTokenToHeader(operatorPayload))
extractor(withAuth).run().map(_ => succeed)
}
@ -140,7 +140,7 @@ final class AuthSpec
Future.successful(Option(lastOffset.get()))
}
)
setToken(toHeader(expiringIn(Duration.ofSeconds(5), operatorPayload)))
setToken(customTokenToHeader(expiringIn(Duration.ofSeconds(5), operatorPayload)))
val _ = process.run()
val expectedTxs = ListBuffer.empty[String]
Delayed.Future
@ -148,7 +148,7 @@ final class AuthSpec
newSyncClient
.submitAndWaitForTransactionId(
SubmitAndWaitRequest(commands = dummyRequest.commands),
Option(toHeader(operatorPayload)),
Option(customTokenToHeader(operatorPayload)),
)
.map(_.transactionId)
}
@ -156,13 +156,13 @@ final class AuthSpec
case Success(tx) => val _ = expectedTxs += tx
case Failure(_) => () // do nothing, the test will fail
}
Delayed.by(15.seconds)(setToken(toHeader(operatorPayload)))
Delayed.by(15.seconds)(setToken(customTokenToHeader(operatorPayload)))
Delayed.Future
.by(20.seconds) {
newSyncClient
.submitAndWaitForTransactionId(
SubmitAndWaitRequest(commands = dummyRequest.commands),
Option(toHeader(operatorPayload)),
Option(customTokenToHeader(operatorPayload)),
)
.map(_.transactionId)
}

View File

@ -76,6 +76,7 @@ da_scala_library(
"//ledger/error",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-common",
"//ledger/participant-state-index",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:io_grpc_grpc_core",
"@maven//:io_grpc_grpc_netty",

View File

@ -4,10 +4,10 @@
package com.daml.ledger.rxjava.grpc.helpers
import com.daml.error.ErrorCodesVersionSwitcher
import java.net.{InetSocketAddress, SocketAddress}
import java.time.{Clock, Duration}
import java.util.concurrent.TimeUnit
import com.daml.ledger.rxjava.grpc._
import com.daml.ledger.rxjava.grpc.helpers.TransactionsServiceImpl.LedgerItem
import com.daml.ledger.rxjava.{CommandCompletionClient, LedgerConfigurationClient, PackageClient}
@ -31,6 +31,7 @@ import com.daml.ledger.api.v1.package_service.{
ListPackagesResponse,
}
import com.daml.ledger.api.v1.testing.time_service.GetTimeResponse
import com.daml.ledger.participant.state.index.impl.inmemory.InMemoryUserManagementStore
import com.google.protobuf.empty.Empty
import io.grpc._
import io.grpc.netty.NettyServerBuilder
@ -93,6 +94,7 @@ final class LedgerServices(val ledgerId: String) {
): Server = {
val authorizationInterceptor = AuthorizationInterceptor(
authService,
new InMemoryUserManagementStore(),
executionContext,
new ErrorCodesVersionSwitcher(enableSelfServiceErrorCodes = true),
)

View File

@ -47,22 +47,24 @@ package object rxjava {
private[rxjava] val mockedAuthService =
AuthServiceStatic {
case `emptyToken` => ClaimSet.Unauthenticated
case `publicToken` => ClaimSet.Claims(Seq[Claim](ClaimPublic))
case `adminToken` => ClaimSet.Claims(Seq[Claim](ClaimAdmin))
case `publicToken` => ClaimSet.Claims.Empty.copy(claims = Seq[Claim](ClaimPublic))
case `adminToken` => ClaimSet.Claims.Empty.copy(claims = Seq[Claim](ClaimAdmin))
case `somePartyReadToken` =>
ClaimSet.Claims(
Seq[Claim](ClaimPublic, ClaimReadAsParty(Ref.Party.assertFromString(someParty)))
ClaimSet.Claims.Empty.copy(
claims = Seq[Claim](ClaimPublic, ClaimReadAsParty(Ref.Party.assertFromString(someParty)))
)
case `somePartyReadWriteToken` =>
ClaimSet.Claims(
Seq[Claim](ClaimPublic, ClaimActAsParty(Ref.Party.assertFromString(someParty)))
ClaimSet.Claims.Empty.copy(
claims = Seq[Claim](ClaimPublic, ClaimActAsParty(Ref.Party.assertFromString(someParty)))
)
case `someOtherPartyReadToken` =>
ClaimSet.Claims(
ClaimSet.Claims.Empty.copy(
claims =
Seq[Claim](ClaimPublic, ClaimReadAsParty(Ref.Party.assertFromString(someOtherParty)))
)
case `someOtherPartyReadWriteToken` =>
ClaimSet.Claims(
ClaimSet.Claims.Empty.copy(
claims =
Seq[Claim](ClaimPublic, ClaimActAsParty(Ref.Party.assertFromString(someOtherParty)))
)
}

View File

@ -9,8 +9,6 @@ option java_outer_classname = "UserManagementServiceOuterClass";
option java_package = "com.daml.ledger.api.v1.admin";
option csharp_namespace = "Com.Daml.Ledger.Api.V1.Admin";
import "google/protobuf/empty.proto";
// Experimental API to manage users and their rights for interacting with the Ledger API
// served by a participant node.
@ -135,7 +133,7 @@ message DeleteUserResponse {
// Required authorization: ``HasRight(ParticipantAdmin)``
message ListUsersRequest {
// TODO: add pagination, cf. https://cloud.google.com/apis/design/design_patterns#list_pagination
// TODO (i12052): add pagination following https://cloud.google.com/apis/design/design_patterns#list_pagination
}
message ListUsersResponse {
@ -176,7 +174,7 @@ message ListUserRightsRequest {
// If set to empty string (the default), then the rights for the authenticated user will be listed.
string user_id = 1;
// TODO: add pagination, cf. https://cloud.google.com/apis/design/design_patterns#list_pagination
// TODO (i12052): add pagination following https://cloud.google.com/apis/design/design_patterns#list_pagination
}
message ListUserRightsResponse {

View File

@ -46,7 +46,7 @@ final class AuthorizationTest
private val publicToken = "public"
private val emptyToken = "empty"
private val mockedAuthService = Option(AuthServiceStatic {
case `publicToken` => ClaimSet.Claims(Seq[Claim](ClaimPublic))
case `publicToken` => ClaimSet.Claims.Empty.copy(claims = Seq[Claim](ClaimPublic))
case `emptyToken` => ClaimSet.Unauthenticated
})

View File

@ -40,4 +40,7 @@ object ErrorResource {
object Party extends ErrorResource {
def asString: String = "PARTY"
}
object User extends ErrorResource {
def asString: String = "USER"
}
}

View File

@ -15,6 +15,7 @@ object ErrorGroups {
abstract class LedgerApiErrorGroup extends ErrorGroup() {
abstract class CommandExecutionErrorGroup extends ErrorGroup()
abstract class PackageServiceErrorGroup extends ErrorGroup()
abstract class UserManagementServiceErrorGroup extends ErrorGroup()
}
}
}

View File

@ -677,6 +677,53 @@ object LedgerApiErrors extends LedgerApiErrorGroup {
cause = _message
)
}
@Explanation("""The user referred to by the request was not found, which may be due to:
|
|1. Connecting to the wrong participant node, as users are a participant local concept.
|2. The user-id being misspelled.
|3. The user not yet having been created.
|4. The user having been deleted.
|""")
@Resolution(
"""Check that you are connecting to the right participant node and the user-id is spelled correctly,
|if yes, create the user.
|"""
)
object UserNotFound
extends ErrorCode(
id = "USER_NOT_FOUND",
ErrorCategory.InvalidGivenCurrentSystemStateResourceMissing,
) {
case class Reject(_operation: String, userId: String)(implicit
loggingContext: ContextualizedErrorLogger
) extends LoggingTransactionErrorImpl(
cause = s"cannot ${_operation} for unknown user \"${userId}\"."
// TODO (i12053): also output participantId
) {
override def resources: Seq[(ErrorResource, String)] = Seq(
ErrorResource.User -> userId
)
}
}
@Explanation("There already exists another user with the same user-id.")
@Resolution("Choose a different user-id or use the user that already exists.")
object UserAlreadyExists
extends ErrorCode(
id = "USER_ALREADY_EXISTS",
ErrorCategory.InvalidGivenCurrentSystemStateResourceExists,
) {
case class Reject(_operation: String, userId: String)(implicit
loggingContext: ContextualizedErrorLogger
) extends LoggingTransactionErrorImpl(
cause = s"cannot ${_operation}, as user \"${userId}\" already exists."
// TODO (i12053): also output participantId
) {
override def resources: Seq[(ErrorResource, String)] = Seq(
ErrorResource.User -> userId
)
}
}
}
@Explanation(

View File

@ -29,6 +29,8 @@ da_scala_library(
"//ledger-service/jwt",
"//ledger/error",
"//ledger/ledger-api-common",
"//ledger/ledger-api-domain",
"//ledger/participant-state-index",
"//libs-scala/contextualized-logging",
"@maven//:com_auth0_java_jwt",
"@maven//:io_grpc_grpc_api",
@ -67,6 +69,7 @@ da_scala_test_suite(
deps = [
":ledger-api-auth",
"//ledger/error",
"//ledger/participant-state-index",
"//ledger/test-common",
"@maven//:com_google_api_grpc_proto_google_common_protos",
"@maven//:com_google_protobuf_protobuf_java",

View File

@ -42,14 +42,14 @@ class AuthServiceJWT(verifier: JwtVerifierBase) extends AuthService {
token => payloadToClaims(token),
)
private[this] def parsePayload(jwtPayload: String): Either[Error, AuthServiceJWTPayload] = {
import AuthServiceJWTCodec.JsonImplicits._
Try(JsonParser(jwtPayload).convertTo[AuthServiceJWTPayload]).toEither.left.map(t =>
private[this] def parsePayload(jwtPayload: String): Either[Error, SupportedJWTPayload] = {
import SupportedJWTCodec.JsonImplicits._
Try(JsonParser(jwtPayload).convertTo[SupportedJWTPayload]).toEither.left.map(t =>
Error("Could not parse JWT token: " + t.getMessage)
)
}
private[this] def parseJWTPayload(header: String): Either[Error, AuthServiceJWTPayload] = {
private[this] def parseJWTPayload(header: String): Either[Error, SupportedJWTPayload] = {
val BearerTokenRegex = "Bearer (.*)".r
for {
@ -66,7 +66,8 @@ class AuthServiceJWT(verifier: JwtVerifierBase) extends AuthService {
} yield parsed
}
private[this] def payloadToClaims(payload: AuthServiceJWTPayload): ClaimSet.Claims = {
private[this] def payloadToClaims(payload: SupportedJWTPayload): ClaimSet = payload match {
case CustomDamlJWTPayload(payload) =>
val claims = ListBuffer[Claim]()
// Any valid token authorizes the user to use public services
@ -87,6 +88,14 @@ class AuthServiceJWT(verifier: JwtVerifierBase) extends AuthService {
participantId = payload.participantId,
applicationId = payload.applicationId,
expiration = payload.exp,
resolvedFromUser = false,
)
case StandardJWTPayload(payload) =>
ClaimSet.AuthenticatedUser(
participantId = payload.participantId,
userId = payload.applicationId.get,
expiration = payload.exp,
)
}
}

View File

@ -111,6 +111,13 @@ object AuthServiceJWTCodec {
propExp -> writeOptionalInstant(v.exp),
)
def writeStandardTokenPayload(v: AuthServiceJWTPayload): JsValue =
JsObject(
"aud" -> writeOptionalString(v.participantId),
"sub" -> writeOptionalString(v.applicationId),
"exp" -> writeOptionalInstant(v.exp),
)
/** Writes the given payload to a compact JSON string */
def compactPrint(v: AuthServiceJWTPayload): String = writePayload(v).compactPrint
@ -127,10 +134,9 @@ object AuthServiceJWTCodec {
// Decoding
// ------------------------------------------------------------------------------------------------------------------
def readFromString(value: String): Try[AuthServiceJWTPayload] = {
import AuthServiceJWTCodec.JsonImplicits._
for {
json <- Try(value.parseJson)
parsed <- Try(json.convertTo[AuthServiceJWTPayload])
parsed <- Try(readPayload(json))
} yield parsed
}
@ -176,6 +182,25 @@ object AuthServiceJWTCodec {
)
}
def readStandardTokenPayload(jsValue: JsValue): Option[AuthServiceJWTPayload] = jsValue match {
// NOTE: there is the corner-case of a legacy Daml token containing a "sub" field.
// We accept that risk.
case JsObject(fields) if !fields.contains(oidcNamespace) && fields.contains("sub") =>
Some(
AuthServiceJWTPayload(
ledgerId = None,
// TODO (i12054): allow for an array of audiences
participantId = readOptionalString("aud", fields),
applicationId = readOptionalString("sub", fields),
exp = readInstant("exp", fields),
admin = false,
actAs = List.empty,
readAs = List.empty,
)
)
case _ => None
}
private[this] def readOptionalString(name: String, fields: Map[String, JsValue]): Option[String] =
fields.get(name) match {
case None => None

View File

@ -192,14 +192,28 @@ final class Authorizer(
},
)
private def authenticatedClaimsFromContext(): Try[ClaimSet.Claims] =
/** Directly access the authenticated claims from the thread-local context.
*
* Prefer to use the more specialized methods of [[Authorizer]] instead of this
* method to avoid skipping required authorization checks.
*/
def authenticatedClaimsFromContext(): Try[ClaimSet.Claims] =
AuthorizationInterceptor
.extractClaimSetFromContext()
.fold[Try[ClaimSet.Claims]](Failure(errorFactories.unauthenticatedMissingJwtToken())) {
.flatMap({
case ClaimSet.Unauthenticated =>
Failure(errorFactories.unauthenticatedMissingJwtToken())
case authenticatedUser: ClaimSet.AuthenticatedUser =>
Failure(
errorFactories.internalAuthenticationError(
s"Unexpected unresolved authenticated user claim",
new RuntimeException(
s"Unexpected unresolved authenticated user claim for user '${authenticatedUser.userId}"
),
)
)
case claims: ClaimSet.Claims => Success(claims)
}
})
private def authorize[Req, Res](call: (Req, ServerCallStreamObserver[Res]) => Unit)(
authorized: ClaimSet.Claims => Either[AuthorizationError, Unit]
@ -235,24 +249,17 @@ final class Authorizer(
private[auth] def authorize[Req, Res](call: Req => Future[Res])(
authorized: ClaimSet.Claims => Either[AuthorizationError, Unit]
): Req => Future[Res] = request =>
authenticatedClaimsFromContext()
.fold(
ex => {
// TODO error codes: Remove once fully relying on self-service error codes with logging on creation
logger.debug(
s"No authenticated claims found in the request context. Returning UNAUTHENTICATED"
)
Future.failed(ex)
},
claims =>
authenticatedClaimsFromContext() match {
case Failure(ex) => Future.failed(ex)
case Success(claims) =>
authorized(claims) match {
case Right(_) => call(request)
case Left(authorizationError) =>
Future.failed(
errorFactories.permissionDenied(authorizationError.reason)
)
},
)
}
}
}
object Authorizer {

View File

@ -86,13 +86,15 @@ object ClaimSet {
* @param participantId If set, the claims will only be valid on the given participant identifier.
* @param applicationId If set, the claims will only be valid on the given application identifier.
* @param expiration If set, the claims will cease to be valid at the given time.
* @param resolvedFromUser If set, then the claims were resolved from a user in the user management service.
*/
final case class Claims(
claims: Seq[Claim],
ledgerId: Option[String] = None,
participantId: Option[String] = None,
applicationId: Option[String] = None,
expiration: Option[Instant] = None,
ledgerId: Option[String],
participantId: Option[String],
applicationId: Option[String],
expiration: Option[Instant],
resolvedFromUser: Boolean,
) extends ClaimSet {
def validForLedger(id: String): Either[AuthorizationError, Unit] =
Either.cond(ledgerId.forall(_ == id), (), AuthorizationError.InvalidLedger(ledgerId.get, id))
@ -155,6 +157,13 @@ object ClaimSet {
}
}
/** The representation of a user that was authenticated, but whose [[Claims]] have not yet been resolved. */
final case class AuthenticatedUser(
userId: String, // TODO (i12049): use Ref.UserId here
participantId: Option[String],
expiration: Option[Instant],
) extends ClaimSet
object Claims {
/** A set of [[Claims]] that does not have any authorization */
@ -164,6 +173,7 @@ object ClaimSet {
participantId = None,
applicationId = None,
expiration = None,
resolvedFromUser = false,
)
/** A set of [[Claims]] that has all possible authorizations */

View File

@ -0,0 +1,47 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.api.auth
import com.daml.ledger.api.auth.AuthServiceJWTCodec.{
readPayload,
readStandardTokenPayload,
writePayload,
writeStandardTokenPayload,
}
import spray.json.{DefaultJsonProtocol, JsValue, RootJsonFormat}
/** A wrapper class to add support for standard JWT tokens alongside the existing custom tokens in a minimally invasive way.
*
* The main problem is that the [[AuthServiceJWTPayload]] class is used by other applications
* like the JSON-API to parse tokens; and we don't want to meddle with that code as part of the PoC.
*/
// TODO (i12049): clarify naming and inline case classes where possible.
sealed trait SupportedJWTPayload
final case class CustomDamlJWTPayload(payload: AuthServiceJWTPayload) extends SupportedJWTPayload
final case class StandardJWTPayload(payload: AuthServiceJWTPayload) extends SupportedJWTPayload
/** JSON codecs for all the formats of JWTs supported by the [[AuthServiceJWT]]. */
object SupportedJWTCodec {
/** Writes the given payload to a compact JSON string */
def compactPrint(v: SupportedJWTPayload): String =
JsonImplicits.SupportedJWTPayloadFormat.write(v).compactPrint
// ------------------------------------------------------------------------------------------------------------------
// Implicits that can be imported to write JSON
// ------------------------------------------------------------------------------------------------------------------
object JsonImplicits extends DefaultJsonProtocol {
implicit object SupportedJWTPayloadFormat extends RootJsonFormat[SupportedJWTPayload] {
override def write(v: SupportedJWTPayload): JsValue = v match {
case CustomDamlJWTPayload(payload) => writePayload(payload)
case StandardJWTPayload(payload) => writeStandardTokenPayload(payload)
}
override def read(json: JsValue): SupportedJWTPayload =
readStandardTokenPayload(json) match {
case Some(payload) => StandardJWTPayload(payload)
case None => CustomDamlJWTPayload(readPayload(json))
}
}
}
}

View File

@ -4,21 +4,25 @@
package com.daml.ledger.api.auth.interceptor
import com.daml.error.{DamlContextualizedErrorLogger, ErrorCodesVersionSwitcher}
import com.daml.ledger.api.auth.{AuthService, ClaimSet}
import com.daml.ledger.api.auth._
import com.daml.ledger.api.domain.UserRight
import com.daml.ledger.participant.state.index.v2.UserManagementStore
import com.daml.lf.data.Ref
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.platform.server.api.validation.ErrorFactories
import io.grpc._
import scala.compat.java8.FutureConverters
import scala.concurrent.ExecutionContext
import scala.util.{Failure, Success}
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success, Try}
/** This interceptor uses the given [[AuthService]] to get [[Claims]] for the current request,
* and then stores them in the current [[Context]].
*/
final class AuthorizationInterceptor(
protected val authService: AuthService,
ec: ExecutionContext,
authService: AuthService,
userManagementService: UserManagementStore,
implicit val ec: ExecutionContext,
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
)(implicit loggingContext: LoggingContext)
extends ServerInterceptor {
@ -42,6 +46,7 @@ final class AuthorizationInterceptor(
new AsyncForwardingListener[ReqT] {
FutureConverters
.toScala(authService.decodeMetadata(headers))
.flatMap(resolveAuthenticatedUserRights)
.onComplete {
case Failure(exception) =>
val error = errorFactories.internalAuthenticationError(
@ -58,24 +63,64 @@ final class AuthorizationInterceptor(
Contexts.interceptCall(nextCtx, call, headers, nextListener)
setNextListener(nextListenerWithContext)
nextListenerWithContext
}(ec)
}
}
}
private[this] def resolveAuthenticatedUserRights(claimSet: ClaimSet): Future[ClaimSet] =
claimSet match {
case ClaimSet.AuthenticatedUser(userId, participantId, expiration) =>
userManagementService
.listUserRights(Ref.UserId.assertFromString(userId))
.map {
case Left(msg) =>
logger.warn(
s"Authorization error: cannot resolve rights for user '$userId' due to $msg."
)
ClaimSet.Unauthenticated
case Right(userClaims) =>
ClaimSet.Claims(
claims = userClaims.view.map(userRightToClaim).toList.prepended(ClaimPublic),
ledgerId = None,
participantId = participantId,
applicationId = Some(userId),
expiration = expiration,
resolvedFromUser = true,
)
}
case _ => Future.successful(claimSet)
}
private[this] def userRightToClaim(r: UserRight): Claim = r match {
case UserRight.CanActAs(p) => ClaimActAsParty(Ref.Party.assertFromString(p))
case UserRight.CanReadAs(p) => ClaimReadAsParty(Ref.Party.assertFromString(p))
case UserRight.ParticipantAdmin => ClaimAdmin
}
}
object AuthorizationInterceptor {
private[auth] val contextKeyClaimSet = Context.key[ClaimSet]("AuthServiceDecodedClaim")
def extractClaimSetFromContext(): Option[ClaimSet] =
Option(contextKeyClaimSet.get())
def extractClaimSetFromContext(): Try[ClaimSet] = {
val claimSet = contextKeyClaimSet.get()
if (claimSet == null)
Failure(
new RuntimeException(
"Thread local context unexpectedly does not store authorization claims. Perhaps a Future was used in some intermediate computation and changed the executing thread?"
)
)
else
Success(claimSet)
}
def apply(
authService: AuthService,
userManagementService: UserManagementStore,
ec: ExecutionContext,
errorCodesStatusSwitcher: ErrorCodesVersionSwitcher,
): AuthorizationInterceptor =
LoggingContext.newLoggingContext { implicit loggingContext: LoggingContext =>
new AuthorizationInterceptor(authService, ec, errorCodesStatusSwitcher)
new AuthorizationInterceptor(authService, userManagementService, ec, errorCodesStatusSwitcher)
}
}

View File

@ -0,0 +1,102 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.api.auth.services
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.error.definitions.LedgerApiErrors
import com.daml.ledger.api.auth._
import com.daml.ledger.api.v1.admin.user_management_service._
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.server.api.ProxyCloseable
import io.grpc.ServerServiceDefinition
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success, Try}
private[daml] final class UserManagementServiceAuthorization(
protected val service: UserManagementServiceGrpc.UserManagementService with AutoCloseable,
private val authorizer: Authorizer,
)(implicit executionContext: ExecutionContext, loggingContext: LoggingContext)
extends UserManagementServiceGrpc.UserManagementService
with ProxyCloseable
with GrpcApiService {
private val logger = ContextualizedLogger.get(this.getClass)
private implicit val errorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
override def createUser(request: CreateUserRequest): Future[User] =
authorizer.requireAdminClaims(service.createUser)(request)
override def getUser(request: GetUserRequest): Future[User] =
defaultToAuthenticatedUser(request.userId) match {
case Failure(ex) => Future.failed(ex)
case Success(Some(userId)) => service.getUser(request.copy(userId = userId))
case Success(None) => authorizer.requireAdminClaims(service.getUser)(request)
}
override def deleteUser(request: DeleteUserRequest): Future[DeleteUserResponse] =
authorizer.requireAdminClaims(service.deleteUser)(request)
override def listUsers(request: ListUsersRequest): Future[ListUsersResponse] =
authorizer.requireAdminClaims(service.listUsers)(request)
override def grantUserRights(request: GrantUserRightsRequest): Future[GrantUserRightsResponse] =
authorizer.requireAdminClaims(service.grantUserRights)(request)
override def revokeUserRights(
request: RevokeUserRightsRequest
): Future[RevokeUserRightsResponse] =
authorizer.requireAdminClaims(service.revokeUserRights)(request)
override def listUserRights(request: ListUserRightsRequest): Future[ListUserRightsResponse] =
defaultToAuthenticatedUser(request.userId) match {
case Failure(ex) => Future.failed(ex)
case Success(Some(userId)) => service.listUserRights(request.copy(userId = userId))
case Success(None) => authorizer.requireAdminClaims(service.listUserRights)(request)
}
override def bindService(): ServerServiceDefinition =
UserManagementServiceGrpc.bindService(this, executionContext)
override def close(): Unit = service.close()
private def defaultToAuthenticatedUser(userId: String): Try[Option[String]] = {
// Note: doing all of this computation within a `Try` instead of a `Future` is very important
// as the authorization claims are stored in thread-local storage; and we thus must avoid switching
// the executing thread.
if (userId.isEmpty) {
authorizer
.authenticatedClaimsFromContext()
.flatMap(claims =>
if (claims.resolvedFromUser)
claims.applicationId match {
case Some(applicationId) => Success(Some(applicationId))
case None =>
Failure(
LedgerApiErrors.AuthorizationChecks.InternalAuthorizationError
.Reject(
"unexpectedly the user-id is not set in authenticated claims",
new RuntimeException(),
)
.asGrpcError
)
}
else {
// This case can be hit both when running without authentication and when using custom Daml tokens.
Failure(
LedgerApiErrors.RequestValidation.InvalidArgument
.Reject(
"requests with an empty user-id are only supported if there is an authenticated user"
)
.asGrpcError
)
}
)
} else
Success(None)
}
}

View File

@ -18,23 +18,34 @@ class AuthServiceJWTCodecSpec
with ScalaCheckDrivenPropertyChecks {
/** Serializes a [[AuthServiceJWTPayload]] to JSON, then parses it back to a AuthServiceJWTPayload */
private def serializeAndParse(value: AuthServiceJWTPayload): Try[AuthServiceJWTPayload] = {
import AuthServiceJWTCodec.JsonImplicits._
private def serializeAndParse(value: SupportedJWTPayload): Try[SupportedJWTPayload] = {
import SupportedJWTCodec.JsonImplicits._
for {
serialized <- Try(value.toJson.prettyPrint)
json <- Try(serialized.parseJson)
parsed <- Try(json.convertTo[AuthServiceJWTPayload])
parsed <- Try(json.convertTo[SupportedJWTPayload])
} yield parsed
}
/** Parses a [[AuthServiceJWTPayload]] */
private def parse(serialized: String): Try[AuthServiceJWTPayload] = {
import AuthServiceJWTCodec.JsonImplicits._
/** Parses a string as a [[CustomDamlJWTPayload]] */
private def parseCustomToken(serialized: String): Try[CustomDamlJWTPayload] = {
import SupportedJWTCodec.JsonImplicits._
for {
json <- Try(serialized.parseJson)
parsed <- Try(json.convertTo[AuthServiceJWTPayload])
// FIXME: understand how to avoid unwrapping and re-wrapping the `parsed` here
CustomDamlJWTPayload(parsed) <- Try(json.convertTo[SupportedJWTPayload])
} yield CustomDamlJWTPayload(parsed)
}
/** Parses a [[SupportedJWTPayload]] */
private def parse(serialized: String): Try[SupportedJWTPayload] = {
import SupportedJWTCodec.JsonImplicits._
for {
json <- Try(serialized.parseJson)
parsed <- Try(json.convertTo[SupportedJWTPayload])
} yield parsed
}
@ -52,10 +63,31 @@ class AuthServiceJWTCodecSpec
"serializing and parsing a value" should {
"work for arbitrary values" in forAll(
"work for arbitrary custom Daml token values" in forAll(
Gen.resultOf(AuthServiceJWTPayload),
minSuccessful(100),
)(value => serializeAndParse(value) shouldBe Success(value))
)(v0 => {
val value = CustomDamlJWTPayload(v0)
serializeAndParse(value) shouldBe Success(value)
})
"work for arbitrary standard Daml token values" in forAll(
Gen.resultOf(AuthServiceJWTPayload),
minSuccessful(100),
)(v0 => {
val value = StandardJWTPayload(
AuthServiceJWTPayload(
ledgerId = None,
participantId = v0.participantId,
applicationId = Some(v0.applicationId.getOrElse("default-user")),
exp = v0.exp,
admin = false,
actAs = List.empty,
readAs = List.empty,
)
)
serializeAndParse(value) shouldBe Success(value)
})
"support OIDC compliant sandbox format" in {
val serialized =
@ -71,7 +103,8 @@ class AuthServiceJWTCodecSpec
| "exp": 0
|}
""".stripMargin
val expected = AuthServiceJWTPayload(
val expected = CustomDamlJWTPayload(
AuthServiceJWTPayload(
ledgerId = Some("someLedgerId"),
participantId = Some("someParticipantId"),
applicationId = Some("someApplicationId"),
@ -80,9 +113,10 @@ class AuthServiceJWTCodecSpec
actAs = List("Alice"),
readAs = List("Alice", "Bob"),
)
val result = parse(serialized)
)
val result = parseCustomToken(serialized)
result shouldBe Success(expected)
result.map(_.party) shouldBe Success(None)
result.map(_.payload.party) shouldBe Success(None)
}
"support legacy sandbox format" in {
@ -97,7 +131,8 @@ class AuthServiceJWTCodecSpec
| "readAs": ["Alice", "Bob"]
|}
""".stripMargin
val expected = AuthServiceJWTPayload(
val expected = CustomDamlJWTPayload(
AuthServiceJWTPayload(
ledgerId = Some("someLedgerId"),
participantId = Some("someParticipantId"),
applicationId = Some("someApplicationId"),
@ -106,9 +141,10 @@ class AuthServiceJWTCodecSpec
actAs = List("Alice"),
readAs = List("Alice", "Bob"),
)
val result = parse(serialized)
)
val result = parseCustomToken(serialized)
result shouldBe Success(expected)
result.map(_.party) shouldBe Success(None)
result.map(_.payload.party) shouldBe Success(None)
}
"support legacy JSON API format" in {
@ -119,7 +155,8 @@ class AuthServiceJWTCodecSpec
| "party": "Alice"
|}
""".stripMargin
val expected = AuthServiceJWTPayload(
val expected = CustomDamlJWTPayload(
AuthServiceJWTPayload(
ledgerId = Some("someLedgerId"),
participantId = None,
applicationId = Some("someApplicationId"),
@ -128,14 +165,39 @@ class AuthServiceJWTCodecSpec
actAs = List("Alice"),
readAs = List.empty,
)
)
val result = parseCustomToken(serialized)
result shouldBe Success(expected)
result.map(_.payload.party) shouldBe Success(Some("Alice"))
}
"support standard JWT claims" in {
val serialized =
"""{
| "aud": "someParticipantId",
| "sub": "someUserId",
| "exp": 100
|}
""".stripMargin
val expected = StandardJWTPayload(
AuthServiceJWTPayload(
ledgerId = None,
participantId = Some("someParticipantId"),
applicationId = Some("someUserId"),
exp = Some(Instant.ofEpochSecond(100)),
admin = false,
actAs = List.empty,
readAs = List.empty,
)
)
val result = parse(serialized)
result shouldBe Success(expected)
result.map(_.party) shouldBe Success(Some("Alice"))
}
"have stable default values" in {
val serialized = "{}"
val expected = AuthServiceJWTPayload(
val expected = CustomDamlJWTPayload(
AuthServiceJWTPayload(
ledgerId = None,
participantId = None,
applicationId = None,
@ -144,9 +206,10 @@ class AuthServiceJWTCodecSpec
actAs = List.empty,
readAs = List.empty,
)
val result = parse(serialized)
)
val result = parseCustomToken(serialized)
result shouldBe Success(expected)
result.map(_.party) shouldBe Success(None)
result.map(_.payload.party) shouldBe Success(None)
}
}

View File

@ -12,8 +12,10 @@ import org.mockito.{ArgumentMatchersSugar, MockitoSugar}
import org.scalatest.Assertion
import org.scalatest.flatspec.AsyncFlatSpec
import org.scalatest.matchers.should.Matchers
import java.util.concurrent.CompletableFuture
import com.daml.ledger.participant.state.index.v2.UserManagementStore
import scala.concurrent.ExecutionContext.global
import scala.concurrent.Promise
import scala.util.Success
@ -49,6 +51,7 @@ class AuthorizationInterceptorSpec
usesSelfServiceErrorCodes: Boolean
)(assertRpcStatus: (Status, Metadata) => Assertion) = {
val authService = mock[AuthService]
val userManagementService = mock[UserManagementStore]
val serverCall = mock[ServerCall[Nothing, Nothing]]
val failedMetadataDecode = CompletableFuture.supplyAsync[ClaimSet](() =>
throw new RuntimeException("some internal failure")
@ -63,7 +66,7 @@ class AuthorizationInterceptorSpec
val errorCodesStatusSwitcher = new ErrorCodesVersionSwitcher(usesSelfServiceErrorCodes)
val authorizationInterceptor =
AuthorizationInterceptor(authService, global, errorCodesStatusSwitcher)
AuthorizationInterceptor(authService, userManagementService, global, errorCodesStatusSwitcher)
val statusCaptor = ArgCaptor[Status]
val metadataCaptor = ArgCaptor[Metadata]

View File

@ -35,20 +35,12 @@ class AuthorizerSpec extends AsyncFlatSpec with Matchers {
behavior of s"$className.authorize (V1 error codes)"
it should "return unauthenticated if missing claims" in {
testUnauthenticated(selfServiceErrorCodes = false)
}
it should "return permission denied on authorization error" in {
testPermissionDenied(selfServiceErrorCodes = false)
}
behavior of s"$className.authorize (V2 error codes)"
it should "return unauthenticated if missing claims" in {
testUnauthenticated(selfServiceErrorCodes = true)
}
it should "return permission denied on authorization error" in {
testPermissionDenied(selfServiceErrorCodes = true)
}
@ -63,16 +55,6 @@ class AuthorizerSpec extends AsyncFlatSpec with Matchers {
)
)
private def testUnauthenticated(selfServiceErrorCodes: Boolean) =
contextWithoutClaims {
authorizer(selfServiceErrorCodes).authorize(dummyReqRes)(allAuthorized)(dummyRequest)
}
.transform(
assertExpectedFailure(selfServiceErrorCodes = selfServiceErrorCodes)(
Status.UNAUTHENTICATED.getCode
)
)
private def assertExpectedFailure[T](
selfServiceErrorCodes: Boolean
)(expectedStatusCode: Status.Code): Try[T] => Try[Assertion] = {
@ -85,8 +67,6 @@ class AuthorizerSpec extends AsyncFlatSpec with Matchers {
case ex => fail(s"Expected a failure with StatusRuntimeException but got $ex")
}
private def contextWithoutClaims[R](f: => R): R = io.grpc.Context.ROOT.call(() => f)
private def contextWithClaims[R](f: => R): R =
io.grpc.Context.ROOT
.withValue(AuthorizationInterceptor.contextKeyClaimSet, ClaimSet.Claims.Wildcard)

View File

@ -40,7 +40,7 @@ final class LedgerClientAuthIT
)
private val ClientConfiguration = ClientConfigurationWithoutToken.copy(
token = Some(toHeader(readOnlyToken("Read-only party")))
token = Some(customTokenToHeader(readOnlyToken("Read-only party")))
)
override protected def config: SandboxConfig = super.config.copy(
@ -78,7 +78,7 @@ final class LedgerClientAuthIT
.allocateParty(
hint = Some(partyName),
displayName = Some(partyName),
token = Some(toHeader(adminToken)),
token = Some(customTokenToHeader(adminToken)),
)
} yield {
allocatedParty.displayName should be(Some(partyName))

View File

@ -71,6 +71,17 @@ class FieldValidations private (errorFactories: ErrorFactories) {
} yield parties + party
}
def requireUserId(
s: String,
fieldName: String,
)(implicit
contextualizedErrorLogger: ContextualizedErrorLogger
): Either[StatusRuntimeException, Ref.UserId] =
Ref.UserId.fromString(s) match {
case Right(userId) => Right(userId)
case Left(msg) => Left(invalidField(fieldName, msg, definiteAnswer = Some(false)))
}
def requireLedgerString(
s: String,
fieldName: String,
@ -92,12 +103,10 @@ class FieldValidations private (errorFactories: ErrorFactories) {
def validateSubmissionId(s: String)(implicit
contextualizedErrorLogger: ContextualizedErrorLogger
): Either[StatusRuntimeException, Option[domain.SubmissionId]] =
if (s.isEmpty) {
Right(None)
} else {
optionalString(s) { nonEmptyString =>
Ref.SubmissionId
.fromString(s)
.map(submissionId => Some(domain.SubmissionId(submissionId)))
.fromString(nonEmptyString)
.map(domain.SubmissionId(_))
.left
.map(invalidField("submission_id", _, definiteAnswer = Some(false)))
}
@ -144,6 +153,11 @@ class FieldValidations private (errorFactories: ErrorFactories) {
en <- requireDottedName(identifier.entityName, "entity_name")
} yield Ref.Identifier(packageId, Ref.QualifiedName(mn, en))
def optionalString[T](s: String)(
someValidation: String => Either[StatusRuntimeException, T]
): Either[StatusRuntimeException, Option[T]] =
if (s.isEmpty) Right(None)
else someValidation(s).map(Option(_))
}
object FieldValidations {

View File

@ -13,6 +13,8 @@ import com.daml.ledger.api.v1.admin.participant_pruning_service.ParticipantPruni
import com.daml.ledger.api.v1.admin.participant_pruning_service.ParticipantPruningServiceGrpc.ParticipantPruningService
import com.daml.ledger.api.v1.admin.party_management_service.PartyManagementServiceGrpc
import com.daml.ledger.api.v1.admin.party_management_service.PartyManagementServiceGrpc.PartyManagementService
import com.daml.ledger.api.v1.admin.user_management_service.UserManagementServiceGrpc
import com.daml.ledger.api.v1.admin.user_management_service.UserManagementServiceGrpc.UserManagementService
import com.daml.ledger.api.v1.command_completion_service.CommandCompletionServiceGrpc
import com.daml.ledger.api.v1.command_completion_service.CommandCompletionServiceGrpc.CommandCompletionService
import com.daml.ledger.api.v1.command_service.CommandServiceGrpc
@ -84,4 +86,7 @@ private[infrastructure] final class LedgerServices(
val version: VersionService =
VersionServiceGrpc.stub(participant)
val userManagement: UserManagementService =
UserManagementServiceGrpc.stub(participant)
}

View File

@ -39,6 +39,7 @@ import com.daml.ledger.api.v1.admin.party_management_service.{
ListKnownPartiesRequest,
PartyDetails,
}
import com.daml.ledger.api.v1.admin.user_management_service.UserManagementServiceGrpc.UserManagementService
import com.daml.ledger.api.v1.command_completion_service.{
Checkpoint,
CompletionEndRequest,
@ -791,6 +792,9 @@ private[testtool] final class ParticipantTestContext private[participant] (
_ <- waitForParties(participantsUnderTest, parties.toSet)
} yield parties
def userManagement: UserManagementService =
services.userManagement // TODO (i12059) perhaps remove and create granular accessors
private def reservePartyNames(n: Int): Future[Vector[Party]] =
Future.successful(Vector.fill(n)(Party(nextPartyHintId())))
}

View File

@ -0,0 +1,173 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.api.testtool.suites
import java.util.UUID
import com.daml.error.ErrorCode
import com.daml.error.definitions.LedgerApiErrors
import com.daml.ledger.api.testtool.infrastructure.Allocation._
import com.daml.ledger.api.testtool.infrastructure.Assertions._
import com.daml.ledger.api.testtool.infrastructure.LedgerTestSuite
import com.daml.ledger.api.v1.admin.user_management_service.{
CreateUserRequest,
DeleteUserRequest,
GetUserRequest,
GrantUserRightsRequest,
ListUserRightsRequest,
RevokeUserRightsRequest,
User,
Right => Permission,
}
import com.daml.ledger.api.v1.admin.{user_management_service => proto}
import io.grpc.Status
import scala.concurrent.Future
final class UserManagementServiceIT extends LedgerTestSuite {
// TODO (i12059): complete testing
// create user
// - invalid user-name
// - invalid rights
// get user
// delete user
// list users
// grant user rights
// revoke user rights
// list user rights
test(
"UserManagementCreateUserInvalidArguments",
"Test argument validation for UserManagement#CreateUser",
allocate(NoParties),
)(implicit ec => { case Participants(Participant(ledger)) =>
val userId = UUID.randomUUID.toString
def createAndCheck(
problem: String,
user: User,
rights: Seq[proto.Right],
errorCode: ErrorCode,
): Future[Unit] =
for {
error <- ledger.userManagement
.createUser(CreateUserRequest(Some(user), rights))
.mustFail(problem)
} yield assertGrpcError(ledger, error, Status.Code.INVALID_ARGUMENT, errorCode, None)
for {
_ <- createAndCheck(
"empty user-id",
User(""),
List.empty,
LedgerApiErrors.RequestValidation.InvalidField,
)
_ <- createAndCheck(
"invalid user-id",
User("!!"),
List.empty,
LedgerApiErrors.RequestValidation.InvalidField,
)
_ <- createAndCheck(
"invalid primary-party",
User("u1-" + userId, "party2-!!"),
List.empty,
LedgerApiErrors.RequestValidation.InvalidArgument,
)
r = proto.Right(proto.Right.Kind.CanActAs(proto.Right.CanActAs("party3-!!")))
_ <- createAndCheck(
"invalid party in right",
User("u2-" + userId),
List(r),
LedgerApiErrors.RequestValidation.InvalidArgument,
)
} yield ()
})
test(
"UserManagementGetUserInvalidArguments",
"Test argument validation for UserManagement#GetUser",
allocate(NoParties),
)(implicit ec => { case Participants(Participant(ledger)) =>
def getAndCheck(problem: String, userId: String, errorCode: ErrorCode): Future[Unit] =
for {
error <- ledger.userManagement
.getUser(GetUserRequest(userId))
.mustFail(problem)
} yield assertGrpcError(ledger, error, Status.Code.INVALID_ARGUMENT, errorCode, None)
for {
_ <- getAndCheck("empty user-id", "", LedgerApiErrors.RequestValidation.InvalidArgument)
_ <- getAndCheck("invalid user-id", "!!", LedgerApiErrors.RequestValidation.InvalidField)
} yield ()
})
test(
"TestAllUserManagementRpcs",
"Exercise every rpc once with success and once with a failure",
allocate(NoParties),
)(implicit ec => { case Participants(Participant(ledger)) =>
for {
// TODO: actually exercise all RPCs
createResult <- ledger.userManagement.createUser(CreateUserRequest(Some(User("a", "b")), Nil))
createAgainError <- ledger.userManagement
.createUser(CreateUserRequest(Some(User("a", "b")), Nil))
.mustFail("allocating a duplicate user")
getUserResult <- ledger.userManagement.getUser(GetUserRequest("a"))
getUserError <- ledger.userManagement
.getUser(GetUserRequest("b"))
.mustFail("retrieving non-existent user")
grantResult <- ledger.userManagement.grantUserRights(
GrantUserRightsRequest(
"a",
List(Permission(Permission.Kind.ParticipantAdmin(Permission.ParticipantAdmin()))),
)
)
listRightsResult <- ledger.userManagement.listUserRights(ListUserRightsRequest("a"))
revokeResult <- ledger.userManagement.revokeUserRights(
RevokeUserRightsRequest(
"a",
List(Permission(Permission.Kind.ParticipantAdmin(Permission.ParticipantAdmin()))),
)
)
_ <- ledger.userManagement.deleteUser(DeleteUserRequest("a"))
} yield {
assertGrpcError(
ledger,
createAgainError,
Status.Code.ALREADY_EXISTS,
LedgerApiErrors.AdminServices.UserAlreadyExists,
None,
)
assertGrpcError(
ledger,
getUserError,
Status.Code.NOT_FOUND,
LedgerApiErrors.AdminServices.UserNotFound,
None,
)
assert(createResult == User("a", "b"))
assert(getUserResult == User("a", "b"))
assert(
grantResult.newlyGrantedRights == List(
Permission(Permission.Kind.ParticipantAdmin(Permission.ParticipantAdmin()))
)
)
assert(
revokeResult.newlyRevokedRights == List(
Permission(Permission.Kind.ParticipantAdmin(Permission.ParticipantAdmin()))
)
)
assert(
listRightsResult.rights.toSet == Set(
Permission(Permission.Kind.ParticipantAdmin(Permission.ParticipantAdmin()))
// Permission(Permission.Kind.CanActAs(Permission.CanActAs("acting-party"))),
// Permission(Permission.Kind.CanReadAs(Permission.CanReadAs("reader-party"))),
)
)
}
})
}

View File

@ -78,6 +78,7 @@ object Tests {
new MonotonicRecordTimeIT,
new TLSOnePointThreeIT,
new TLSAtLeastOnePointTwoIT,
new UserManagementServiceIT, // TODO (i12076): make this a default test once it reads a feature descriptor
// TODO sandbox-classic removal: Remove
new DeprecatedSandboxClassicMemoryContractKeysIT,
new DeprecatedSandboxClassicMemoryExceptionsIT,

View File

@ -36,6 +36,7 @@ import com.daml.platform.apiserver.services.admin.{
ApiPackageManagementService,
ApiParticipantPruningService,
ApiPartyManagementService,
ApiUserManagementService,
}
import com.daml.platform.apiserver.services.transaction.ApiTransactionService
import com.daml.platform.configuration.{
@ -76,6 +77,7 @@ private[daml] object ApiServices {
participantId: Ref.ParticipantId,
optWriteService: Option[state.WriteService],
indexService: IndexService,
userManagementService: UserManagementStore,
authorizer: Authorizer,
engine: Engine,
timeProvider: TimeProvider,
@ -201,6 +203,9 @@ private[daml] object ApiServices {
val apiHealthService = new GrpcHealthService(healthChecks, errorsVersionsSwitcher)
val apiUserManagementService =
new ApiUserManagementService(userManagementService, errorsVersionsSwitcher)
apiTimeServiceOpt.toList :::
writeServiceBackedApiServices :::
List(
@ -213,6 +218,7 @@ private[daml] object ApiServices {
apiReflectionService,
apiHealthService,
apiVersionService,
new UserManagementServiceAuthorization(apiUserManagementService, authorizer),
)
}

View File

@ -3,6 +3,8 @@
package com.daml.platform.apiserver
import java.time.Clock
import akka.actor.ActorSystem
import akka.stream.Materializer
import com.daml.api.util.TimeProvider
@ -12,6 +14,8 @@ import com.daml.ledger.api.auth.interceptor.AuthorizationInterceptor
import com.daml.ledger.api.auth.{AuthService, Authorizer}
import com.daml.ledger.api.health.HealthChecks
import com.daml.ledger.configuration.LedgerId
import com.daml.ledger.participant.state.index.impl.inmemory.InMemoryUserManagementStore
import com.daml.ledger.participant.state.index.v2.IndexService
import com.daml.ledger.participant.state.{v2 => state}
import com.daml.ledger.resources.ResourceOwner
import com.daml.lf.data.Ref
@ -25,12 +29,9 @@ import com.daml.platform.configuration.{
}
import com.daml.platform.services.time.TimeProviderType
import com.daml.ports.{Port, PortFiles}
import com.daml.telemetry.TelemetryContext
import io.grpc.{BindableService, ServerInterceptor}
import scalaz.{-\/, \/-}
import java.time.Clock
import com.daml.ledger.participant.state.index.v2.IndexService
import com.daml.telemetry.TelemetryContext
import scala.collection.immutable
import scala.concurrent.ExecutionContextExecutor
@ -87,6 +88,8 @@ object StandaloneApiServer {
)
val healthChecksWithIndexService = healthChecks + ("index" -> indexService)
val userManagementService = new InMemoryUserManagementStore
for {
executionSequencerFactory <- new ExecutionSequencerFactoryOwner()
apiServicesOwner = new ApiServices.Owner(
@ -113,6 +116,7 @@ object StandaloneApiServer {
managementServiceTimeout = config.managementServiceTimeout,
enableSelfServiceErrorCodes = config.enableSelfServiceErrorCodes,
checkOverloaded = checkOverloaded,
userManagementService = userManagementService,
)(materializer, executionSequencerFactory, loggingContext)
.map(_.withServices(otherServices))
apiServer <- new LedgerApiServer(
@ -123,6 +127,7 @@ object StandaloneApiServer {
config.tlsConfig,
AuthorizationInterceptor(
authService,
userManagementService,
servicesExecutionContext,
errorCodesVersionSwitcher,
) :: otherInterceptors,

View File

@ -0,0 +1,208 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.apiserver.services.admin
import com.daml.error.definitions.LedgerApiErrors
import com.daml.error.{
ContextualizedErrorLogger,
DamlContextualizedErrorLogger,
ErrorCodesVersionSwitcher,
}
import com.daml.ledger.api.domain._
import com.daml.ledger.api.v1.admin.{user_management_service => proto}
import com.daml.ledger.participant.state.index.v2.UserManagementStore
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.server.api.validation.{ErrorFactories, FieldValidations}
import io.grpc.{ServerServiceDefinition, StatusRuntimeException}
import scalaz.std.either._
import scalaz.syntax.traverse._
import scalaz.std.list._
import scala.concurrent.{ExecutionContext, Future}
private[apiserver] final class ApiUserManagementService(
userManagementService: UserManagementStore,
errorCodesVersionSwitcher: ErrorCodesVersionSwitcher,
)(implicit
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends proto.UserManagementServiceGrpc.UserManagementService
with GrpcApiService {
import ApiUserManagementService._
private implicit val logger: ContextualizedLogger = ContextualizedLogger.get(this.getClass)
private val errorFactories = ErrorFactories(errorCodesVersionSwitcher)
private implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
private val fieldValidations = FieldValidations(errorFactories)
import fieldValidations._
override def close(): Unit = ()
override def bindService(): ServerServiceDefinition =
proto.UserManagementServiceGrpc.bindService(this, executionContext)
override def createUser(request: proto.CreateUserRequest): Future[proto.User] =
withValidation {
for {
pUser <- requirePresence(request.user, "user")
pUserId <- requireUserId(pUser.id, "id")
pOptPrimaryParty <- optionalString(pUser.primaryParty)(requireParty)
pRights <- fromProtoRights(request.rights)
} yield (User(pUserId, pOptPrimaryParty), pRights)
} { case (user, pRights) =>
userManagementService
.createUser(
user = user,
rights = pRights,
)
.flatMap(handleResult("create user"))
.map(_ => request.user.get)
}
override def getUser(request: proto.GetUserRequest): Future[proto.User] =
withValidation(
requireUserId(request.userId, "user_id")
)(userId =>
userManagementService
.getUser(userId)
.flatMap(handleResult("get user"))
.map(toProtoUser)
)
override def deleteUser(request: proto.DeleteUserRequest): Future[proto.DeleteUserResponse] =
withValidation(
requireUserId(request.userId, "user_id")
)(userId =>
userManagementService
.deleteUser(userId)
.flatMap(handleResult("delete user"))
.map(_ => proto.DeleteUserResponse())
)
override def listUsers(request: proto.ListUsersRequest): Future[proto.ListUsersResponse] =
userManagementService
.listUsers()
.flatMap(handleResult("list users"))
.map(
_.map(toProtoUser)
)
.map(proto.ListUsersResponse(_))
override def grantUserRights(
request: proto.GrantUserRightsRequest
): Future[proto.GrantUserRightsResponse] =
withValidation(
for {
userId <- requireUserId(request.userId, "user_id")
rights <- fromProtoRights(request.rights)
} yield (userId, rights)
) { case (userId, rights) =>
userManagementService
.grantRights(
id = userId,
rights = rights,
)
.flatMap(handleResult("grant user rights"))
.map(_.view.map(toProtoRight).toList)
.map(proto.GrantUserRightsResponse(_))
}
override def revokeUserRights(
request: proto.RevokeUserRightsRequest
): Future[proto.RevokeUserRightsResponse] =
withValidation(
for {
userId <- fieldValidations.requireUserId(request.userId, "user_id")
rights <- fromProtoRights(request.rights)
} yield (userId, rights)
) { case (userId, rights) =>
userManagementService
.revokeRights(
id = userId,
rights = rights,
)
.flatMap(handleResult("revoke user rights"))
.map(_.view.map(toProtoRight).toList)
.map(proto.RevokeUserRightsResponse(_))
}
override def listUserRights(
request: proto.ListUserRightsRequest
): Future[proto.ListUserRightsResponse] =
withValidation(
requireUserId(request.userId, "user_id")
)(userId =>
userManagementService
.listUserRights(userId)
.flatMap(handleResult("list user rights"))
.map(_.view.map(toProtoRight).toList)
.map(proto.ListUserRightsResponse(_))
)
private def handleResult[T](operation: String)(result: UserManagementStore.Result[T]): Future[T] =
result match {
case Left(UserManagementStore.UserNotFound(id)) =>
Future.failed(
LedgerApiErrors.AdminServices.UserNotFound.Reject(operation, id.toString).asGrpcError
)
case Left(UserManagementStore.UserExists(id)) =>
Future.failed(
LedgerApiErrors.AdminServices.UserAlreadyExists.Reject(operation, id.toString).asGrpcError
)
case scala.util.Right(t) =>
Future.successful(t)
}
private def withValidation[A, B](validatedResult: Either[StatusRuntimeException, A])(
f: A => Future[B]
): Future[B] =
validatedResult.fold(Future.failed, Future.successful).flatMap(f)
private val fromProtoRight: proto.Right => Either[StatusRuntimeException, UserRight] = {
case proto.Right(_: proto.Right.Kind.ParticipantAdmin) =>
Right(UserRight.ParticipantAdmin)
case proto.Right(proto.Right.Kind.CanActAs(r)) =>
requireParty(r.party).map(UserRight.CanActAs(_))
case proto.Right(proto.Right.Kind.CanReadAs(r)) =>
requireParty(r.party).map(UserRight.CanReadAs(_))
case proto.Right(proto.Right.Kind.Empty) =>
Left(
LedgerApiErrors.RequestValidation.InvalidArgument
.Reject(
"unknown kind of right - check that the Ledger API version of the server is recent enough"
)
.asGrpcError
)
}
private def fromProtoRights(
rights: Seq[proto.Right]
): Either[StatusRuntimeException, Set[UserRight]] =
rights.toList.traverse(fromProtoRight).map(_.toSet)
}
object ApiUserManagementService {
private def toProtoUser(user: User): proto.User =
proto.User(
id = user.id.toString,
primaryParty = user.primaryParty.getOrElse(""),
)
private val toProtoRight: UserRight => proto.Right = {
case UserRight.ParticipantAdmin =>
proto.Right(proto.Right.Kind.ParticipantAdmin(proto.Right.ParticipantAdmin()))
case UserRight.CanActAs(party) =>
proto.Right(proto.Right.Kind.CanActAs(proto.Right.CanActAs(party)))
case UserRight.CanReadAs(party) =>
proto.Right(proto.Right.Kind.CanReadAs(proto.Right.CanReadAs(party)))
}
}

View File

@ -0,0 +1,48 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.index
// TODO (i12057): complete unit testing and move to the right place sitting side-by-side with the implementation
//import org.scalatest.BeforeAndAfterAll
//import org.scalatest.concurrent.Eventually
import org.scalatest.matchers.should.Matchers
//import org.scalatest.time.{Millis, Span}
import org.scalatest.wordspec.AnyWordSpec
//import scala.concurrent.duration._
final class InMemoryUserManagementStoreSpec
extends AnyWordSpec
with Matchers
// with Eventually
// with BeforeAndAfterAll
{
// override implicit val patienceConfig: PatienceConfig =
// PatienceConfig(timeout = scaled(Span(2000, Millis)), interval = scaled(Span(50, Millis)))
// tests for
// deleteUser
// getUser
// createUser
"in-memory user management should" should {
"allow creating a fresh user" in {}
"disallow re-creating an existing user" in {}
"find a freshly created user" in {}
"not find a non-existent user" in {}
}
// tests for:
// listUserRights
// revokeRights
// grantRights
"in-memory user rights management should" should {}
// override def afterAll(): Unit = {
// }
}

View File

@ -8,7 +8,7 @@ load(
da_scala_library(
name = "participant-state-index",
srcs = glob(["src/main/scala/com/daml/ledger/participant/state/index/v2/*.scala"]),
srcs = glob(["src/main/scala/com/daml/ledger/participant/state/index/**/*.scala"]),
resources = glob(["src/main/resources/**/*"]),
scala_deps = [
"@maven//:com_typesafe_akka_akka_actor",

View File

@ -0,0 +1,112 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.participant.state.index.impl.inmemory
import com.daml.ledger.api.domain.{User, UserRight}
import com.daml.ledger.participant.state.index.v2.UserManagementStore
import com.daml.ledger.participant.state.index.v2.UserManagementStore._
import com.daml.lf.data.Ref
import scala.collection.mutable
import scala.concurrent.Future
class InMemoryUserManagementStore extends UserManagementStore {
import InMemoryUserManagementStore._
// Underlying mutable map to keep track of UserInfo state.
// Structured so we can use a ConcurrentHashMap (to more closely mimic a real implementation, where performance is key).
// We synchronize on a private object (the mutable map), not the service (which could cause deadlocks).
// (No need to mark state as volatile -- rely on synchronized to establish the JMM's happens-before relation.)
private val state: mutable.Map[Ref.UserId, UserInfo] = mutable.Map(AdminUser.toStateEntry)
override def createUser(user: User, rights: Set[UserRight]): Future[Result[Unit]] =
withoutUser(user.id) {
state.update(user.id, UserInfo(user, rights))
}
override def getUser(id: Ref.UserId): Future[Result[User]] =
withUser(id)(_.user)
override def deleteUser(id: Ref.UserId): Future[Result[Unit]] =
withUser(id) { _ =>
state.remove(id)
()
}
override def grantRights(
id: Ref.UserId,
granted: Set[UserRight],
): Future[Result[Set[UserRight]]] =
withUser(id) { userInfo =>
val newlyGranted = granted.diff(userInfo.rights) // faster than filter
// we're not doing concurrent updates -- assert as backstop and a reminder to handle the collision case in the future
assert(
replaceInfo(userInfo, userInfo.copy(rights = userInfo.rights ++ newlyGranted))
)
newlyGranted
}
override def revokeRights(
id: Ref.UserId,
revoked: Set[UserRight],
): Future[Result[Set[UserRight]]] =
withUser(id) { userInfo =>
val effectivelyRevoked = revoked.intersect(userInfo.rights) // faster than filter
// we're not doing concurrent updates -- assert as backstop and a reminder to handle the collision case in the future
assert(
replaceInfo(userInfo, userInfo.copy(rights = userInfo.rights -- effectivelyRevoked))
)
effectivelyRevoked
}
override def listUserRights(id: Ref.UserId): Future[Result[Set[UserRight]]] =
withUser(id)(_.rights)
def listUsers(): Future[Result[Users]] =
withState {
Right(state.values.map(_.user).toSeq)
}
private def withState[T](t: => T): Future[T] =
synchronized(
Future.successful(t)
)
private def withUser[T](id: Ref.UserId)(f: UserInfo => T): Future[Result[T]] =
withState(
state.get(id) match {
case Some(user) => Right(f(user))
case None => Left(UserNotFound(id))
}
)
private def withoutUser[T](id: Ref.UserId)(t: => T): Future[Result[T]] =
withState(
state.get(id) match {
case Some(_) => Left(UserExists(id))
case None => Right(t)
}
)
private def replaceInfo(oldInfo: UserInfo, newInfo: UserInfo) = state.synchronized {
assert(
oldInfo.user.id == newInfo.user.id,
s"Replace info from if ${oldInfo.user.id} to ${newInfo.user.id} -> ${newInfo.rights}",
)
state.get(oldInfo.user.id) match {
case Some(`oldInfo`) => state.update(newInfo.user.id, newInfo); true
case _ => false
}
}
}
object InMemoryUserManagementStore {
case class UserInfo(user: User, rights: Set[UserRight]) {
def toStateEntry: (Ref.UserId, UserInfo) = user.id -> this
}
private val AdminUser = UserInfo(
user = User(Ref.UserId.assertFromString("participant_admin"), None),
rights = Set(UserRight.ParticipantAdmin),
)
}

View File

@ -0,0 +1,36 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.ledger.participant.state.index.v2
import com.daml.ledger.api.domain.{User, UserRight}
import com.daml.lf.data.Ref
import scala.concurrent.Future
trait UserManagementStore {
import UserManagementStore._
def createUser(user: User, rights: Set[UserRight]): Future[Result[Unit]]
def getUser(id: Ref.UserId): Future[Result[User]]
def deleteUser(id: Ref.UserId): Future[Result[Unit]]
def grantRights(id: Ref.UserId, rights: Set[UserRight]): Future[Result[Set[UserRight]]]
def revokeRights(id: Ref.UserId, rights: Set[UserRight]): Future[Result[Set[UserRight]]]
def listUserRights(id: Ref.UserId): Future[Result[Set[UserRight]]]
def listUsers(): Future[Result[Users]]
}
object UserManagementStore {
type Result[T] = Either[Error, T]
type Users = Seq[User]
sealed trait Error
final case class UserNotFound(userId: Ref.UserId) extends Error
final case class UserExists(userId: Ref.UserId) extends Error
}

View File

@ -45,11 +45,13 @@ import com.daml.platform.services.time.TimeProviderType
import com.daml.platform.store.{FlywayMigrations, LfValueTranslationCache}
import com.daml.ports.Port
import scalaz.syntax.tag._
import java.io.File
import java.nio.file.Files
import java.time.Instant
import java.util.concurrent.Executors
import com.daml.ledger.participant.state.index.impl.inmemory.InMemoryUserManagementStore
import scala.concurrent.duration.DurationInt
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.jdk.CollectionConverters._
@ -305,6 +307,8 @@ final class SandboxServer(
case None => "in-memory"
}
val userManagementService = new InMemoryUserManagementStore
for {
servicesExecutionContext <- ResourceOwner
.forExecutorService(() => Executors.newWorkStealingPool())
@ -401,6 +405,7 @@ final class SandboxServer(
managementServiceTimeout = config.managementServiceTimeout,
enableSelfServiceErrorCodes = config.enableSelfServiceErrorCodes,
checkOverloaded = _ => None,
userManagementService = userManagementService,
)(materializer, executionSequencerFactory, loggingContext)
.map(_.withServices(List(resetService)))
apiServer <- new LedgerApiServer(
@ -413,7 +418,8 @@ final class SandboxServer(
List(
AuthorizationInterceptor(
authService,
executionContext,
userManagementService,
servicesExecutionContext,
errorCodesVersionSwitcher,
),
resetService,

View File

@ -7,7 +7,7 @@ import java.util.UUID
trait AdminServiceCallAuthTests extends SecuredServiceCallAuthTests {
private val signedIncorrectly = Option(toHeader(adminToken, UUID.randomUUID.toString))
private val signedIncorrectly = Option(customTokenToHeader(adminToken, UUID.randomUUID.toString))
it should "deny calls with an invalid signature" in {
expectUnauthenticated(serviceCallWithToken(signedIncorrectly))
@ -27,7 +27,12 @@ trait AdminServiceCallAuthTests extends SecuredServiceCallAuthTests {
it should "allow calls with admin token without expiration" in {
expectSuccess(serviceCallWithToken(canReadAsAdmin))
}
it should "allow calls with standard token for 'participant_admin' without expiration" in {
expectSuccess(serviceCallWithToken(canReadAsAdminStandardJWT))
}
it should "deny calls with standard token for 'unknown_user' without expiration" in {
expectUnauthenticated(serviceCallWithToken(canReadAsUnknownUserStandardJWT))
}
it should "allow calls with the correct ledger ID" in {
expectSuccess(serviceCallWithToken(canReadAsAdminActualLedgerId))
}

View File

@ -45,10 +45,10 @@ trait ExpiringStreamServiceCallAuthTests[T]
}
private def canActAsMainActorExpiresInFiveSeconds =
toHeader(expiringIn(Duration.ofSeconds(5), readWriteToken(mainActor)))
customTokenToHeader(expiringIn(Duration.ofSeconds(5), readWriteToken(mainActor)))
private def canReadAsMainActorExpiresInFiveSeconds =
toHeader(expiringIn(Duration.ofSeconds(5), readOnlyToken(mainActor)))
customTokenToHeader(expiringIn(Duration.ofSeconds(5), readOnlyToken(mainActor)))
it should "break a stream in flight upon read-only token expiration" in {
val _ = Delayed.Future.by(10.seconds)(submitAndWait())

View File

@ -28,7 +28,6 @@ trait MultiPartyServiceCallAuthTests extends SecuredServiceCallAuthTests {
private[this] val submitters: RequestSubmitters = RequestSubmitters("", actAs, readAs)
private[this] val singleParty = UUID.randomUUID.toString
private[this] val randomParty = UUID.randomUUID.toString
private[this] val randomActAs: List[String] = List.fill(actorsCount)(UUID.randomUUID.toString)
private[this] val randomReadAs: List[String] = List.fill(readersCount)(UUID.randomUUID.toString)
@ -42,7 +41,9 @@ trait MultiPartyServiceCallAuthTests extends SecuredServiceCallAuthTests {
tokenParties: TokenParties,
requestSubmitters: RequestSubmitters,
): Future[Any] = {
val token = Option(toHeader(multiPartyToken(tokenParties.actAs, tokenParties.readAs)))
val token = Option(
customTokenToHeader(multiPartyToken(tokenParties.actAs, tokenParties.readAs))
)
serviceCallWithToken(token, requestSubmitters)
}

View File

@ -15,6 +15,13 @@ trait PublicServiceCallAuthTests extends SecuredServiceCallAuthTests {
expectSuccess(serviceCallWithToken(canReadAsRandomParty))
}
it should "allow calls with non-expired 'participant_admin' standard token" in {
expectSuccess(serviceCallWithToken(canReadAsAdminStandardJWT))
}
it should "deny calls with non-expired 'unknown_user' standard token" in {
expectUnauthenticated(serviceCallWithToken(canReadAsUnknownUserStandardJWT))
}
it should "deny calls with an expired read/write token" in {
expectUnauthenticated(serviceCallWithToken(canActAsRandomPartyExpired))
}

View File

@ -44,6 +44,9 @@ trait ServiceCallAuthTests
protected def expectUnauthenticated(f: Future[Any]): Future[Assertion] =
expectFailure(f, Status.Code.UNAUTHENTICATED)
protected def expectInvalidArgument(f: Future[Any]): Future[Assertion] =
expectFailure(f, Status.Code.INVALID_ARGUMENT)
protected def expectUnimplemented(f: Future[Any]): Future[Assertion] =
expectFailure(f, Status.Code.UNIMPLEMENTED)
@ -59,49 +62,83 @@ trait ServiceCallAuthTests
protected def ledgerBegin: LedgerOffset =
LedgerOffset(LedgerOffset.Value.Boundary(LedgerOffset.LedgerBoundary.LEDGER_BEGIN))
protected val randomParty: String = UUID.randomUUID.toString
protected val canActAsRandomParty: Option[String] =
Option(toHeader(readWriteToken(UUID.randomUUID.toString)))
Option(customTokenToHeader(readWriteToken(randomParty)))
protected val canActAsRandomPartyExpired: Option[String] =
Option(toHeader(expiringIn(Duration.ofDays(-1), readWriteToken(UUID.randomUUID.toString))))
Option(
customTokenToHeader(expiringIn(Duration.ofDays(-1), readWriteToken(UUID.randomUUID.toString)))
)
protected val canActAsRandomPartyExpiresTomorrow: Option[String] =
Option(toHeader(expiringIn(Duration.ofDays(1), readWriteToken(UUID.randomUUID.toString))))
Option(
customTokenToHeader(expiringIn(Duration.ofDays(1), readWriteToken(UUID.randomUUID.toString)))
)
protected val canReadAsRandomParty: Option[String] =
Option(toHeader(readOnlyToken(UUID.randomUUID.toString)))
Option(customTokenToHeader(readOnlyToken(randomParty)))
protected val canReadAsRandomPartyExpired: Option[String] =
Option(toHeader(expiringIn(Duration.ofDays(-1), readOnlyToken(UUID.randomUUID.toString))))
Option(
customTokenToHeader(expiringIn(Duration.ofDays(-1), readOnlyToken(UUID.randomUUID.toString)))
)
protected val canReadAsRandomPartyExpiresTomorrow: Option[String] =
Option(toHeader(expiringIn(Duration.ofDays(1), readOnlyToken(UUID.randomUUID.toString))))
Option(
customTokenToHeader(expiringIn(Duration.ofDays(1), readOnlyToken(UUID.randomUUID.toString)))
)
protected val canReadAsAdmin: Option[String] =
Option(toHeader(adminToken))
Option(customTokenToHeader(adminToken))
protected val canReadAsAdminExpired: Option[String] =
Option(toHeader(expiringIn(Duration.ofDays(-1), adminToken)))
Option(customTokenToHeader(expiringIn(Duration.ofDays(-1), adminToken)))
protected val canReadAsAdminExpiresTomorrow: Option[String] =
Option(toHeader(expiringIn(Duration.ofDays(1), adminToken)))
Option(customTokenToHeader(expiringIn(Duration.ofDays(1), adminToken)))
// Standard tokens for user authentication
protected val canReadAsAdminStandardJWT: Option[String] =
Option(toHeader(adminTokenStandardJWT))
protected val canReadAsUnknownUserStandardJWT: Option[String] =
Option(toHeader(unknownUserTokenStandardJWT))
// Special tokens to test decoding users and rights from custom tokens
protected val randomUserCanReadAsRandomParty: Option[String] =
Option(customTokenToHeader(readOnlyToken(randomParty).copy(applicationId = Some(randomUserId))))
protected val randomUserCanActAsRandomParty: Option[String] =
Option(
customTokenToHeader(readWriteToken(randomParty).copy(applicationId = Some(randomUserId)))
)
// Note: lazy val, because the ledger ID is only known after the sandbox start
protected lazy val canReadAsRandomPartyActualLedgerId: Option[String] =
Option(toHeader(forLedgerId(unwrappedLedgerId, readOnlyToken(UUID.randomUUID.toString))))
Option(
customTokenToHeader(forLedgerId(unwrappedLedgerId, readOnlyToken(UUID.randomUUID.toString)))
)
protected val canReadAsRandomPartyRandomLedgerId: Option[String] =
Option(toHeader(forLedgerId(UUID.randomUUID.toString, readOnlyToken(UUID.randomUUID.toString))))
Option(
customTokenToHeader(
forLedgerId(UUID.randomUUID.toString, readOnlyToken(UUID.randomUUID.toString))
)
)
protected val canReadAsRandomPartyActualParticipantId: Option[String] =
Option(
toHeader(forParticipantId("sandbox-participant", readOnlyToken(UUID.randomUUID.toString)))
customTokenToHeader(
forParticipantId("sandbox-participant", readOnlyToken(UUID.randomUUID.toString))
)
)
protected val canReadAsRandomPartyRandomParticipantId: Option[String] =
Option(
toHeader(forParticipantId(UUID.randomUUID.toString, readOnlyToken(UUID.randomUUID.toString)))
customTokenToHeader(
forParticipantId(UUID.randomUUID.toString, readOnlyToken(UUID.randomUUID.toString))
)
)
// Note: lazy val, because the ledger ID is only known after the sandbox start
protected lazy val canReadAsAdminActualLedgerId: Option[String] =
Option(toHeader(forLedgerId(unwrappedLedgerId, adminToken)))
Option(customTokenToHeader(forLedgerId(unwrappedLedgerId, adminToken)))
protected val canReadAsAdminRandomLedgerId: Option[String] =
Option(toHeader(forLedgerId(UUID.randomUUID.toString, adminToken)))
Option(customTokenToHeader(forLedgerId(UUID.randomUUID.toString, adminToken)))
protected val canReadAsAdminActualParticipantId: Option[String] =
Option(toHeader(forParticipantId("sandbox-participant", adminToken)))
Option(customTokenToHeader(forParticipantId("sandbox-participant", adminToken)))
protected val canReadAsAdminRandomParticipantId: Option[String] =
Option(toHeader(forParticipantId(UUID.randomUUID.toString, adminToken)))
Option(customTokenToHeader(forParticipantId(UUID.randomUUID.toString, adminToken)))
}

View File

@ -11,7 +11,7 @@ trait ServiceCallWithMainActorAuthTests extends SecuredServiceCallAuthTests {
protected val mainActor: String = UUID.randomUUID.toString
private val signedIncorrectly =
Option(toHeader(readWriteToken(mainActor), UUID.randomUUID.toString))
Option(customTokenToHeader(readWriteToken(mainActor), UUID.randomUUID.toString))
it should "deny calls authorized to read/write as the wrong party" in {
expectPermissionDenied(serviceCallWithToken(canActAsRandomParty))
@ -24,45 +24,53 @@ trait ServiceCallWithMainActorAuthTests extends SecuredServiceCallAuthTests {
}
protected val canReadAsMainActor =
Option(toHeader(readOnlyToken(mainActor)))
Option(customTokenToHeader(readOnlyToken(mainActor)))
protected val canReadAsMainActorExpired =
Option(toHeader(expiringIn(Duration.ofDays(-1), readOnlyToken(mainActor))))
Option(customTokenToHeader(expiringIn(Duration.ofDays(-1), readOnlyToken(mainActor))))
protected val canReadAsMainActorExpiresTomorrow =
Option(toHeader(expiringIn(Duration.ofDays(1), readOnlyToken(mainActor))))
Option(customTokenToHeader(expiringIn(Duration.ofDays(1), readOnlyToken(mainActor))))
protected val canActAsMainActor =
Option(toHeader(readWriteToken(mainActor)))
Option(customTokenToHeader(readWriteToken(mainActor)))
protected val canActAsMainActorExpired =
Option(toHeader(expiringIn(Duration.ofDays(-1), readWriteToken(mainActor))))
Option(customTokenToHeader(expiringIn(Duration.ofDays(-1), readWriteToken(mainActor))))
protected val canActAsMainActorExpiresTomorrow =
Option(toHeader(expiringIn(Duration.ofDays(1), readWriteToken(mainActor))))
Option(customTokenToHeader(expiringIn(Duration.ofDays(1), readWriteToken(mainActor))))
// Note: lazy val, because the ledger ID is only known after the sandbox start
protected lazy val canReadAsMainActorActualLedgerId =
Option(toHeader(forLedgerId(unwrappedLedgerId, readOnlyToken(mainActor))))
Option(customTokenToHeader(forLedgerId(unwrappedLedgerId, readOnlyToken(mainActor))))
protected val canReadAsMainActorRandomLedgerId =
Option(toHeader(forLedgerId(UUID.randomUUID.toString, readOnlyToken(mainActor))))
Option(customTokenToHeader(forLedgerId(UUID.randomUUID.toString, readOnlyToken(mainActor))))
protected val canReadAsMainActorActualParticipantId =
Option(toHeader(forParticipantId("sandbox-participant", readOnlyToken(mainActor))))
Option(customTokenToHeader(forParticipantId("sandbox-participant", readOnlyToken(mainActor))))
protected val canReadAsMainActorRandomParticipantId =
Option(toHeader(forParticipantId(UUID.randomUUID.toString, readOnlyToken(mainActor))))
Option(
customTokenToHeader(forParticipantId(UUID.randomUUID.toString, readOnlyToken(mainActor)))
)
protected val canReadAsMainActorActualApplicationId =
Option(toHeader(forApplicationId(serviceCallName, readOnlyToken(mainActor))))
Option(customTokenToHeader(forApplicationId(serviceCallName, readOnlyToken(mainActor))))
protected val canReadAsMainActorRandomApplicationId =
Option(toHeader(forApplicationId(UUID.randomUUID.toString, readOnlyToken(mainActor))))
Option(
customTokenToHeader(forApplicationId(UUID.randomUUID.toString, readOnlyToken(mainActor)))
)
// Note: lazy val, because the ledger ID is only known after the sandbox start
protected lazy val canActAsMainActorActualLedgerId =
Option(toHeader(forLedgerId(unwrappedLedgerId, readWriteToken(mainActor))))
Option(customTokenToHeader(forLedgerId(unwrappedLedgerId, readWriteToken(mainActor))))
protected val canActAsMainActorRandomLedgerId =
Option(toHeader(forLedgerId(UUID.randomUUID.toString, readWriteToken(mainActor))))
Option(customTokenToHeader(forLedgerId(UUID.randomUUID.toString, readWriteToken(mainActor))))
protected val canActAsMainActorActualParticipantId =
Option(toHeader(forParticipantId("sandbox-participant", readWriteToken(mainActor))))
Option(customTokenToHeader(forParticipantId("sandbox-participant", readWriteToken(mainActor))))
protected val canActAsMainActorRandomParticipantId =
Option(toHeader(forParticipantId(UUID.randomUUID.toString, readWriteToken(mainActor))))
Option(
customTokenToHeader(forParticipantId(UUID.randomUUID.toString, readWriteToken(mainActor)))
)
protected val canActAsMainActorActualApplicationId =
Option(toHeader(forApplicationId(serviceCallName, readWriteToken(mainActor))))
Option(customTokenToHeader(forApplicationId(serviceCallName, readWriteToken(mainActor))))
protected val canActAsMainActorRandomApplicationId =
Option(toHeader(forApplicationId(UUID.randomUUID.toString, readWriteToken(mainActor))))
Option(
customTokenToHeader(forApplicationId(UUID.randomUUID.toString, readWriteToken(mainActor)))
)
}

View File

@ -14,7 +14,7 @@ import scala.concurrent.Future
trait SubmitAndWaitDummyCommand extends TestCommands { self: ServiceCallWithMainActorAuthTests =>
protected def submitAndWait(): Future[Empty] =
submitAndWait(Option(toHeader(readWriteToken(mainActor))))
submitAndWait(Option(customTokenToHeader(readWriteToken(mainActor))))
protected def dummySubmitAndWaitRequest: SubmitAndWaitRequest =
SubmitAndWaitRequest(

View File

@ -0,0 +1,22 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.sandbox.auth
import java.util.UUID
import com.daml.ledger.api.v1.admin.user_management_service._
import scala.concurrent.Future
final class CreateUserAuthIT extends AdminServiceCallAuthTests {
override def serviceCallName: String = "UserManagementService#CreateUser"
override def serviceCallWithToken(token: Option[String]): Future[Any] = {
val userId = "fresh-user-" + UUID.randomUUID().toString
val req = CreateUserRequest(Some(User(userId)))
stub(UserManagementServiceGrpc.stub(channel), token).createUser(req)
}
}

View File

@ -0,0 +1,35 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.sandbox.auth
import java.util.UUID
import com.daml.ledger.api.v1.admin.user_management_service._
import io.grpc.{Status, StatusRuntimeException}
import scala.concurrent.Future
class GetUserWithGivenUserIdAuthIT extends AdminServiceCallAuthTests {
override def serviceCallName: String = "UserManagementService#GetUser(given-user-id)"
// only admin users are allowed to specify a user-id for which to retrieve a user
override def serviceCallWithToken(token: Option[String]): Future[Any] = {
for {
// test for an existing user
_ <- stub(UserManagementServiceGrpc.stub(channel), token)
.getUser(GetUserRequest("participant_admin"))
// test for a non-existent user
_ <- stub(UserManagementServiceGrpc.stub(channel), token)
.getUser(GetUserRequest("non-existent-user-" + UUID.randomUUID().toString))
.transform({
case scala.util.Success(u) =>
scala.util.Failure(new RuntimeException(s"User $u unexpectedly exists."))
case scala.util.Failure(e: StatusRuntimeException)
if e.getStatus.getCode == Status.Code.NOT_FOUND =>
scala.util.Success(())
case scala.util.Failure(e: Throwable) => scala.util.Failure(e)
})
} yield ()
}
}

View File

@ -0,0 +1,41 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.sandbox.auth
import com.daml.ledger.api.v1.admin.user_management_service.{
GetUserRequest,
User,
UserManagementServiceGrpc,
}
import org.scalatest.Assertion
import scala.concurrent.Future
/** Tests covering the special behaviour of GetUser when not specifying a user-id. */
class GetUserWithNoUserIdAuthIT extends ServiceCallAuthTests {
override def serviceCallName: String = "UserManagementService#GetUser(<no-user-id>)"
override def serviceCallWithToken(token: Option[String]): Future[Any] =
stub(UserManagementServiceGrpc.stub(channel), token).getUser(GetUserRequest())
protected def expectUser(token: Option[String], expectedUser: User): Future[Assertion] =
serviceCallWithToken(token).map(assertResult(expectedUser)(_))
behavior of serviceCallName
it should "deny unauthenticated access" in {
expectUnauthenticated(serviceCallWithToken(None))
}
it should "deny access for a standard token referring to an unknown user" in {
expectUnauthenticated(serviceCallWithToken(canReadAsUnknownUserStandardJWT))
}
it should "return the 'participant_admin' user when using its standard token" in {
expectUser(canReadAsAdminStandardJWT, User("participant_admin", ""))
}
it should "return invalid argument for custom token" in {
expectInvalidArgument(serviceCallWithToken(canReadAsAdmin))
}
}

View File

@ -0,0 +1,37 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.sandbox.auth
import java.util.UUID
import com.daml.ledger.api.v1.admin.user_management_service._
import io.grpc.{Status, StatusRuntimeException}
import scala.concurrent.Future
class ListUserRightsWithGivenUserIdAuthIT extends AdminServiceCallAuthTests {
override def serviceCallName: String = "UserManagementService#ListUserRights(given-user-id)"
// only admin users are allowed to specify a user-id for which to retrieve rights
override def serviceCallWithToken(token: Option[String]): Future[Any] = {
for {
// test for an existing user
_ <- stub(UserManagementServiceGrpc.stub(channel), token)
.listUserRights(ListUserRightsRequest("participant_admin"))
// test for a non-existent user
_ <- stub(UserManagementServiceGrpc.stub(channel), token)
.listUserRights(ListUserRightsRequest("non-existent-user-" + UUID.randomUUID().toString))
.transform({
case scala.util.Success(u) =>
scala.util.Failure(new RuntimeException(s"User $u unexpectedly exists."))
case scala.util.Failure(e: StatusRuntimeException)
if e.getStatus.getCode == Status.Code.NOT_FOUND =>
scala.util.Success(())
case scala.util.Failure(e: Throwable) => scala.util.Failure(e)
})
} yield ()
}
}

View File

@ -0,0 +1,44 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.platform.sandbox.auth
import com.daml.ledger.api.v1.admin.user_management_service._
import org.scalatest.Assertion
import scala.concurrent.Future
class ListUserRightsWithNoUserIdAuthIT extends ServiceCallAuthTests {
override def serviceCallName: String = "UserManagementService#ListUserRights(<no-user-id>)"
override def serviceCallWithToken(token: Option[String]): Future[Any] =
stub(UserManagementServiceGrpc.stub(channel), token).listUserRights(ListUserRightsRequest())
protected def expectRights(
token: Option[String],
expectedRights: Vector[Right],
): Future[Assertion] =
serviceCallWithToken(token).map(assertResult(ListUserRightsResponse(expectedRights))(_))
behavior of serviceCallName
it should "deny unauthenticated access" in {
expectUnauthenticated(serviceCallWithToken(None))
}
it should "deny access for a standard token referring to an unknown user" in {
expectUnauthenticated(serviceCallWithToken(canReadAsUnknownUserStandardJWT))
}
it should "return rights of the 'participant_admin' when using its standard token" in {
expectRights(
canReadAsAdminStandardJWT,
Vector(Right(Right.Kind.ParticipantAdmin(Right.ParticipantAdmin()))),
)
}
it should "return invalid argument for custom token" in {
expectInvalidArgument(serviceCallWithToken(canReadAsAdmin))
}
}

View File

@ -32,7 +32,7 @@ final class ReflectionIT
"accessed" should {
"provide a list of exposed services" in {
val expectedServiceCount: Int = 17
val expectedServiceCount: Int = 18
for {
response <- execRequest(listServices)
} yield {

View File

@ -11,8 +11,11 @@ import com.daml.jwt.{HMAC256Verifier, JwtSigner}
import com.daml.ledger.api.auth.{
AuthService,
AuthServiceJWT,
AuthServiceJWTCodec,
AuthServiceJWTPayload,
CustomDamlJWTPayload,
StandardJWTPayload,
SupportedJWTCodec,
SupportedJWTPayload,
}
import com.daml.ledger.api.domain.LedgerId
import org.scalatest.Suite
@ -34,9 +37,25 @@ trait SandboxRequiringAuthorization {
readAs = Nil,
)
protected val adminToken: AuthServiceJWTPayload = emptyToken.copy(admin = true)
protected def standardToken(userId: String) = StandardJWTPayload(
AuthServiceJWTPayload(
ledgerId = None,
participantId = None,
applicationId = Some(userId),
exp = None,
admin = false,
actAs = Nil,
readAs = Nil,
)
)
protected lazy val wrappedLedgerId: LedgerId = ledgerId(Some(toHeader(adminToken)))
protected val randomUserId: String = UUID.randomUUID().toString
protected val adminToken: AuthServiceJWTPayload = emptyToken.copy(admin = true)
protected val adminTokenStandardJWT: SupportedJWTPayload = standardToken("participant_admin")
protected val unknownUserTokenStandardJWT: SupportedJWTPayload = standardToken("unknown_user")
protected lazy val wrappedLedgerId: LedgerId = ledgerId(Some(customTokenToHeader(adminToken)))
protected lazy val unwrappedLedgerId: String = wrappedLedgerId.unwrap
override protected def authService: Option[AuthService] = {
@ -66,12 +85,18 @@ trait SandboxRequiringAuthorization {
protected def forApplicationId(id: String, p: AuthServiceJWTPayload): AuthServiceJWTPayload =
p.copy(applicationId = Some(id))
protected def toHeader(payload: AuthServiceJWTPayload, secret: String = jwtSecret): String =
protected def customTokenToHeader(
payload: AuthServiceJWTPayload,
secret: String = jwtSecret,
): String =
signed(CustomDamlJWTPayload(payload), secret)
protected def toHeader(payload: SupportedJWTPayload, secret: String = jwtSecret): String =
signed(payload, secret)
private def signed(payload: AuthServiceJWTPayload, secret: String): String =
private def signed(payload: SupportedJWTPayload, secret: String): String =
JwtSigner.HMAC256
.sign(DecodedJwt(jwtHeader, AuthServiceJWTCodec.compactPrint(payload)), secret)
.sign(DecodedJwt(jwtHeader, SupportedJWTCodec.compactPrint(payload)), secret)
.getOrElse(sys.error("Failed to generate token"))
.value
}

View File

@ -32,7 +32,7 @@ class Jwt
override protected def ledgerClientConfiguration: LedgerClientConfiguration =
super.ledgerClientConfiguration.copy(
token = Some(toHeader(forApplicationId("custom app id", readWriteToken(party))))
token = Some(customTokenToHeader(forApplicationId("custom app id", readWriteToken(party))))
)
private val party = "AliceAuth"