Accept new form of JWT tokens [DPP-1102] (#14323)

* Accept new form of jwt tokens [DPP-1102]

CHANGELOG_BEGIN
CHANGELOG_END
This commit is contained in:
Sergey Kisel 2022-07-04 17:41:08 +02:00 committed by GitHub
parent 5db3cc4bd5
commit f9521f27eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 247 additions and 33 deletions

View File

@ -17,7 +17,12 @@ import com.daml.http.util.Logging.{InstanceUUID, instanceUUIDLogCtx}
import com.daml.http.{HttpService, StartSettings, nonrepudiation}
import com.daml.jwt.JwtSigner
import com.daml.jwt.domain.DecodedJwt
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, CustomDamlJWTPayload, StandardJWTPayload}
import com.daml.ledger.api.auth.{
AuthServiceJWTCodec,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.ledger.api.domain.{User, UserRight}
import com.daml.ledger.api.refinements.ApiTypes.ApplicationId
import com.daml.ledger.api.testing.utils.{
@ -140,6 +145,7 @@ trait JsonApiFixture
userId = userId,
participantId = None,
exp = None,
format = StandardJWTTokenFormat.Scope,
)
val header = """{"alg": "HS256", "typ": "JWT"}"""
val jwt = DecodedJwt[String](header, AuthServiceJWTCodec.writeToString(payload))

View File

@ -54,6 +54,21 @@ final case class CustomDamlJWTPayload(
}
}
/** There are two JWT token formats which are currently supported by `StandardJWTPayload`.
* The format is identified by `aud` claim.
*/
sealed trait StandardJWTTokenFormat
object StandardJWTTokenFormat {
/** `Scope` format is for the tokens where scope field contains `daml_ledger_api`.
*/
final case object Scope extends StandardJWTTokenFormat
/** `ParticipantId` format is for the tokens where `aud` claim starts with `https://daml.com/jwt/aud/participant/`
*/
final case object ParticipantId extends StandardJWTTokenFormat
}
/** Payload parsed from the standard "sub", "aud", "exp" claims as specified in
* https://datatracker.ietf.org/doc/html/rfc7519#section-4.1
*
@ -69,6 +84,7 @@ final case class StandardJWTPayload(
userId: String,
participantId: Option[String],
exp: Option[Instant],
format: StandardJWTTokenFormat,
) extends AuthServiceJWTPayload
/** Codec for writing and reading [[AuthServiceJWTPayload]] to and from JSON.
@ -93,9 +109,11 @@ object AuthServiceJWTCodec {
// Unique scope for standard tokens, following the pattern of https://developers.google.com/identity/protocols/oauth2/scopes
final val scopeLedgerApiFull: String = "daml_ledger_api"
private[this] final val audPrefix: String = "https://daml.com/jwt/aud/participant/"
private[this] final val propLedgerId: String = "ledgerId"
private[this] final val propParticipantId: String = "participantId"
private[this] final val propApplicationId: String = "applicationId"
private[this] final val propAud: String = "aud"
private[this] final val propAdmin: String = "admin"
private[this] final val propActAs: String = "actAs"
private[this] final val propReadAs: String = "readAs"
@ -121,13 +139,19 @@ object AuthServiceJWTCodec {
),
propExp -> writeOptionalInstant(v.exp),
)
case v: StandardJWTPayload =>
case v: StandardJWTPayload if v.format == StandardJWTTokenFormat.Scope =>
JsObject(
"aud" -> writeOptionalString(v.participantId),
propAud -> writeOptionalString(v.participantId),
"sub" -> JsString(v.userId),
"exp" -> writeOptionalInstant(v.exp),
"scope" -> JsString(scopeLedgerApiFull),
)
case v: StandardJWTPayload =>
JsObject(
propAud -> JsString(audPrefix + v.participantId.getOrElse("")),
"sub" -> JsString(v.userId),
"exp" -> writeOptionalInstant(v.exp),
)
}
/** Writes the given payload to a compact JSON string */
@ -158,14 +182,34 @@ object AuthServiceJWTCodec {
val scopes = scope.toList.collect({ case JsString(scope) => scope.split(" ") }).flatten
// We're using this rather restrictive test to ensure we continue parsing all legacy sandbox tokens that
// are in use before the 2.0 release; and thereby maintain full backwards compatibility.
if (scopes.contains(scopeLedgerApiFull))
// Standard JWT payload
val audienceValue = readOptionalStringOrSingletonArray(propAud, fields)
if (audienceValue.exists(_.startsWith(audPrefix))) {
// Tokens with audience which starts with `https://daml.com/participant/jwt/aud/participant/${participantId}`
// where `${participantId}` is non-empty string are supported.
// As required for JWTs, additional fields can be in a token but will be ignored (including scope)
audienceValue.map(_.substring(audPrefix.length)).filter(_.nonEmpty) match {
case Some(participantId) =>
StandardJWTPayload(
participantId = Some(participantId),
userId = readOptionalString("sub", fields).get, // guarded by if-clause above
exp = readInstant("exp", fields),
format = StandardJWTTokenFormat.ParticipantId,
)
case _ =>
deserializationError(
s"Could not read ${value.prettyPrint} as AuthServiceJWTPayload: " +
s"`aud` must include participantId value prefixed by $audPrefix"
)
}
} else if (scopes.contains(scopeLedgerApiFull)) {
// We support the tokens with scope containing `daml_ledger_api`, there is no restriction of `aud` field.
StandardJWTPayload(
participantId = readOptionalString("aud", fields),
participantId = audienceValue,
userId = readOptionalString("sub", fields).get, // guarded by if-clause above
exp = readInstant("exp", fields),
format = StandardJWTTokenFormat.Scope,
)
else {
} else {
if (scope.nonEmpty)
logger.warn(
s"Access token with unknown scope \"${scope.get}\" is being parsed as a custom claims token. Issue tokens with adjusted or no scope to get rid of this warning."
@ -191,11 +235,11 @@ object AuthServiceJWTCodec {
.getOrElse(
oidcNamespace,
deserializationError(
s"Can't read ${value.prettyPrint} as AuthServiceJWTPayload: namespace missing"
s"Could not read ${value.prettyPrint} as AuthServiceJWTPayload: namespace missing"
),
)
.asJsObject(
s"Can't read ${value.prettyPrint} as AuthServiceJWTPayload: namespace is not an object"
s"Could not read ${value.prettyPrint} as AuthServiceJWTPayload: namespace is not an object"
)
.fields
CustomDamlJWTPayload(
@ -211,7 +255,7 @@ object AuthServiceJWTCodec {
}
case _ =>
deserializationError(
s"Can't read ${value.prettyPrint} as AuthServiceJWTPayload: value is not an object"
s"Could not read ${value.prettyPrint} as AuthServiceJWTPayload: value is not an object"
)
}
@ -221,7 +265,20 @@ object AuthServiceJWTCodec {
case Some(JsNull) => None
case Some(JsString(value)) => Some(value)
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as string for $name")
deserializationError(s"Could not read ${value.prettyPrint} as string for $name")
}
private[this] def readOptionalStringOrSingletonArray(
name: String,
fields: Map[String, JsValue],
): Option[String] =
fields.get(name) match {
case None => None
case Some(JsNull) => None
case Some(JsString(value)) => Some(value)
case Some(JsArray(Vector(JsString(value)))) => Some(value)
case Some(value) =>
deserializationError(s"Could not read ${value.prettyPrint} as string for $name")
}
private[this] def readOptionalStringList(
@ -234,10 +291,10 @@ object AuthServiceJWTCodec {
values.toList.map {
case JsString(value) => value
case value =>
deserializationError(s"Can't read ${value.prettyPrint} as string element for $name")
deserializationError(s"Could not read ${value.prettyPrint} as string element for $name")
}
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as string list for $name")
deserializationError(s"Could not read ${value.prettyPrint} as string list for $name")
}
private[this] def readOptionalBoolean(
@ -248,7 +305,7 @@ object AuthServiceJWTCodec {
case Some(JsNull) => None
case Some(JsBoolean(value)) => Some(value)
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as boolean for $name")
deserializationError(s"Could not read ${value.prettyPrint} as boolean for $name")
}
private[this] def readInstant(name: String, fields: Map[String, JsValue]): Option[Instant] =
@ -257,7 +314,7 @@ object AuthServiceJWTCodec {
case Some(JsNull) => None
case Some(JsNumber(epochSeconds)) => Some(Instant.ofEpochSecond(epochSeconds.longValue))
case Some(value) =>
deserializationError(s"Can't read ${value.prettyPrint} as epoch seconds for $name")
deserializationError(s"Could not read ${value.prettyPrint} as epoch seconds for $name")
}
// ------------------------------------------------------------------------------------------------------------------

View File

@ -10,7 +10,7 @@ import org.scalatest.wordspec.AnyWordSpec
import spray.json._
import java.time.Instant
import scala.util.{Success, Try}
import scala.util.{Failure, Success, Try}
class AuthServiceJWTCodecSpec
extends AnyWordSpec
@ -48,6 +48,20 @@ class AuthServiceJWTCodecSpec
}
}
private implicit val arbFormat: Arbitrary[StandardJWTTokenFormat] =
Arbitrary(
Gen.oneOf[StandardJWTTokenFormat](
StandardJWTTokenFormat.ParticipantId,
StandardJWTTokenFormat.Scope,
)
)
// participantId is mandatory for the format `StandardJWTTokenFormat.ParticipantId`
private val StandardJWTPayloadGen = Gen.resultOf(StandardJWTPayload).filterNot { payload =>
!payload.participantId
.exists(_.nonEmpty) && payload.format == StandardJWTTokenFormat.ParticipantId
}
"AuthServiceJWTPayload codec" when {
"serializing and parsing a value" should {
@ -60,11 +74,11 @@ class AuthServiceJWTCodecSpec
})
"work for arbitrary standard Daml token values" in forAll(
Gen.resultOf(StandardJWTPayload),
StandardJWTPayloadGen,
minSuccessful(100),
)(value => {
) { value =>
serializeAndParse(value) shouldBe Success(value)
})
}
"support OIDC compliant sandbox format" in {
val serialized =
@ -174,6 +188,7 @@ class AuthServiceJWTCodecSpec
participantId = Some("someParticipantId"),
userId = "someUserId",
exp = Some(Instant.ofEpochSecond(100)),
format = StandardJWTTokenFormat.Scope,
)
parse(serialized) shouldBe Success(expected)
}
@ -191,6 +206,7 @@ class AuthServiceJWTCodecSpec
participantId = Some("someParticipantId"),
userId = "someUserId",
exp = Some(Instant.ofEpochSecond(100)),
format = StandardJWTTokenFormat.Scope,
)
parse(serialized) shouldBe Success(expected)
}
@ -210,6 +226,122 @@ class AuthServiceJWTCodecSpec
parse(serialized) shouldBe Success(expected)
}
"support additional daml user token with prefixed audience" in {
val serialized =
"""{
| "aud": "https://daml.com/jwt/aud/participant/someParticipantId",
| "sub": "someUserId",
| "exp": 100
|}
""".stripMargin
val expected = StandardJWTPayload(
participantId = Some("someParticipantId"),
userId = "someUserId",
exp = Some(Instant.ofEpochSecond(100)),
format = StandardJWTTokenFormat.ParticipantId,
)
parse(serialized) shouldBe Success(expected)
}
"treat a singleton array of audiences equivalent to a string of its first element" in {
val prefixed =
"""{
| "aud": ["https://daml.com/jwt/aud/participant/someParticipantId"],
| "sub": "someUserId",
| "exp": 100
|}
""".stripMargin
parse(prefixed) shouldBe Success(
StandardJWTPayload(
participantId = Some("someParticipantId"),
userId = "someUserId",
exp = Some(Instant.ofEpochSecond(100)),
format = StandardJWTTokenFormat.ParticipantId,
)
)
val standard =
"""{
| "aud": ["someParticipantId"],
| "sub": "someUserId",
| "exp": 100,
| "scope": "dummy-scope1 daml_ledger_api dummy-scope2"
|}
""".stripMargin
parse(standard) shouldBe Success(
StandardJWTPayload(
participantId = Some("someParticipantId"),
userId = "someUserId",
exp = Some(Instant.ofEpochSecond(100)),
format = StandardJWTTokenFormat.Scope,
)
)
}
"support additional daml user token with prefixed audience and provided scope" in {
val serialized =
"""{
| "aud": ["https://daml.com/jwt/aud/participant/someParticipantId"],
| "sub": "someUserId",
| "exp": 100,
| "scope": "daml_ledger_api"
|}
""".stripMargin
val expected = StandardJWTPayload(
participantId = Some("someParticipantId"),
userId = "someUserId",
exp = Some(Instant.ofEpochSecond(100)),
format = StandardJWTTokenFormat.ParticipantId,
)
parse(serialized) shouldBe Success(expected)
}
"reject the token of ParticipantId format with multiple audiences" in {
val serialized =
"""{
| "aud": ["https://daml.com/jwt/aud/participant/someParticipantId",
| "https://daml.com/jwt/aud/participant/someParticipantId2"],
| "sub": "someUserId",
| "exp": 100
|}
""".stripMargin
parse(serialized) shouldBe Failure(
DeserializationException(
"Could not read [\"https://daml.com/jwt/aud/participant/someParticipantId\", " +
"\"https://daml.com/jwt/aud/participant/someParticipantId2\"] as string for aud"
)
)
}
"reject the token of Scope format with multiple audiences" in {
val serialized =
"""{
| "aud": ["someParticipantId",
| "someParticipantId2"],
| "sub": "someUserId",
| "exp": 100,
| "scope": "daml_ledger_api"
|}
""".stripMargin
parse(serialized) shouldBe Failure(
DeserializationException(
"Could not read [\"someParticipantId\", " +
"\"someParticipantId2\"] as string for aud"
)
)
}
"reject the ParticipantId format token with empty participantId" in {
val serialized =
"""{
| "aud": ["https://daml.com/jwt/aud/participant/"],
| "sub": "someUserId",
| "exp": 100
|}
""".stripMargin
parse(serialized).failed.get.getMessage
.contains("must include participantId value prefixed by") shouldBe true
}
}
}
}

View File

@ -6,7 +6,7 @@ package com.daml.ledger.api.benchtool
import com.daml.jwt.JwtSigner
import com.daml.jwt.domain.DecodedJwt
import com.daml.ledger.api.auth.client.LedgerCallCredentials
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, StandardJWTPayload}
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, StandardJWTPayload, StandardJWTTokenFormat}
import io.grpc.stub.AbstractStub
object AuthorizationHelper {
@ -24,6 +24,7 @@ class AuthorizationHelper(val authorizationTokenSecret: String) {
participantId = None,
userId = userId,
exp = None,
format = StandardJWTTokenFormat.Scope,
)
JwtSigner.HMAC256
.sign(

View File

@ -14,6 +14,7 @@ import com.daml.ledger.api.auth.{
AuthServiceJWTPayload,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.ledger.api.domain.LedgerId
import org.scalatest.Suite
@ -43,6 +44,7 @@ trait SandboxRequiringAuthorizationFuns {
participantId = participantId,
userId = userId,
exp = expiresIn.map(delta => Instant.now().plusNanos(delta.toNanos)),
format = StandardJWTTokenFormat.Scope,
)
protected def randomUserId(): String = UUID.randomUUID().toString

View File

@ -197,15 +197,15 @@
- the /auth endpoint given claim token should return unauthorized on insufficient party claims: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L292)
- the /login endpoint should redirect and set the cookie: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L167)
- the /login endpoint should return OK and set cookie without redirectUri: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L196)
- the /login endpoint with an oauth server checking claims should redirect to the configured middleware callback URI: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L402)
- the /login endpoint with an oauth server checking claims should refuse requests when max capacity is reached: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L433)
- the /login endpoint with an oauth server checking claims should refuse requests when max capacity is reached: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L483)
- the /login endpoint with an oauth server checking claims should redirect to the configured middleware callback URI: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L403)
- the /login endpoint with an oauth server checking claims should refuse requests when max capacity is reached: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L434)
- the /login endpoint with an oauth server checking claims should refuse requests when max capacity is reached: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L484)
- the /refresh endpoint should fail on an invalid refresh token: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L261)
- the /refresh endpoint should return a new access token: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L226)
- the TestMiddlewareClientAutoRedirectToLogin client should not redirect to /login for JSON request: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L645)
- the TestMiddlewareClientAutoRedirectToLogin client should redirect to /login for HTML request: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L627)
- the TestMiddlewareClientNoRedirectToLogin client should not redirect to /login: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L544)
- the TestMiddlewareClientYesRedirectToLogin client should redirect to /login: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L593)
- the TestMiddlewareClientAutoRedirectToLogin client should not redirect to /login for JSON request: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L646)
- the TestMiddlewareClientAutoRedirectToLogin client should redirect to /login for HTML request: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L628)
- the TestMiddlewareClientNoRedirectToLogin client should not redirect to /login: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L545)
- the TestMiddlewareClientYesRedirectToLogin client should redirect to /login: [TestMiddleware.scala](triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/TestMiddleware.scala#L594)
## Performance:
- Tail call optimization: Tail recursion does not blow the scala JVM stack.: [TailCallTest.scala](daml-lf/interpreter/src/test/scala/com/digitalasset/daml/lf/speedy/TailCallTest.scala#L16)

View File

@ -5,7 +5,6 @@ package com.daml.auth.oauth2.test.server
import java.time.Instant
import java.util.UUID
import akka.Done
import akka.actor.ActorSystem
import akka.http.scaladsl.Http
@ -22,6 +21,7 @@ import com.daml.ledger.api.auth.{
AuthServiceJWTPayload,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.ledger.api.refinements.ApiTypes.Party
@ -101,7 +101,12 @@ class Server(config: Config) {
case _ => ()
})
if (config.yieldUserTokens) // ignore everything but the applicationId
StandardJWTPayload(userId = applicationId getOrElse "", participantId = None, exp = None)
StandardJWTPayload(
userId = applicationId getOrElse "",
participantId = None,
exp = None,
format = StandardJWTTokenFormat.Scope,
)
else
CustomDamlJWTPayload(
ledgerId = Some(config.ledgerId),

View File

@ -4,7 +4,6 @@
package com.daml.auth.middleware.oauth2
import java.time.Duration
import akka.http.scaladsl.Http
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import akka.http.scaladsl.model._
@ -21,6 +20,7 @@ import com.daml.ledger.api.auth.{
AuthServiceJWTPayload,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.ledger.api.refinements.ApiTypes
import com.daml.ledger.api.refinements.ApiTypes.Party
@ -153,7 +153,7 @@ abstract class TestMiddleware
"accept user tokens" in {
import com.daml.auth.middleware.oauth2.Server.rightsProvideClaims
rightsProvideClaims(
StandardJWTPayload("foo", None, None),
StandardJWTPayload("foo", None, None, StandardJWTTokenFormat.Scope),
Claims(
admin = true,
actAs = List(ApiTypes.Party("Alice")),
@ -388,6 +388,7 @@ class TestMiddlewareUserToken extends TestMiddleware {
userId = "test-application",
participantId = None,
exp = expiresIn.map(in => clock.instant.plus(in)),
format = StandardJWTTokenFormat.Scope,
)
}

View File

@ -26,7 +26,12 @@ import com.daml.dbutils.{ConnectionPool, JdbcConfig}
import com.daml.jwt.domain.DecodedJwt
import com.daml.jwt.{JwtSigner, JwtVerifier, JwtVerifierBase}
import com.daml.ledger.api.auth
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, CustomDamlJWTPayload, StandardJWTPayload}
import com.daml.ledger.api.auth.{
AuthServiceJWTCodec,
CustomDamlJWTPayload,
StandardJWTPayload,
StandardJWTTokenFormat,
}
import com.daml.ledger.api.refinements.ApiTypes
import com.daml.ledger.api.refinements.ApiTypes.ApplicationId
import com.daml.ledger.api.testing.utils.{AkkaBeforeAndAfterAll, OwnedResource}
@ -169,7 +174,12 @@ trait AuthMiddlewareFixture
) = Some {
val payload =
if (sandboxClientTakesUserToken)
StandardJWTPayload(userId = "", participantId = None, exp = None)
StandardJWTPayload(
userId = "",
participantId = None,
exp = None,
format = StandardJWTTokenFormat.Scope,
)
else
CustomDamlJWTPayload(
ledgerId = None,