Use auth middleware in trigger service /v1/start endpoint (#7654)

* Authorize trigger service on middleware

changelog_begin
changelog_end

* Trigger service auth callback handler

* Forward token

* Do not pin the application ID in the access token

The trigger service will assign an individual application ID to each
trigger based on its UUID. Requiring tokens on the granularity of
application IDs would break the idea of storing the token in a cookie to
be able to use it across multiple requests.

changelog_begin
changelog_end

* todo persist trigger token

* Add a state parameter to middleware login

* add documentation comments

* typo

* fmt

* Align Party type between middleware and trigger service

The middleware was using `com.daml.lf.data.Ref.Party` while the trigger
service is using `com.daml.ledger.api.refinements.ApiTypes.Party` which
requires conversions. This aligns the types to avoid such conversions.

* optional application id in oauth2 test server

* align party types

* configure auth middleware in trigger service tests

* handle empty cookie header

* follow redirects in trigger service tests

* keep track of cookies

* keep track of cookies

* Replace any previous Cookie header

Otherwise on old daml-ledger-token cookie might persist and be preferred
over a newly added instance.

* DEBUG

* Configure test ledger client readAs claims

* fmt

* docstrings

* remove debug output

* Avoid endless redirect loops

When the replay still fails to authorize on the middleware then we do
not want to attempt another login flow.

* Store callback routes in authCallbacks

* fmt

* Push AuthTestConfig into test target

https://github.com/digital-asset/daml/pull/7654#discussion_r506510193

* Unbind oauth2 server after middleware

https://github.com/digital-asset/daml/pull/7654/files#r506513251

Co-authored-by: Andreas Herrmann <andreas.herrmann@tweag.io>
This commit is contained in:
Andreas Herrmann 2020-10-16 17:37:36 +02:00 committed by GitHub
parent f025dc3065
commit 60fe244e1b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 375 additions and 52 deletions

View File

@ -45,6 +45,7 @@ da_scala_library(
"//libs-scala/contextualized-logging",
"//libs-scala/scala-utils",
"//triggers/runner:trigger-runner-lib",
"//triggers/service/auth:oauth-middleware", # TODO[AH] Extract request/response types
"@maven//:com_chuusai_shapeless_2_12",
"@maven//:com_github_scopt_scopt_2_12",
"@maven//:com_lihaoyi_sourcecode_2_12",
@ -53,6 +54,7 @@ da_scala_library(
"@maven//:com_typesafe_akka_akka_http_2_12",
"@maven//:com_typesafe_akka_akka_http_core_2_12",
"@maven//:com_typesafe_akka_akka_http_spray_json_2_12",
"@maven//:com_typesafe_akka_akka_parsing_2_12",
"@maven//:com_typesafe_akka_akka_stream_2_12",
"@maven//:com_typesafe_scala_logging_scala_logging_2_12",
"@maven//:com_zaxxer_HikariCP",
@ -101,7 +103,9 @@ da_scala_test(
"//language-support/scala/bindings-akka",
"//ledger-api/rs-grpc-bridge",
"//ledger-service/cli-opts",
"//ledger-service/jwt",
"//ledger/caching",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-common",
"//ledger/participant-integration-api",
"//ledger/participant-state",
@ -110,9 +114,12 @@ da_scala_test(
"//libs-scala/ports",
"//libs-scala/postgresql-testing",
"//libs-scala/timer-utils",
"//triggers/service/auth:oauth-middleware",
"//triggers/service/auth:oauth-test-server",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:com_typesafe_akka_akka_actor_typed_2_12",
"@maven//:com_typesafe_akka_akka_http_core_2_12",
"@maven//:com_typesafe_akka_akka_parsing_2_12",
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_3",
"@maven//:io_spray_spray_json_2_12",
"@maven//:org_scalatest_scalatest_2_12",

View File

@ -18,6 +18,7 @@ da_scala_library(
deps = [
":oauth-test-server", # TODO[AH] Extract OAuth2 request/response types
"//daml-lf/data",
"//language-support/scala/bindings",
"//ledger-service/jwt",
"//ledger/ledger-api-auth",
"//libs-scala/ports",
@ -112,6 +113,7 @@ da_scala_test(
":oauth-middleware",
":oauth-test-server",
"//daml-lf/data",
"//language-support/scala/bindings",
"//ledger-api/rs-grpc-bridge",
"//ledger-api/testing-utils",
"//ledger-service/jwt",

View File

@ -6,7 +6,7 @@ package com.daml.oauth.middleware
import akka.http.scaladsl.model.Uri
import com.daml.ports.Port
private[middleware] case class Config(
case class Config(
// Port the middleware listens on
port: Port,
// OAuth2 server endpoints
@ -17,7 +17,7 @@ private[middleware] case class Config(
clientSecret: String,
)
private[middleware] object Config {
object Config {
private val Empty =
Config(
port = Port.Dynamic,

View File

@ -51,7 +51,6 @@ repository](https://github.com/digital-asset/ex-secure-daml-infra).
context.accessToken[namespace] = {
// NOTE change the ledger ID to match your deployment.
"ledgerId": "2D105384-CE61-4CCC-8E0E-37248BA935A3",
"applicationId": context.clientName,
"actAs": actAs,
"readAs": readAs,
"admin": admin

View File

@ -6,7 +6,7 @@ package com.daml.oauth.middleware
import akka.http.scaladsl.marshalling.Marshaller
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.unmarshalling.Unmarshaller
import com.daml.lf.data.Ref.Party
import com.daml.ledger.api.refinements.ApiTypes.Party
import spray.json.{
DefaultJsonProtocol,
JsString,
@ -48,9 +48,9 @@ object Request {
if (w == "admin") {
admin = true
} else if (w.startsWith("actAs:")) {
actAs.append(Party.assertFromString(w.stripPrefix("actAs:")))
actAs.append(Party(w.stripPrefix("actAs:")))
} else if (w.startsWith("readAs:")) {
readAs.append(Party.assertFromString(w.stripPrefix("readAs:")))
readAs.append(Party(w.stripPrefix("readAs:")))
} else {
throw new IllegalArgumentException(s"Expected claim but got $w")
}
@ -62,14 +62,26 @@ object Request {
/** Auth endpoint query parameters
*/
case class Auth(claims: Claims)
case class Auth(claims: Claims) {
def toQuery: Uri.Query = Uri.Query("claims" -> claims.toQueryString())
}
/** Login endpoint query parameters
*
* @param redirectUri Redirect target after the login flow completed. I.e. the original request URI on the trigger service.
* @param claims Required ledger claims.
* @param state State that will be forwarded to the callback URI after authentication and authorization.
*/
case class Login(redirectUri: Uri, claims: Claims)
case class Login(redirectUri: Uri, claims: Claims, state: Option[String]) {
def toQuery: Uri.Query = {
var params = Seq(
"redirect_uri" -> redirectUri.toString,
"claims" -> claims.toQueryString(),
)
state.foreach(x => params ++= Seq("state" -> x))
Uri.Query(params: _*)
}
}
}

View File

@ -95,11 +95,15 @@ object Server extends StrictLogging {
}
private def login(config: Config, requests: TrieMap[UUID, Uri]) =
parameters(('redirect_uri.as[Uri], 'claims.as[Request.Claims]))
parameters(('redirect_uri.as[Uri], 'claims.as[Request.Claims], 'state ?))
.as[Request.Login](Request.Login) { login =>
extractRequest { request =>
val requestId = UUID.randomUUID
requests += (requestId -> login.redirectUri)
requests += (requestId -> {
var query = login.redirectUri.query().to[Seq]
login.state.foreach(x => query ++= Seq("state" -> x))
login.redirectUri.withQuery(Uri.Query(query: _*))
})
val authorize = OAuthRequest.Authorize(
responseType = "code",
clientId = config.clientId,

View File

@ -11,14 +11,14 @@ case class Config(
// Ledger ID of issued tokens
ledgerId: String,
// Application ID of issued tokens
applicationId: String,
applicationId: Option[String],
// Secret used to sign JWTs
jwtSecret: String
)
object Config {
private val Empty =
Config(port = Port.Dynamic, ledgerId = null, applicationId = null, jwtSecret = null)
Config(port = Port.Dynamic, ledgerId = null, applicationId = None, jwtSecret = null)
def parseConfig(args: Seq[String]): Option[Config] =
configParser.parse(args, Empty)
@ -37,7 +37,7 @@ object Config {
.action((x, c) => c.copy(ledgerId = x))
opt[String]("application-id")
.action((x, c) => c.copy(applicationId = x))
.action((x, c) => c.copy(applicationId = Some(x)))
opt[String]("secret")
.action((x, c) => c.copy(jwtSecret = x))

View File

@ -41,7 +41,7 @@ object Server {
}
AuthServiceJWTPayload(
ledgerId = Some(config.ledgerId),
applicationId = Some(config.applicationId),
applicationId = config.applicationId,
// Not required by the default auth service
participantId = None,
// Only for testing, expire never.

View File

@ -12,8 +12,8 @@ import akka.http.scaladsl.unmarshalling.Unmarshal
import com.daml.jwt.JwtSigner
import com.daml.jwt.domain.DecodedJwt
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload}
import com.daml.ledger.api.refinements.ApiTypes
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
import com.daml.lf.data.Ref.Party
import com.daml.oauth.server.{Response => OAuthResponse}
import org.scalatest.AsyncWordSpec
@ -33,8 +33,8 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
participantId = None,
exp = None,
admin = claims.admin,
actAs = claims.actAs,
readAs = claims.readAs
actAs = claims.actAs.map(ApiTypes.Party.unwrap(_)),
readAs = claims.readAs.map(ApiTypes.Party.unwrap(_))
)
OAuthResponse.Token(
accessToken = JwtSigner.HMAC256
@ -63,7 +63,7 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
}
}
"return the token from a cookie" in {
val claims = Request.Claims(actAs = List(Party.assertFromString("Alice")))
val claims = Request.Claims(actAs = List(ApiTypes.Party("Alice")))
val token = makeToken(claims)
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
val req = HttpRequest(
@ -83,13 +83,13 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
}
}
"return unauthorized on insufficient claims" in {
val token = makeToken(Request.Claims(actAs = List(Party.assertFromString("Alice"))))
val token = makeToken(Request.Claims(actAs = List(ApiTypes.Party("Alice"))))
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
val req = HttpRequest(
uri = middlewareUri
.withPath(Path./("auth"))
.withQuery(Query(
("claims", Request.Claims(actAs = List(Party.assertFromString("Bob"))).toQueryString))),
.withQuery(
Query(("claims", Request.Claims(actAs = List(ApiTypes.Party("Bob"))).toQueryString))),
headers = List(cookieHeader)
)
for {

View File

@ -30,7 +30,7 @@ trait TestFixture extends AkkaBeforeAndAfterAll with SuiteResource[(ServerBindin
OAuthConfig(
port = Port.Dynamic,
ledgerId = ledgerId,
applicationId = applicationId,
applicationId = Some(applicationId),
jwtSecret = jwtSecret))
serverUri = Uri()
.withScheme("http")

View File

@ -29,7 +29,7 @@ trait TestFixture extends AkkaBeforeAndAfterAll with SuiteResource[(ServerBindin
Config(
port = Port.Dynamic,
ledgerId = ledgerId,
applicationId = applicationId,
applicationId = Some(applicationId),
jwtSecret = jwtSecret))
client <- Resources.authClient(
Client.Config(

View File

@ -4,19 +4,28 @@
package com.daml.lf
package engine.trigger
import akka.actor.ActorSystem
import akka.actor.typed.{ActorRef, Behavior}
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
import akka.actor.typed.scaladsl.adapter._
import akka.http.scaladsl.Http
import akka.http.scaladsl.Http.ServerBinding
import akka.http.scaladsl.model._
import akka.http.scaladsl.model.Uri.Path
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import akka.http.scaladsl.server.{Directive, Directive1}
import akka.http.scaladsl.server.Directives._
import akka.http.scaladsl.server.Route
import akka.http.scaladsl.unmarshalling.Unmarshal
import akka.stream.Materializer
import akka.util.ByteString
import spray.json.DefaultJsonProtocol._
import spray.json._
import com.daml.oauth.middleware.{
JsonProtocol => AuthJsonProtocol,
Request => AuthRequest,
Response => AuthResponse
}
import com.daml.ledger.api.refinements.ApiTypes.Party
import com.daml.lf.archive.{Dar, DarReader, Decode}
import com.daml.lf.archive.Reader.ParseError
@ -43,7 +52,10 @@ import java.util.UUID
import java.util.zip.ZipInputStream
import java.time.LocalDateTime
import scala.collection.concurrent.TrieMap
class Server(
authConfig: AuthConfig,
ledgerConfig: LedgerConfig,
restartConfig: TriggerRestartConfig,
triggerDao: RunningTriggerDao)(
@ -96,7 +108,8 @@ class Server(
private def restartTriggers(triggers: Vector[RunningTrigger]): Either[String, Unit] = {
import cats.implicits._ // needed for traverse
triggers.traverse_(t => startTrigger(t.triggerParty, t.triggerName, Some(t.triggerInstance)))
triggers.traverse_(t =>
startTrigger(t.triggerParty, t.triggerName, t.triggerToken, Some(t.triggerInstance)))
}
private def triggerRunnerName(triggerInstance: UUID): String =
@ -110,6 +123,7 @@ class Server(
private def startTrigger(
party: Party,
triggerName: Identifier,
token: Option[String],
existingInstance: Option[UUID] = None): Either[String, JsValue] = {
for {
trigger <- Trigger.fromIdentifier(compiledPackages, triggerName)
@ -117,7 +131,7 @@ class Server(
case None =>
val newInstance = UUID.randomUUID
triggerDao
.addRunningTrigger(RunningTrigger(newInstance, triggerName, party))
.addRunningTrigger(RunningTrigger(newInstance, triggerName, party, token))
.map(_ => newInstance)
case Some(instance) => Right(instance)
}
@ -127,6 +141,7 @@ class Server(
ctx.self,
triggerInstance,
party,
token,
compiledPackages,
trigger,
ledgerConfig,
@ -171,20 +186,112 @@ class Server(
private def getTriggerStatus(uuid: UUID): Vector[(LocalDateTime, String)] =
triggerLog.getOrDefault(uuid, Vector.empty)
private val route = concat(
// TODO[AH] Make sure this is bounded in size.
private val authCallbacks: TrieMap[UUID, Route] = TrieMap()
// This directive requires authorization for the given claims via the auth middleware, if configured.
// If no auth middleware is configured, then the request will proceed without attempting authorization.
//
// Authorization follows the steps defined in `triggers/service/authentication.md`.
// First asking for a token on the `/auth` endpoint and redirecting to `/login` if none was returned.
// If a login is required then this will store the current continuation in `authCallbacks`
// to proceed once the login flow completed and authentication succeeded.
private def authorize(claims: AuthRequest.Claims)(
implicit ec: ExecutionContext,
system: ActorSystem): Directive1[Option[String]] =
authConfig match {
case NoAuth => provide(None)
case AuthMiddleware(authUri) =>
// Attempt to obtain the access token from the middleware's /auth endpoint.
// Forwards the current request's cookies.
// Fails if the response is not OK or Unauthorized.
def auth: Directive1[Option[String]] = {
val uri = authUri
.withPath(Path./("auth"))
.withQuery(AuthRequest.Auth(claims).toQuery)
import AuthJsonProtocol._
extract(_.request.headers[headers.Cookie]).flatMap { cookies =>
onSuccess(Http().singleRequest(HttpRequest(uri = uri, headers = cookies))).flatMap {
case HttpResponse(StatusCodes.OK, _, entity, _) =>
onSuccess(Unmarshal(entity).to[AuthResponse.Authorize]).map { auth =>
Some(auth.accessToken): Option[String]
}
case HttpResponse(StatusCodes.Unauthorized, _, _, _) =>
provide(None)
case resp @ HttpResponse(code, _, _, _) =>
onSuccess(Unmarshal(resp).to[String]).flatMap { msg =>
logger.error(s"Failed to authorize with middleware ($code): $msg")
complete(
errorResponse(
StatusCodes.InternalServerError,
"Failed to authorize with middleware"))
}
}
}
}
Directive { inner =>
auth {
// Authorization successful - pass token to continuation
case Some(token) => inner(Tuple1(Some(token)))
// Authorization failed - login and retry on callback request.
case None => { ctx =>
val requestId = UUID.randomUUID()
authCallbacks.update(
requestId, {
auth {
case None => {
// Authorization failed after login - respond with 401
// TODO[AH] Add WWW-Authenticate header
complete(errorResponse(StatusCodes.Unauthorized))
}
case Some(token) =>
// Authorization successful after login - use old request context and pass token to continuation.
mapRequestContext(_ => ctx) {
inner(Tuple1(Some(token)))
}
}
}
)
// TODO[AH] Make the redirect URI configurable, especially the authority. E.g. when running behind nginx.
val callbackUri = Uri()
.withScheme(ctx.request.uri.scheme)
.withAuthority(ctx.request.uri.authority)
.withPath(Path./("cb"))
val uri = authUri
.withPath(Path./("login"))
.withQuery(AuthRequest.Login(callbackUri, claims, Some(requestId.toString)).toQuery)
ctx.redirect(uri, StatusCodes.Found)
}
}
}
}
private def authCallback(requestId: UUID): Route =
authCallbacks.remove(requestId) match {
case None => complete(StatusCodes.NotFound)
case Some(callback) => callback
}
private def route(implicit ec: ExecutionContext, system: ActorSystem) = concat(
post {
concat(
// Start a new trigger given its identifier and the party it
// should be running as. Returns a UUID for the newly
// started trigger.
path("v1" / "start") {
entity(as[StartParams]) { params =>
startTrigger(params.party, params.triggerName) match {
case Left(err) =>
complete(errorResponse(StatusCodes.UnprocessableEntity, err))
case Right(triggerInstance) =>
complete(successResponse(triggerInstance))
}
entity(as[StartParams]) {
params =>
val claims =
AuthRequest.Claims(actAs = List(params.party))
// TODO[AH] Why do we need to pass ec, system explicitly?
authorize(claims)(ec, system) { token =>
startTrigger(params.party, params.triggerName, token) match {
case Left(err) =>
complete(errorResponse(StatusCodes.UnprocessableEntity, err))
case Right(triggerInstance) =>
complete(successResponse(triggerInstance))
}
}
}
},
// upload a DAR as a multi-part form request with a single field called
@ -257,6 +364,14 @@ class Server(
}
}
},
// Authorization callback endpoint
path("cb") {
get {
parameters('state.as[UUID]) { requestId =>
authCallback(requestId)
}
}
},
)
}
@ -265,6 +380,7 @@ object Server {
def apply(
host: String,
port: Int,
authConfig: AuthConfig,
ledgerConfig: LedgerConfig,
restartConfig: TriggerRestartConfig,
initialDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]],
@ -299,11 +415,11 @@ object Server {
val (dao, server): (RunningTriggerDao, Server) = jdbcConfig match {
case None =>
val dao = InMemoryTriggerDao()
val server = new Server(ledgerConfig, restartConfig, dao)
val server = new Server(authConfig, ledgerConfig, restartConfig, dao)
(dao, server)
case Some(c) =>
val dao = DbTriggerDao(c)
val server = new Server(ledgerConfig, restartConfig, dao)
val server = new Server(authConfig, ledgerConfig, restartConfig, dao)
val recovery: Either[String, Unit] = for {
_ <- dao.initialize
packages <- dao.readPackages

View File

@ -6,6 +6,7 @@ package com.daml.lf.engine.trigger
import java.nio.file.{Path, Paths}
import java.time.Duration
import akka.http.scaladsl.model.Uri
import com.daml.cliopts
import com.daml.platform.services.time.TimeProviderType
import scalaz.Show
@ -21,6 +22,7 @@ private[trigger] final case class ServiceConfig(
httpPort: Int,
ledgerHost: String,
ledgerPort: Int,
authUri: Option[Uri],
maxInboundMessageSize: Int,
minRestartInterval: FiniteDuration,
maxRestartInterval: FiniteDuration,
@ -105,6 +107,13 @@ private[trigger] object ServiceConfig {
.action((t, c) => c.copy(ledgerPort = t))
.text("Ledger port.")
opt[String]("auth")
.optional()
.action((t, c) => c.copy(authUri = Some(Uri(t))))
.text("Auth middleware URI.")
// TODO[AH] Expose once the feature is fully implemented.
.hidden()
opt[Int]("max-inbound-message-size")
.action((x, c) => c.copy(maxInboundMessageSize = x))
.optional()
@ -157,6 +166,7 @@ private[trigger] object ServiceConfig {
httpPort = DefaultHttpPort,
ledgerHost = null,
ledgerPort = 0,
authUri = None,
maxInboundMessageSize = DefaultMaxInboundMessageSize,
minRestartInterval = DefaultMinRestartInterval,
maxRestartInterval = DefaultMaxRestartInterval,

View File

@ -27,6 +27,7 @@ object ServiceMain {
def startServer(
host: String,
port: Int,
authConfig: AuthConfig,
ledgerConfig: LedgerConfig,
restartConfig: TriggerRestartConfig,
encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]],
@ -38,6 +39,7 @@ object ServiceMain {
Server(
host,
port,
authConfig,
ledgerConfig,
restartConfig,
encodedDar,
@ -66,6 +68,10 @@ object ServiceMain {
case Success(dar) => dar
}
}
val authConfig: AuthConfig = config.authUri match {
case None => NoAuth
case Some(uri) => AuthMiddleware(uri)
}
val ledgerConfig =
LedgerConfig(
config.ledgerHost,
@ -83,6 +89,7 @@ object ServiceMain {
Server(
config.address,
config.httpPort,
authConfig,
ledgerConfig,
restartConfig,
encodedDar,

View File

@ -31,7 +31,7 @@ object TriggerRunnerImpl {
server: ActorRef[Message],
triggerInstance: UUID,
party: Party,
// TODO(SF, 2020-06-09): Add access token field here in the presence of authentication.
token: Option[String],
compiledPackages: CompiledPackages,
trigger: Trigger,
ledgerConfig: LedgerConfig,
@ -67,9 +67,7 @@ object TriggerRunnerImpl {
commandClient = CommandClientConfiguration.default.copy(
defaultDeduplicationTime = config.ledgerConfig.commandTtl),
sslContext = None,
// TODO(SF, 2020-06-09): In the presence of an authorization
// service, get an access token and pass it through here!
token = None,
token = config.token,
maxInboundMessageSize = config.ledgerConfig.maxInboundMessageSize,
)

View File

@ -149,7 +149,8 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
triggerInstance: UUID,
party: String,
fullTriggerName: String): Either[String, RunningTrigger] = {
Identifier.fromString(fullTriggerName).map(RunningTrigger(triggerInstance, _, Tag(party)))
// TODO[AH] Persist the access and refresh token.
Identifier.fromString(fullTriggerName).map(RunningTrigger(triggerInstance, _, Tag(party), None))
}
// Drop all tables and other objects associated with the database.

View File

@ -10,10 +10,15 @@ import com.daml.ledger.api.refinements.ApiTypes.Party
import com.daml.lf.data.Ref.Identifier
import com.daml.platform.services.time.TimeProviderType
import akka.http.scaladsl.model.Uri
import scala.concurrent.duration.FiniteDuration
package trigger {
sealed trait AuthConfig
case object NoAuth extends AuthConfig
final case class AuthMiddleware(uri: Uri) extends AuthConfig
case class LedgerConfig(
host: String,
port: Int,
@ -32,7 +37,6 @@ package trigger {
triggerInstance: UUID,
triggerName: Identifier,
triggerParty: Party,
// TODO(SF, 2020-0610): Add access token field here in the
// presence of authentication.
triggerToken: Option[String],
)
}

View File

@ -7,14 +7,21 @@ import java.io.File
import java.net.InetAddress
import java.time.Duration
import akka.actor.ActorSystem
import akka.actor.typed.{ActorSystem => TypedActorSystem}
import akka.http.scaladsl.Http.ServerBinding
import akka.http.scaladsl.model.Uri
import akka.http.scaladsl.model.Uri.Path
import akka.stream.Materializer
import com.daml.bazeltools.BazelRunfiles
import com.daml.daml_lf_dev.DamlLf
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.jwt.domain.DecodedJwt
import com.daml.jwt.{HMAC256Verifier, JwtSigner}
import com.daml.ledger.api.auth
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload}
import com.daml.ledger.api.domain.LedgerId
import com.daml.ledger.api.refinements.ApiTypes
import com.daml.ledger.api.refinements.ApiTypes.ApplicationId
import com.daml.ledger.client.LedgerClient
import com.daml.ledger.client.configuration.{
@ -24,6 +31,8 @@ import com.daml.ledger.client.configuration.{
}
import com.daml.lf.archive.Dar
import com.daml.lf.data.Ref._
import com.daml.oauth.middleware.{Config => MiddlewareConfig, Server => MiddlewareServer}
import com.daml.oauth.server.{Config => OAuthConfig, Server => OAuthServer}
import com.daml.platform.common.LedgerIdMode
import com.daml.platform.sandbox
import com.daml.platform.sandbox.SandboxServer
@ -40,6 +49,13 @@ import scala.concurrent._
import scala.concurrent.duration._
import scala.sys.process.Process
private[trigger] final case class AuthTestConfig(
// HMAC256 signature secret.
jwtSecret: String,
// Grant readAs claims for these parties to the ledger client provided to test cases.
parties: List[ApiTypes.Party],
)
object TriggerServiceFixture extends StrictLogging {
// Use a small initial interval so we can test restart behaviour more easily.
@ -50,10 +66,12 @@ object TriggerServiceFixture extends StrictLogging {
dars: List[File],
encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]],
jdbcConfig: Option[JdbcConfig],
authTestConfig: Option[AuthTestConfig],
)(testFn: (Uri, LedgerClient, Proxy) => Future[A])(
implicit mat: Materializer,
aesf: ExecutionSequencerFactory,
ec: ExecutionContext,
system: ActorSystem,
pos: source.Position,
): Future[A] = {
logger.info(s"${pos.fileName}:${pos.lineNumber}: setting up trigger service")
@ -77,9 +95,43 @@ object TriggerServiceFixture extends StrictLogging {
val ledgerId = LedgerId(testName)
val applicationId = ApplicationId(testName)
val authF: Future[(AuthConfig, () => Future[Unit])] = authTestConfig match {
case None => Future((NoAuth, () => Future(())))
case Some(AuthTestConfig(secret, _)) =>
for {
oauth <- OAuthServer.start(
OAuthConfig(
port = Port.Dynamic,
ledgerId = LedgerId.unwrap(ledgerId),
// TODO[AH] Choose application ID, see https://github.com/digital-asset/daml/issues/7671
applicationId = None,
jwtSecret = secret,
))
serverUri = Uri()
.withScheme("http")
.withAuthority(oauth.localAddress.getHostString, oauth.localAddress.getPort)
middleware <- MiddlewareServer.start(
MiddlewareConfig(
port = Port.Dynamic,
oauthAuth = serverUri.withPath(Path./("authorize")),
oauthToken = serverUri.withPath(Path./("token")),
clientId = "oauth-middleware-id",
clientSecret = "oauth-middleware-secret",
))
middlewareUri = Uri()
.withScheme("http")
.withAuthority(middleware.localAddress.getHostString, middleware.localAddress.getPort)
cleanup = () =>
for {
_ <- middleware.unbind()
_ <- oauth.unbind()
} yield ()
} yield (AuthMiddleware(middlewareUri), cleanup)
}
val ledgerF = for {
(_, toxiproxyClient) <- toxiproxyF
ledger <- Future(new SandboxServer(ledgerConfig(Port.Dynamic, dars, ledgerId), mat))
ledger <- Future(
new SandboxServer(ledgerConfig(Port.Dynamic, dars, ledgerId, authTestConfig), mat))
ledgerPort <- ledger.portF
ledgerProxyPort = LockedFreePort.find()
ledgerProxy = toxiproxyClient.createProxy(
@ -97,12 +149,13 @@ object TriggerServiceFixture extends StrictLogging {
client <- LedgerClient.singleHost(
host.getHostName,
ledgerPort.value,
clientConfig(applicationId),
clientConfig(applicationId, authTestConfig),
)
} yield client
// Configure the service with the ledger's *proxy* port.
val serviceF: Future[(ServerBinding, TypedActorSystem[Message])] = for {
(authConfig, _) <- authF
(_, _, ledgerProxyPort, _) <- ledgerF
ledgerConfig = LedgerConfig(
host.getHostName,
@ -119,6 +172,7 @@ object TriggerServiceFixture extends StrictLogging {
service <- ServiceMain.startServer(
host.getHostName,
servicePort.port.value,
authConfig,
ledgerConfig,
restartConfig,
encodedDar,
@ -152,7 +206,10 @@ object TriggerServiceFixture extends StrictLogging {
proc.destroy()
proc.exitValue() // destroy is async
})
result <- (ta.failed.toOption orElse se orElse le orElse te)
ae <- optErr(authF.flatMap {
case (_, cleanup) => cleanup()
})
result <- (ta.failed.toOption orElse se orElse le orElse te orElse ae)
.cata(Future.failed, Future fromTry ta)
} yield result
}
@ -164,24 +221,43 @@ object TriggerServiceFixture extends StrictLogging {
private def ledgerConfig(
ledgerPort: Port,
dars: List[File],
ledgerId: LedgerId
ledgerId: LedgerId,
authTestConfig: Option[AuthTestConfig]
): SandboxConfig =
sandbox.DefaultConfig.copy(
port = ledgerPort,
damlPackages = dars,
timeProviderType = Some(TimeProviderType.Static),
ledgerIdMode = LedgerIdMode.Static(ledgerId),
authService = None,
authService = for {
cfg <- authTestConfig
verifier <- HMAC256Verifier(cfg.jwtSecret).toOption
} yield auth.AuthServiceJWT(verifier)
)
private def clientConfig[A](
applicationId: ApplicationId,
token: Option[String] = None): LedgerClientConfiguration =
authTestConfig: Option[AuthTestConfig]): LedgerClientConfiguration =
LedgerClientConfiguration(
applicationId = ApplicationId.unwrap(applicationId),
ledgerIdRequirement = LedgerIdRequirement.none,
commandClient = CommandClientConfiguration.default,
sslContext = None,
token = token,
token = for {
cfg <- authTestConfig
header = """{"alg": "HS256", "typ": "JWT"}"""
payload = AuthServiceJWTPayload(
ledgerId = None,
applicationId = None,
participantId = None,
exp = None,
admin = true,
actAs = cfg.parties.map(ApiTypes.Party.unwrap),
readAs = List(),
)
jwt <- JwtSigner.HMAC256
.sign(DecodedJwt(header, AuthServiceJWTCodec.compactPrint(payload)), cfg.jwtSecret)
.toOption
} yield jwt.value,
)
}

View File

@ -35,15 +35,82 @@ import com.daml.ledger.api.v1.transaction_filter.{Filters, InclusiveFilters, Tra
import com.daml.ledger.client.LedgerClient
import com.daml.lf.engine.trigger.dao.DbTriggerDao
import com.daml.testing.postgresql.PostgresAroundAll
import com.typesafe.scalalogging.StrictLogging
import eu.rekawek.toxiproxy._
abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually with Matchers {
import scala.collection.concurrent.TrieMap
import scala.util.Success
/**
* A test-fixture that persists cookies between http requests for each test-case.
*/
trait HttpCookies extends BeforeAndAfterEach { this: Suite =>
private val cookieJar = TrieMap[String, String]()
override protected def afterEach(): Unit = {
try super.afterEach()
finally cookieJar.clear()
}
/**
* Adds a Cookie header for the currently stored cookies and performs the given http request.
*/
def httpRequest(request: HttpRequest)(
implicit system: ActorSystem,
ec: ExecutionContext): Future[HttpResponse] = {
Http()
.singleRequest {
if (cookieJar.nonEmpty) {
val cookies = headers.Cookie(values = cookieJar.to[Seq]: _*)
request.addHeader(cookies)
} else {
request
}
}
.andThen {
case Success(resp) =>
resp.headers.foreach {
case headers.`Set-Cookie`(cookie) =>
cookieJar.update(cookie.name, cookie.value)
case _ =>
}
}
}
/**
* Same as [[httpRequest]] but will follow redirections.
*/
def httpRequestFollow(request: HttpRequest, maxRedirections: Int = 10)(
implicit system: ActorSystem,
ec: ExecutionContext): Future[HttpResponse] = {
httpRequest(request).flatMap {
case resp @ HttpResponse(StatusCodes.Redirection(_), _, _, _) =>
if (maxRedirections == 0) {
throw new RuntimeException("Too many redirections")
} else {
val uri = resp.header[headers.Location].get.uri
httpRequestFollow(HttpRequest(uri = uri), maxRedirections - 1)
}
case resp => Future(resp)
}
}
}
abstract class AbstractTriggerServiceTest
extends AsyncFlatSpec
with HttpCookies
with Eventually
with Matchers
with StrictLogging {
import AbstractTriggerServiceTest.CompatAssertion
// Abstract member for testing with and without a database
def jdbcConfig: Option[JdbcConfig]
// Abstract member for testing with and without authentication/authorization
def authTestConfig: Option[AuthTestConfig]
// Default retry config for `eventually`
override implicit def patienceConfig: PatienceConfig =
PatienceConfig(timeout = scaled(Span(15, Seconds)), interval = scaled(Span(1, Seconds)))
@ -77,7 +144,12 @@ abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually
def withTriggerService[A](encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]])(
testFn: (Uri, LedgerClient, Proxy) => Future[A])(implicit pos: source.Position): Future[A] =
TriggerServiceFixture.withTriggerService(testId, List(darPath), encodedDar, jdbcConfig)(testFn)
TriggerServiceFixture.withTriggerService(
testId,
List(darPath),
encodedDar,
jdbcConfig,
authTestConfig)(testFn)
def startTrigger(uri: Uri, triggerName: String, party: Party): Future[HttpResponse] = {
val req = HttpRequest(
@ -88,7 +160,7 @@ abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually
s"""{"triggerName": "$triggerName", "party": "$party"}"""
)
)
Http().singleRequest(req)
httpRequestFollow(req)
}
def listTriggers(uri: Uri, party: Party): Future[HttpResponse] = {
@ -446,6 +518,7 @@ object AbstractTriggerServiceTest {
class TriggerServiceTestInMem extends AbstractTriggerServiceTest {
override def jdbcConfig: Option[JdbcConfig] = None
override def authTestConfig: Option[AuthTestConfig] = None
}
@ -456,6 +529,7 @@ class TriggerServiceTestWithDb
with PostgresAroundAll {
override def jdbcConfig: Option[JdbcConfig] = Some(jdbcConfig_)
override def authTestConfig: Option[AuthTestConfig] = None
// Lazy because the postgresDatabase is only available once the tests start
private lazy val jdbcConfig_ = JdbcConfig(postgresDatabase.url, "operator", "password")
@ -530,3 +604,16 @@ class TriggerServiceTestWithDb
} yield succeed)
}
// Tests for auth mode only go here
class TriggerServiceTestAuth extends AbstractTriggerServiceTest {
override def jdbcConfig: Option[JdbcConfig] = None
override def authTestConfig: Option[AuthTestConfig] =
Some(
AuthTestConfig(
jwtSecret = "secret",
parties = List(alice, bob),
))
}