[Trigger Service/Oauth2-Middleware] Hocon config refactor (#12228)

* Changes to add a pureconfig-util module with some shared config readers, and cleanup some code from oauth2-middleware hocon

CHANGELOG_BEGIN
CHANGELOG_END

* Update triggers/service/auth/src/test/scala/com/daml/auth/middleware/oauth2/CliSpec.scala

Co-authored-by: Stephen Compall <stephen.compall@daml.com>

Co-authored-by: Stephen Compall <stephen.compall@daml.com>
This commit is contained in:
akshayshirahatti-da 2022-01-07 10:35:31 +00:00 committed by GitHub
parent 8fdc871048
commit 19fe4266ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 267 additions and 165 deletions

View File

@ -0,0 +1,63 @@
# Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"//bazel_tools:scala.bzl",
"da_scala_library",
"da_scala_test",
"lf_scalacopts",
"silencer_plugin",
)
da_scala_library(
name = "pureconfig-utils",
srcs = glob(["src/main/scala/**/*.scala"]),
plugins = [
silencer_plugin,
],
scala_deps = [
"@maven//:com_chuusai_shapeless",
"@maven//:com_github_pureconfig_pureconfig_core",
"@maven//:com_github_pureconfig_pureconfig_generic",
"@maven//:com_typesafe_akka_akka_http_core",
"@maven//:com_typesafe_akka_akka_parsing",
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:com_github_scopt_scopt",
"@maven//:org_scalaz_scalaz_core",
],
scalacopts = lf_scalacopts,
# tags = ["maven_coordinates=com.daml:pureconfig-utils:__VERSION__"],
visibility = [
"//visibility:public",
],
runtime_deps = [
"@maven//:ch_qos_logback_logback_classic",
],
deps = [
"//ledger-service/jwt",
"//ledger/ledger-api-common",
"//libs-scala/db-utils",
"@maven//:com_auth0_java_jwt",
"@maven//:com_typesafe_config",
],
)
da_scala_test(
name = "tests",
size = "medium",
srcs = glob(["src/test/scala/**/*.scala"]),
scala_deps = [
"@maven//:com_chuusai_shapeless",
"@maven//:com_github_pureconfig_pureconfig_core",
"@maven//:com_github_pureconfig_pureconfig_generic",
"@maven//:org_scalatest_scalatest_core",
"@maven//:org_scalatest_scalatest_matchers_core",
"@maven//:org_scalatest_scalatest_shouldmatchers",
"@maven//:org_scalatest_scalatest_wordspec",
],
scalacopts = lf_scalacopts,
deps = [
":pureconfig-utils",
"@maven//:org_scalatest_scalatest_compatible",
],
)

View File

@ -0,0 +1,69 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.pureconfigutils
import akka.http.scaladsl.model.Uri
import com.auth0.jwt.algorithms.Algorithm
import com.daml.dbutils.JdbcConfig
import com.daml.jwt.{ECDSAVerifier, HMAC256Verifier, JwksVerifier, JwtVerifierBase, RSA256Verifier}
import com.daml.platform.services.time.TimeProviderType
import pureconfig.{ConfigReader, ConvertHelpers}
import pureconfig.generic.semiauto.deriveReader
import java.nio.file.Path
import scala.concurrent.duration.FiniteDuration
final case class HttpServerConfig(address: String, port: Int, portFile: Option[Path] = None)
final case class LedgerApiConfig(address: String, port: Int)
final case class MetricsConfig(reporter: String, reportingInterval: FiniteDuration)
object SharedConfigReaders {
implicit val tokenVerifierReader: ConfigReader[JwtVerifierBase] =
ConfigReader.forProduct2[JwtVerifierBase, String, String]("type", "uri") {
case (t: String, p: String) =>
// hs256-unsafe, rs256-crt, es256-crt, es512-crt, rs256-jwks
t match {
case "hs256-unsafe" =>
HMAC256Verifier(p)
.valueOr(err => sys.error(s"Failed to create HMAC256 verifier: $err"))
case "rs256-crt" =>
RSA256Verifier
.fromCrtFile(p)
.valueOr(err => sys.error(s"Failed to create RSA256 verifier: $err"))
case "es256-crt" =>
ECDSAVerifier
.fromCrtFile(p, Algorithm.ECDSA256(_, null))
.valueOr(err => sys.error(s"Failed to create ECDSA256 verifier: $err"))
case "es512-crt" =>
ECDSAVerifier
.fromCrtFile(p, Algorithm.ECDSA512(_, null))
.valueOr(err => sys.error(s"Failed to create ECDSA512 verifier: $err"))
case "rs256-jwks" =>
JwksVerifier(p)
}
}
implicit val uriCfgReader: ConfigReader[Uri] =
ConfigReader.fromString[Uri](ConvertHelpers.catchReadError(s => Uri(s)))
implicit val timeProviderTypeCfgReader: ConfigReader[TimeProviderType] =
ConfigReader.fromString[TimeProviderType](ConvertHelpers.catchReadError { s =>
s.toLowerCase() match {
case "static" => TimeProviderType.Static
case "wall-clock" => TimeProviderType.WallClock
case s =>
throw new IllegalArgumentException(
s"Value '$s' for time-provider-type is not one of 'static' or 'wall-clock'"
)
}
})
implicit val jdbcCfgReader: ConfigReader[JdbcConfig] = deriveReader[JdbcConfig]
implicit val httpServerCfgReader: ConfigReader[HttpServerConfig] =
deriveReader[HttpServerConfig]
implicit val ledgerApiConfReader: ConfigReader[LedgerApiConfig] =
deriveReader[LedgerApiConfig]
implicit val metricsConfigReader: ConfigReader[MetricsConfig] = deriveReader[MetricsConfig]
}

View File

@ -0,0 +1,47 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.pureconfigutils
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import pureconfig.generic.semiauto.deriveReader
import pureconfig.{ConfigReader, ConfigSource}
class SharedConfigReadersTest extends AsyncWordSpec with Matchers {
import SharedConfigReaders._
case class SampleServiceConfig(
server: HttpServerConfig,
ledgerApi: LedgerApiConfig,
metrics: MetricsConfig,
)
implicit val serviceConfigReader: ConfigReader[SampleServiceConfig] =
deriveReader[SampleServiceConfig]
"should be able to parse a sample config with shared config objects" in {
val conf = """
|{
| server {
| address = "127.0.0.1"
| port = 8890
| port-file = "port-file"
| }
| ledger-api {
| address = "127.0.0.1"
| port = 8098
| }
| metrics {
| reporter = "console"
| reporting-interval = 10s
| }
|}
|""".stripMargin
ConfigSource.string(conf).load[SampleServiceConfig] match {
case Right(_) => succeed
case Left(ex) => fail(s"Failed to successfully parse service conf: ${ex.head.description}")
}
}
}

View File

@ -67,6 +67,7 @@ da_scala_library(
"//ledger-api/rs-grpc-akka",
"//ledger-api/rs-grpc-bridge",
"//ledger-service/cli-opts",
"//ledger-service/pureconfig-utils",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//libs-scala/contextualized-logging",
@ -231,6 +232,7 @@ da_scala_test_suite(
"//ledger-api/testing-utils",
"//ledger-service/cli-opts",
"//ledger-service/jwt",
"//ledger-service/pureconfig-utils",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-common",
"//ledger/ledger-resources",

View File

@ -87,10 +87,10 @@ da_scala_library(
"//language-support/scala/bindings",
"//ledger-service/cli-opts",
"//ledger-service/jwt",
"//ledger-service/pureconfig-utils",
"//ledger/cli-opts",
"//ledger/ledger-api-auth",
"//libs-scala/ports",
"@maven//:com_auth0_java_jwt",
"@maven//:com_typesafe_config",
"@maven//:org_slf4j_slf4j_api",
],
@ -193,6 +193,7 @@ da_scala_test(
"@maven//:com_typesafe_scala_logging_scala_logging",
"@maven//:io_spray_spray_json",
"@maven//:org_scalaz_scalaz_core",
"@maven//:com_github_pureconfig_pureconfig_core",
],
scala_runtime_deps = [
"@maven//:com_typesafe_akka_akka_stream_testkit",

View File

@ -4,7 +4,6 @@
package com.daml.auth.middleware.oauth2
import akka.http.scaladsl.model.Uri
import com.auth0.jwt.algorithms.Algorithm
import com.daml.auth.middleware.oauth2.Config.{
DefaultCookieSecure,
DefaultHttpPort,
@ -12,35 +11,20 @@ import com.daml.auth.middleware.oauth2.Config.{
DefaultMaxLoginRequests,
}
import com.daml.cliopts
import com.daml.jwt.{
ECDSAVerifier,
HMAC256Verifier,
JwksVerifier,
JwtVerifierBase,
JwtVerifierConfigurationCli,
RSA256Verifier,
}
import com.daml.jwt.{JwtVerifierBase, JwtVerifierConfigurationCli}
import com.typesafe.scalalogging.StrictLogging
import pureconfig.{ConfigReader, ConfigSource, ConvertHelpers}
import pureconfig.error.ConfigReaderException
import pureconfig.ConfigSource
import pureconfig.error.ConfigReaderFailures
import scopt.OptionParser
import java.io.File
import pureconfig.generic.semiauto._
import java.nio.file.{Path, Paths}
import scala.concurrent.duration
import scala.concurrent.duration.FiniteDuration
import scalaz.syntax.std.option._
sealed trait ConfigError extends Product with Serializable {
def msg: String
}
case object MissingConfigError extends ConfigError {
val msg = "Missing auth middleware config file"
}
final case class ConfigParseError(msg: String) extends ConfigError
final case class Cli(
private[oauth2] final case class Cli(
configFile: Option[File] = None,
// Host and port the middleware listens on
address: String = cliopts.Http.defaultAddress,
@ -64,24 +48,13 @@ final case class Cli(
clientSecret: SecretString,
// Token verification
tokenVerifier: JwtVerifierBase,
) {
) extends StrictLogging {
import Cli._
def loadConfigFromFile: Either[ConfigError, Config] = {
require(configFile.nonEmpty, "Config file should be defined to load app config")
configFile
.map(f =>
try {
Right(ConfigSource.file(f).loadOrThrow[Config])
} catch {
case ex: ConfigReaderException[_] => Left(ConfigParseError(ex.failures.head.description))
}
)
.get
def loadFromConfigFile: Option[Either[ConfigReaderFailures, Config]] = {
configFile.map(cf => ConfigSource.file(cf).load[Config])
}
def loadConfigFromCliArgs: Config = {
def loadFromCliArgs: Config = {
val cfg = Config(
address,
port,
@ -102,40 +75,28 @@ final case class Cli(
cfg.validate
cfg
}
def loadConfig: Option[Config] = {
loadFromConfigFile.cata(
{
case Right(cfg) => Some(cfg)
case Left(ex) =>
logger.error(
s"Error loading oauth2-middleware config from file ${configFile}",
ex.prettyPrint(),
)
None
}, {
logger.warn("Using cli opts for running oauth2-middleware is deprecated")
Some(loadFromCliArgs)
},
)
}
}
object Cli extends StrictLogging {
implicit val tokenVerifierReader: ConfigReader[JwtVerifierBase] =
ConfigReader.forProduct2[JwtVerifierBase, String, String]("type", "uri") {
case (t: String, p: String) =>
// hs256-unsafe, rs256-crt, es256-crt, es512-crt, rs256-jwks
t match {
case "hs256-unsafe" =>
HMAC256Verifier(p)
.valueOr(err => sys.error(s"Failed to create HMAC256 verifier: $err"))
case "rs256-crt" =>
RSA256Verifier
.fromCrtFile(p)
.valueOr(err => sys.error(s"Failed to create RSA256 verifier: $err"))
case "es256-crt" =>
ECDSAVerifier
.fromCrtFile(p, Algorithm.ECDSA256(_, null))
.valueOr(err => sys.error(s"Failed to create ECDSA256 verifier: $err"))
case "es512-crt" =>
ECDSAVerifier
.fromCrtFile(p, Algorithm.ECDSA512(_, null))
.valueOr(err => sys.error(s"Failed to create ECDSA512 verifier: $err"))
case "rs256-jwks" =>
JwksVerifier(p)
}
}
lazy implicit val uriReader: ConfigReader[Uri] =
ConfigReader.fromString[Uri](ConvertHelpers.catchReadError(s => Uri(s)))
lazy implicit val clientSecretReader: ConfigReader[SecretString] =
ConfigReader.fromString[SecretString](ConvertHelpers.catchReadError(s => SecretString(s)))
lazy implicit val cfgReader: ConfigReader[Config] = deriveReader[Config]
private[oauth2] object Cli {
private val Empty =
private[oauth2] val Default =
Cli(
configFile = None,
address = cliopts.Http.defaultAddress,
@ -259,22 +220,10 @@ object Cli extends StrictLogging {
override def showUsageOnError: Option[Boolean] = Some(true)
}
def parse(args: Array[String]): Option[Cli] = parser.parse(args, Empty)
def parse(args: Array[String]): Option[Cli] = parser.parse(args, Default)
def parseConfig(args: Array[String]): Option[Config] = {
val cli = parse(args)
cli.flatMap { c =>
if (c.configFile.isDefined) {
c.loadConfigFromFile match {
case Right(conf) => Some(conf)
case Left(err) =>
logger.error(s"Unable to start oauth2-middleware using config: ${err.msg}")
None
}
} else {
logger.warn("Using cli opts for running oauth2-middleware is deprecated")
Some(c.loadConfigFromCliArgs)
}
}
cli.flatMap(_.loadConfig)
}
}

View File

@ -8,11 +8,13 @@ import akka.http.scaladsl.model.Uri
import com.daml.auth.middleware.oauth2.Config._
import com.daml.cliopts
import com.daml.jwt.JwtVerifierBase
import com.daml.pureconfigutils.SharedConfigReaders._
import pureconfig.{ConfigReader, ConvertHelpers}
import pureconfig.generic.semiauto.deriveReader
import scala.concurrent.duration
import scala.concurrent.duration.FiniteDuration
import scala.concurrent.duration._
case class Config(
final case class Config(
// Host and port the middleware listens on
address: String = cliopts.Http.defaultAddress,
port: Int = DefaultHttpPort,
@ -27,9 +29,9 @@ case class Config(
oauthAuth: Uri,
oauthToken: Uri,
// OAuth2 server request templates
oauthAuthTemplate: Option[Path],
oauthTokenTemplate: Option[Path],
oauthRefreshTemplate: Option[Path],
oauthAuthTemplate: Option[Path] = None,
oauthTokenTemplate: Option[Path] = None,
oauthRefreshTemplate: Option[Path] = None,
// OAuth2 client properties
clientId: String,
clientSecret: SecretString,
@ -45,7 +47,7 @@ case class Config(
}
}
case class SecretString(value: String) {
final case class SecretString(value: String) {
override def toString: String = "###"
}
@ -53,5 +55,9 @@ object Config {
val DefaultHttpPort: Int = 3000
val DefaultCookieSecure: Boolean = true
val DefaultMaxLoginRequests: Int = 100
val DefaultLoginTimeout: FiniteDuration = FiniteDuration(5, duration.MINUTES)
val DefaultLoginTimeout: FiniteDuration = 5.minutes
implicit val clientSecretReader: ConfigReader[SecretString] =
ConfigReader.fromString[SecretString](ConvertHelpers.catchReadError(s => SecretString(s)))
implicit val cfgReader: ConfigReader[Config] = deriveReader[Config]
}

View File

@ -8,11 +8,21 @@ import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import com.daml.bazeltools.BazelRunfiles.requiredResource
import com.daml.jwt.JwksVerifier
import org.scalatest.Inside.inside
import pureconfig.error.{CannotReadFile, ConfigReaderFailures}
import java.nio.file.Paths
import scala.concurrent.duration._
class CliSpec extends AsyncWordSpec with Matchers {
val minimalCfg = Config(
oauthAuth = Uri("https://oauth2/uri"),
oauthToken = Uri("https://oauth2/token"),
callbackUri = Some(Uri("https://example.com/auth/cb")),
clientId = sys.env.getOrElse("DAML_CLIENT_ID", "foo"),
clientSecret = SecretString(sys.env.getOrElse("DAML_CLIENT_SECRET", "bar")),
tokenVerifier = null,
)
val confFile = "triggers/service/auth/src/test/resources/oauth2-middleware.conf"
def loadCli(file: String): Cli = {
Cli.parse(Array("--config", file)).getOrElse(fail("Could not load Cli on parse"))
@ -29,30 +39,14 @@ class CliSpec extends AsyncWordSpec with Matchers {
requiredResource("triggers/service/auth/src/test/resources/oauth2-middleware-minimal.conf")
val cli = loadCli(file.getAbsolutePath)
cli.configFile should not be empty
cli.loadConfigFromFile match {
case Left(ex) => fail(ex.msg)
case Right(c) =>
c.address shouldBe "127.0.0.1"
c.port shouldBe Config.DefaultHttpPort
c.callbackUri shouldBe Some(Uri("https://example.com/auth/cb"))
c.maxLoginRequests shouldBe Config.DefaultMaxLoginRequests
c.loginTimeout shouldBe Config.DefaultLoginTimeout
c.cookieSecure shouldBe Config.DefaultCookieSecure
c.oauthAuth shouldBe Uri("https://oauth2/uri")
c.oauthToken shouldBe Uri("https://oauth2/token")
c.oauthAuthTemplate shouldBe None
c.oauthTokenTemplate shouldBe None
c.oauthRefreshTemplate shouldBe None
c.clientId shouldBe sys.env.getOrElse("DAML_CLIENT_ID", "foo")
c.clientSecret shouldBe SecretString(sys.env.getOrElse("DAML_CLIENT_SECRET", "bar"))
// token verifier needs to be set.
c.tokenVerifier match {
case _: JwksVerifier => succeed
case _ => fail("expected JwksVerifier based on supplied config")
}
val cfg = cli.loadFromConfigFile
inside(cfg) { case Some(Right(c)) =>
c.copy(tokenVerifier = null) shouldBe minimalCfg
// token verifier needs to be set.
c.tokenVerifier match {
case _: JwksVerifier => succeed
case _ => fail("expected JwksVerifier based on supplied config")
}
}
}
@ -60,46 +54,34 @@ class CliSpec extends AsyncWordSpec with Matchers {
val file = requiredResource(confFile)
val cli = loadCli(file.getAbsolutePath)
cli.configFile should not be empty
cli.loadConfigFromFile match {
case Left(ex) => fail(ex.msg)
case Right(c) =>
c.address shouldBe "127.0.0.1"
c.port shouldBe 3000
c.callbackUri shouldBe Some(Uri("https://example.com/auth/cb"))
c.maxLoginRequests shouldBe 10
c.loginTimeout shouldBe FiniteDuration(60, SECONDS)
c.cookieSecure shouldBe false
c.oauthAuth shouldBe Uri("https://oauth2/uri")
c.oauthToken shouldBe Uri("https://oauth2/token")
c.oauthAuthTemplate shouldBe Some(Paths.get("auth_template"))
c.oauthTokenTemplate shouldBe Some(Paths.get("token_template"))
c.oauthRefreshTemplate shouldBe Some(Paths.get("refresh_template"))
c.clientId shouldBe sys.env.getOrElse("DAML_CLIENT_ID", "foo")
c.clientSecret shouldBe SecretString(sys.env.getOrElse("DAML_CLIENT_SECRET", "bar"))
c.tokenVerifier match {
case _: JwksVerifier => succeed
case _ => fail("expected JwksVerifier based on supplied config")
}
val cfg = cli.loadFromConfigFile
inside(cfg) { case Some(Right(c)) =>
c.copy(tokenVerifier = null) shouldBe minimalCfg.copy(
port = 3000,
maxLoginRequests = 10,
loginTimeout = 60.seconds,
cookieSecure = false,
oauthAuthTemplate = Some(Paths.get("auth_template")),
oauthTokenTemplate = Some(Paths.get("token_template")),
oauthRefreshTemplate = Some(Paths.get("refresh_template")),
)
// token verifier needs to be set.
c.tokenVerifier shouldBe a[JwksVerifier]
}
}
"parse should raise error on non-existent config file" in {
val cli = loadCli("missingFile.conf")
cli.configFile should not be empty
val cfg = cli.loadConfigFromFile
cfg match {
case Left(err) => err shouldBe a[ConfigParseError]
case _ => fail("Expected a `ConfigParseError` on missing conf file")
val cfg = cli.loadFromConfigFile
inside(cfg) { case Some(Left(ConfigReaderFailures(head))) =>
head shouldBe a[CannotReadFile]
}
//parseConfig for non-existent file should return a None
Cli.parseConfig(
Array(
"--config-file",
"--config",
"missingFile.conf",
)
) shouldBe None

View File

@ -9,6 +9,8 @@ import com.daml.lf.speedy.Compiler
import com.daml.platform.services.time.TimeProviderType
import pureconfig.{ConfigReader, ConvertHelpers}
import com.daml.auth.middleware.api.{Client => AuthClient}
import com.daml.pureconfigutils.LedgerApiConfig
import com.daml.pureconfigutils.SharedConfigReaders._
import pureconfig.error.FailureReason
import pureconfig.generic.semiauto.deriveReader
@ -16,12 +18,6 @@ import java.nio.file.Path
import java.time.Duration
import scala.concurrent.duration.FiniteDuration
private[trigger] object LedgerApiConfig {
implicit val ledgerApiCfgReader: ConfigReader[LedgerApiConfig] =
deriveReader[LedgerApiConfig]
}
private[trigger] final case class LedgerApiConfig(address: String, port: Int)
private[trigger] object AuthorizationConfig {
final case object AuthConfigFailure extends FailureReason {
val description =
@ -33,9 +29,6 @@ private[trigger] object AuthorizationConfig {
(ac.authCommonUri.isEmpty && ac.authExternalUri.nonEmpty && ac.authInternalUri.nonEmpty)
}
implicit val uriCfgReader: ConfigReader[Uri] =
ConfigReader.fromString[Uri](ConvertHelpers.catchReadError(s => Uri(s)))
implicit val redirectToLoginCfgReader: ConfigReader[AuthClient.RedirectToLogin] =
ConfigReader.fromString[AuthClient.RedirectToLogin](
ConvertHelpers.catchReadError(s => Cli.redirectToLogin(s))
@ -68,18 +61,7 @@ private[trigger] object TriggerServiceAppConf {
)
}
})
implicit val timeProviderTypeCfgReader: ConfigReader[TimeProviderType] =
ConfigReader.fromString[TimeProviderType](ConvertHelpers.catchReadError { s =>
s.toLowerCase() match {
case "static" => TimeProviderType.Static
case "wall-clock" => TimeProviderType.WallClock
case s =>
throw new IllegalArgumentException(
s"Value '$s' for time-provider-type is not one of 'static' or 'wall-clock'"
)
}
})
implicit val jdbcCfgReader: ConfigReader[JdbcConfig] = deriveReader[JdbcConfig]
implicit val serviceCfgReader: ConfigReader[TriggerServiceAppConf] =
deriveReader[TriggerServiceAppConf]
}

View File

@ -13,6 +13,7 @@ import org.scalatest.Inside.inside
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
import pureconfig.error.{CannotReadFile, ConfigReaderFailures}
import com.daml.pureconfigutils.LedgerApiConfig
import java.nio.file.Paths
import scala.concurrent.duration._