Add support for EC in the JWKS (#19320)

* Add support for EC in the JWKS

* Fix test evidence
This commit is contained in:
mziolekda 2024-06-05 19:10:24 +02:00 committed by GitHub
parent ce8f2ff7ca
commit 37969fddc3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 40 deletions

View File

@ -4,10 +4,11 @@
package com.daml.jwt
import java.net.{URI, URL}
import java.security.interfaces.RSAPublicKey
import java.security.interfaces.{ECPublicKey, RSAPublicKey}
import java.util.concurrent.TimeUnit
import com.auth0.jwk.UrlJwkProvider
import com.auth0.jwk.{JwkException, UrlJwkProvider}
import com.auth0.jwt.algorithms.Algorithm
import com.google.common.cache.{Cache, CacheBuilder}
import scalaz.{-\/, Show, \/}
import scalaz.syntax.show._
@ -56,9 +57,23 @@ class JwksVerifier(
.build()
private[this] def getVerifier(keyId: String): Error \/ JwtVerifier = {
val jwk = http.get(keyId)
val publicKey = jwk.getPublicKey.asInstanceOf[RSAPublicKey]
RSA256Verifier(publicKey, jwtTimestampLeeway)
try {
val jwk = http.get(keyId)
val publicKey = jwk.getPublicKey
publicKey match {
case rsa: RSAPublicKey => RSA256Verifier(rsa, jwtTimestampLeeway)
case ec: ECPublicKey if ec.getParams.getCurve.getField.getFieldSize == 256 =>
ECDSAVerifier(Algorithm.ECDSA256(ec, null), jwtTimestampLeeway)
case ec: ECPublicKey if ec.getParams.getCurve.getField.getFieldSize == 521 =>
ECDSAVerifier(Algorithm.ECDSA512(ec, null), jwtTimestampLeeway)
case key =>
-\/(Error(Symbol("getVerifier"), s"Unsupported public key format ${key.getFormat}"))
}
} catch {
case e: JwkException => -\/(Error(Symbol("getVerifier"), s"Couldn't get jwk from http: $e"))
case _: Throwable =>
-\/(Error(Symbol("getVerifier"), s"Unknown error while getting jwk from http"))
}
}
/** Looks up the verifier for the given keyId from the local cache.

View File

@ -97,7 +97,7 @@ object KeyUtils {
}
} yield key
/** Generates a JWKS JSON object for the given map of KeyID->Key
/** Generates a JWKS JSON object for the given map of KeyID->Key for RSA
*
* Note: this uses the same format as Google OAuth, see https://www.googleapis.com/oauth2/v3/certs
*/
@ -121,4 +121,33 @@ object KeyUtils {
|}
""".stripMargin
}
/** Generates a JWKS JSON object for the given map of KeyID->Key for EC
*
* Note: this uses the same format as Google OAuth, see https://www.gstatic.com/iap/verify/public_key-jwk
*/
def generateECJwks(keys: Map[String, ECPublicKey]): String = {
def generateKeyEntry(keyId: String, key: ECPublicKey): String =
s""" {
| "kid": "$keyId",
| "kty": "EC",
| "alg": "ES${key.getParams.getCurve.getField.getFieldSize}",
| "use": "sig",
| "crv": "P-${key.getParams.getCurve.getField.getFieldSize}",
| "x": "${java.util.Base64.getUrlEncoder.encodeToString(
key.getW.getAffineX.toByteArray
)}",
| "y": "${java.util.Base64.getUrlEncoder.encodeToString(
key.getW.getAffineY.toByteArray
)}"
| }""".stripMargin
s"""
|{
| "keys": [
|${keys.toList.map { case (keyId, key) => generateKeyEntry(keyId, key) }.mkString(",\n")}
| ]
|}
""".stripMargin
}
}

View File

@ -11,12 +11,14 @@ import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import scalaz.\/
import scalaz.syntax.show._
import java.security.{KeyPairGenerator, PrivateKey, PublicKey}
import java.security.interfaces.{ECPrivateKey, ECPublicKey, RSAPrivateKey, RSAPublicKey}
import java.security.spec.ECGenParameterSpec
import java.security.KeyPairGenerator
import java.security.interfaces.{RSAPrivateKey, RSAPublicKey}
import com.auth0.jwt.algorithms.Algorithm
import com.daml.http.test.SimpleHttpServer
class JwksSpec extends AnyFlatSpec with Matchers {
trait JwksSpec extends AnyFlatSpec with Matchers { self: JwksSpecKeys =>
val securityAsset: SecurityTest =
SecurityTest(property = Authenticity, asset = "JWKS-configured Resource")
@ -27,16 +29,10 @@ class JwksSpec extends AnyFlatSpec with Matchers {
mitigation = s"Refuse to verify authenticity of the token",
)
private def generateToken(keyId: String, privateKey: RSAPrivateKey): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "RS256", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.RSA256.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey)
}
it should "successfully verify against provided correct key by JWKS server" taggedAs securityAsset
.setHappyCase(
"Successfully verify against provided correct key by JWKS server"
) in new JwksSpec.Scope {
) in {
val token = generateToken("test-key-1", privateKey1)
.fold(e => fail("Failed to generate signed token: " + e.shows), x => x)
val result = verifier.verify(token)
@ -48,7 +44,7 @@ class JwksSpec extends AnyFlatSpec with Matchers {
}
it should "raise an error by verifying a token with an unknown key id" taggedAs securityAsset
.setAttack(attack(threat = "Present an unknown key-id")) in new JwksSpec.Scope {
.setAttack(attack(threat = "Present an unknown key-id")) in {
val token = generateToken("test-key-unknown", privateKey1)
.fold(e => fail("Failed to generate signed token: " + e.shows), x => x)
val result = verifier.verify(token)
@ -59,7 +55,7 @@ class JwksSpec extends AnyFlatSpec with Matchers {
it should "raise an error by verifying a token with wrong public key" taggedAs securityAsset
.setAttack(
attack(threat = "Present a known key-id, but not the one used for the token encryption")
) in new JwksSpec.Scope {
) in {
val token = generateToken("test-key-1", privateKey2)
.fold(e => fail("Failed to generate signed token: " + e.shows), x => x)
val result = verifier.verify(token)
@ -71,32 +67,100 @@ class JwksSpec extends AnyFlatSpec with Matchers {
}
}
object JwksSpec {
trait Scope {
// Generate some RSA key pairs
private val keySize = 2048
private val kpg = KeyPairGenerator.getInstance("RSA")
kpg.initialize(keySize)
trait JwksSpecKeys {
private val keyPair1 = kpg.generateKeyPair()
private val publicKey1 = keyPair1.getPublic.asInstanceOf[RSAPublicKey]
val privateKey1 = keyPair1.getPrivate.asInstanceOf[RSAPrivateKey]
protected type PublicKeyType <: PublicKey
protected type PrivateKeyType <: PrivateKey
private val keyPair2 = kpg.generateKeyPair()
private val publicKey2 = keyPair2.getPublic.asInstanceOf[RSAPublicKey]
val privateKey2 = keyPair2.getPrivate.asInstanceOf[RSAPrivateKey]
protected def kpg: KeyPairGenerator
protected def jwks: String
protected def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt
// Start a JWKS server and create a verifier using the JWKS server
private val jwks = KeyUtils.generateJwks(
Map(
"test-key-1" -> publicKey1,
"test-key-2" -> publicKey2,
)
// Generate some RSA key pairs
private val keyPair1 = kpg.generateKeyPair()
protected val publicKey1: PublicKeyType = keyPair1.getPublic.asInstanceOf[PublicKeyType]
val privateKey1: PrivateKeyType = keyPair1.getPrivate.asInstanceOf[PrivateKeyType]
private val keyPair2 = kpg.generateKeyPair()
protected val publicKey2: PublicKeyType = keyPair2.getPublic.asInstanceOf[PublicKeyType]
val privateKey2: PrivateKeyType = keyPair2.getPrivate.asInstanceOf[PrivateKeyType]
private val server = SimpleHttpServer.start(jwks)
private val url = SimpleHttpServer.responseUrl(server)
protected val verifier: JwksVerifier = JwksVerifier(url)
}
class JwksSpecRSA extends JwksSpec with JwksSpecKeys {
private val keySize = 2048
override type PublicKeyType = RSAPublicKey
override type PrivateKeyType = RSAPrivateKey
override def kpg: KeyPairGenerator = KeyPairGenerator.getInstance("RSA")
kpg.initialize(keySize)
override def jwks: String = KeyUtils.generateJwks(
Map(
"test-key-1" -> publicKey1,
"test-key-2" -> publicKey2,
)
)
private val server = SimpleHttpServer.start(jwks)
private val url = SimpleHttpServer.responseUrl(server)
val verifier = JwksVerifier(url)
override def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "RS256", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.RSA256.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey)
}
}
class JwksSpecES256 extends JwksSpec with JwksSpecKeys {
override type PublicKeyType = ECPublicKey
override type PrivateKeyType = ECPrivateKey
protected def kpg: KeyPairGenerator = {
val gen = KeyPairGenerator.getInstance("EC")
gen.initialize(new ECGenParameterSpec("secp256r1"))
gen
}
protected def jwks: String = KeyUtils.generateECJwks(
Map(
"test-key-1" -> publicKey1,
"test-key-2" -> publicKey2,
)
)
protected def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "ES256", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.ECDSA.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey, Algorithm.ECDSA256(null, _))
}
}
class JwksSpecES512 extends JwksSpec with JwksSpecKeys {
override type PublicKeyType = ECPublicKey
override type PrivateKeyType = ECPrivateKey
protected def kpg: KeyPairGenerator = {
val gen = KeyPairGenerator.getInstance("EC")
gen.initialize(new ECGenParameterSpec("secp521r1"))
gen
}
protected def jwks: String = KeyUtils.generateECJwks(
Map(
"test-key-1" -> publicKey1,
"test-key-2" -> publicKey2,
)
)
protected def generateToken(keyId: String, privateKey: PrivateKeyType): JwtSigner.Error \/ Jwt = {
val jwtPayload = s"""{"test": "JwksSpec"}"""
val jwtHeader = s"""{"alg": "ES512", "typ": "JWT", "kid": "$keyId"}"""
JwtSigner.ECDSA.sign(DecodedJwt(jwtHeader, jwtPayload), privateKey, Algorithm.ECDSA512(null, _))
}
}

View File

@ -41,6 +41,7 @@ write_scalatest_runpath(
"//libs-scala/jwt:tests-lib",
],
runtime_deps = [
"//libs-scala/http-test-utils",
"//libs-scala/jwt",
"//libs-scala/scalatest-utils",
],