Add support for EC in the JWKS (#19320)

* Add support for EC in the JWKS

* Fix test evidence
This commit is contained in:
mziolekda 2024-06-05 19:10:24 +02:00 committed by GitHub
parent ce8f2ff7ca
commit 37969fddc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 40 deletions

View File

@ -4,10 +4,11 @@
package com.daml.jwt package com.daml.jwt
import java.net.{URI, URL} import java.net.{URI, URL}
import java.security.interfaces.RSAPublicKey import java.security.interfaces.{ECPublicKey, RSAPublicKey}
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import com.auth0.jwk.UrlJwkProvider import com.auth0.jwk.{JwkException, UrlJwkProvider}
import com.auth0.jwt.algorithms.Algorithm
import com.google.common.cache.{Cache, CacheBuilder} import com.google.common.cache.{Cache, CacheBuilder}
import scalaz.{-\/, Show, \/} import scalaz.{-\/, Show, \/}
import scalaz.syntax.show._ import scalaz.syntax.show._
@ -56,9 +57,23 @@ class JwksVerifier(
.build() .build()
private[this] def getVerifier(keyId: String): Error \/ JwtVerifier = { private[this] def getVerifier(keyId: String): Error \/ JwtVerifier = {
try {
val jwk = http.get(keyId) val jwk = http.get(keyId)
val publicKey = jwk.getPublicKey.asInstanceOf[RSAPublicKey] val publicKey = jwk.getPublicKey
RSA256Verifier(publicKey, jwtTimestampLeeway) publicKey match {
case rsa: RSAPublicKey => RSA256Verifier(rsa, jwtTimestampLeeway)
case ec: ECPublicKey if ec.getParams.getCurve.getField.getFieldSize == 256 =>
ECDSAVerifier(Algorithm.ECDSA256(ec, null), jwtTimestampLeeway)
case ec: ECPublicKey if ec.getParams.getCurve.getField.getFieldSize == 521 =>
ECDSAVerifier(Algorithm.ECDSA512(ec, null), jwtTimestampLeeway)
case key =>
-\/(Error(Symbol("getVerifier"), s"Unsupported public key format ${key.getFormat}"))
}
} catch {
case e: JwkException => -\/(Error(Symbol("getVerifier"), s"Couldn't get jwk from http: $e"))
case _: Throwable =>
-\/(Error(Symbol("getVerifier"), s"Unknown error while getting jwk from http"))
}
} }
/** Looks up the verifier for the given keyId from the local cache. /** Looks up the verifier for the given keyId from the local cache.

View File

@ -97,7 +97,7 @@ object KeyUtils {
} }
} yield key } yield key
/** Generates a JWKS JSON object for the given map of KeyID->Key /** Generates a JWKS JSON object for the given map of KeyID->Key for RSA
* *
* Note: this uses the same format as Google OAuth, see https://www.googleapis.com/oauth2/v3/certs * Note: this uses the same format as Google OAuth, see https://www.googleapis.com/oauth2/v3/certs
*/ */
@ -121,4 +121,33 @@ object KeyUtils {
|} |}
""".stripMargin """.stripMargin
} }
/** Generates a JWKS JSON object for the given map of KeyID->Key for EC
*
* Note: this uses the same format as Google OAuth, see https://www.gstatic.com/iap/verify/public_key-jwk
*/
def generateECJwks(keys: Map[String, ECPublicKey]): String = {
def generateKeyEntry(keyId: String, key: ECPublicKey): String =
s""" {
| "kid": "$keyId",
| "kty": "EC",
| "alg": "ES${key.getParams.getCurve.getField.getFieldSize}",
| "use": "sig",
| "crv": "P-${key.getParams.getCurve.getField.getFieldSize}",
| "x": "${java.util.Base64.getUrlEncoder.encodeToString(
key.getW.getAffineX.toByteArray
)}",
| "y": "${java.util.Base64.getUrlEncoder.encodeToString(
key.getW.getAffineY.toByteArray
)}"
| }""".stripMargin
s"""
|{
| "keys": [
|${keys.toList.map { case (keyId, key) => generateKeyEntry(keyId, key) }.mkString(",\n")}
| ]
|}
""".stripMargin
}
} }

View File

@ -11,12 +11,14 @@ import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers import org.scalatest.matchers.should.Matchers
import scalaz.\/ import scalaz.\/
import scalaz.syntax.show._ import scalaz.syntax.show._
import java.security.{KeyPairGenerator, PrivateKey, PublicKey}
import java.security.interfaces.{ECPrivateKey, ECPublicKey, RSAPrivateKey, RSAPublicKey}
import java.security.spec.ECGenParameterSpec
import java.security.KeyPairGenerator import com.auth0.jwt.algorithms.Algorithm
import java.security.interfaces.{RSAPrivateKey, RSAPublicKey}
import com.daml.http.test.SimpleHttpServer import com.daml.http.test.SimpleHttpServer
class JwksSpec extends AnyFlatSpec with Matchers { trait JwksSpec extends AnyFlatSpec with Matchers { self: JwksSpecKeys =>
val securityAsset: SecurityTest = val securityAsset: SecurityTest =
SecurityTest(property = Authenticity, asset = "JWKS-configured Resource") SecurityTest(property = Authenticity, asset = "JWKS-configured Resource")
@ -27,16 +29,10 @@ class JwksSpec extends AnyFlatSpec with Matchers {
mitigation = s"Refuse to verify authenticity of the token", mitigation = s"Refuse to verify authenticity of the token",
) )
private def generateToken(keyId: String, privateKey: RSAPrivateKey): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "RS256", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.RSA256.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey)
}
it should "successfully verify against provided correct key by JWKS server" taggedAs securityAsset it should "successfully verify against provided correct key by JWKS server" taggedAs securityAsset
.setHappyCase( .setHappyCase(
"Successfully verify against provided correct key by JWKS server" "Successfully verify against provided correct key by JWKS server"
) in new JwksSpec.Scope { ) in {
val token = generateToken("test-key-1", privateKey1) val token = generateToken("test-key-1", privateKey1)
.fold(e => fail("Failed to generate signed token: " + e.shows), x => x) .fold(e => fail("Failed to generate signed token: " + e.shows), x => x)
val result = verifier.verify(token) val result = verifier.verify(token)
@ -48,7 +44,7 @@ class JwksSpec extends AnyFlatSpec with Matchers {
} }
it should "raise an error by verifying a token with an unknown key id" taggedAs securityAsset it should "raise an error by verifying a token with an unknown key id" taggedAs securityAsset
.setAttack(attack(threat = "Present an unknown key-id")) in new JwksSpec.Scope { .setAttack(attack(threat = "Present an unknown key-id")) in {
val token = generateToken("test-key-unknown", privateKey1) val token = generateToken("test-key-unknown", privateKey1)
.fold(e => fail("Failed to generate signed token: " + e.shows), x => x) .fold(e => fail("Failed to generate signed token: " + e.shows), x => x)
val result = verifier.verify(token) val result = verifier.verify(token)
@ -59,7 +55,7 @@ class JwksSpec extends AnyFlatSpec with Matchers {
it should "raise an error by verifying a token with wrong public key" taggedAs securityAsset it should "raise an error by verifying a token with wrong public key" taggedAs securityAsset
.setAttack( .setAttack(
attack(threat = "Present a known key-id, but not the one used for the token encryption") attack(threat = "Present a known key-id, but not the one used for the token encryption")
) in new JwksSpec.Scope { ) in {
val token = generateToken("test-key-1", privateKey2) val token = generateToken("test-key-1", privateKey2)
.fold(e => fail("Failed to generate signed token: " + e.shows), x => x) .fold(e => fail("Failed to generate signed token: " + e.shows), x => x)
val result = verifier.verify(token) val result = verifier.verify(token)
@ -71,32 +67,100 @@ class JwksSpec extends AnyFlatSpec with Matchers {
} }
} }
object JwksSpec { trait JwksSpecKeys {
trait Scope {
// Generate some RSA key pairs
private val keySize = 2048
private val kpg = KeyPairGenerator.getInstance("RSA")
kpg.initialize(keySize)
protected type PublicKeyType <: PublicKey
protected type PrivateKeyType <: PrivateKey
protected def kpg: KeyPairGenerator
protected def jwks: String
protected def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt
// Generate some RSA key pairs
private val keyPair1 = kpg.generateKeyPair() private val keyPair1 = kpg.generateKeyPair()
private val publicKey1 = keyPair1.getPublic.asInstanceOf[RSAPublicKey] protected val publicKey1: PublicKeyType = keyPair1.getPublic.asInstanceOf[PublicKeyType]
val privateKey1 = keyPair1.getPrivate.asInstanceOf[RSAPrivateKey] val privateKey1: PrivateKeyType = keyPair1.getPrivate.asInstanceOf[PrivateKeyType]
private val keyPair2 = kpg.generateKeyPair() private val keyPair2 = kpg.generateKeyPair()
private val publicKey2 = keyPair2.getPublic.asInstanceOf[RSAPublicKey] protected val publicKey2: PublicKeyType = keyPair2.getPublic.asInstanceOf[PublicKeyType]
val privateKey2 = keyPair2.getPrivate.asInstanceOf[RSAPrivateKey] val privateKey2: PrivateKeyType = keyPair2.getPrivate.asInstanceOf[PrivateKeyType]
// Start a JWKS server and create a verifier using the JWKS server private val server = SimpleHttpServer.start(jwks)
private val jwks = KeyUtils.generateJwks( private val url = SimpleHttpServer.responseUrl(server)
protected val verifier: JwksVerifier = JwksVerifier(url)
}
class JwksSpecRSA extends JwksSpec with JwksSpecKeys {
private val keySize = 2048
override type PublicKeyType = RSAPublicKey
override type PrivateKeyType = RSAPrivateKey
override def kpg: KeyPairGenerator = KeyPairGenerator.getInstance("RSA")
kpg.initialize(keySize)
override def jwks: String = KeyUtils.generateJwks(
Map( Map(
"test-key-1" -> publicKey1, "test-key-1" -> publicKey1,
"test-key-2" -> publicKey2, "test-key-2" -> publicKey2,
) )
) )
private val server = SimpleHttpServer.start(jwks) override def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt = {
private val url = SimpleHttpServer.responseUrl(server) val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "RS256", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.RSA256.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey)
}
}
val verifier = JwksVerifier(url) class JwksSpecES256 extends JwksSpec with JwksSpecKeys {
override type PublicKeyType = ECPublicKey
override type PrivateKeyType = ECPrivateKey
protected def kpg: KeyPairGenerator = {
val gen = KeyPairGenerator.getInstance("EC")
gen.initialize(new ECGenParameterSpec("secp256r1"))
gen
}
protected def jwks: String = KeyUtils.generateECJwks(
Map(
"test-key-1" -> publicKey1,
"test-key-2" -> publicKey2,
)
)
protected def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "ES256", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.ECDSA.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey, Algorithm.ECDSA256(null, _))
}
}
class JwksSpecES512 extends JwksSpec with JwksSpecKeys {
override type PublicKeyType = ECPublicKey
override type PrivateKeyType = ECPrivateKey
protected def kpg: KeyPairGenerator = {
val gen = KeyPairGenerator.getInstance("EC")
gen.initialize(new ECGenParameterSpec("secp521r1"))
gen
}
protected def jwks: String = KeyUtils.generateECJwks(
Map(
"test-key-1" -> publicKey1,
"test-key-2" -> publicKey2,
)
)
protected def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "ES512", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.ECDSA.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey, Algorithm.ECDSA512(null, _))
} }
} }

View File

@ -41,6 +41,7 @@ write_scalatest_runpath(
"//libs-scala/jwt:tests-lib", "//libs-scala/jwt:tests-lib",
], ],
runtime_deps = [ runtime_deps = [
"//libs-scala/http-test-utils",
"//libs-scala/jwt", "//libs-scala/jwt",
"//libs-scala/scalatest-utils", "//libs-scala/scalatest-utils",
], ],