From 60fe244e1b92c3925076a2a1452c58479e624944 Mon Sep 17 00:00:00 2001 From: Andreas Herrmann <42969706+aherrmann-da@users.noreply.github.com> Date: Fri, 16 Oct 2020 17:37:36 +0200 Subject: [PATCH] Use auth middleware in trigger service `/v1/start` endpoint (#7654) * Authorize trigger service on middleware changelog_begin changelog_end * Trigger service auth callback handler * Forward token * Do not pin the application ID in the access token The trigger service will assign an individual application ID to each trigger based on its UUID. Requiring tokens on the granularity of application IDs would break the idea of storing the token in a cookie to be able to use it across multiple requests. changelog_begin changelog_end * todo persist trigger token * Add a state parameter to middleware login * add documentation comments * typo * fmt * Align Party type between middleware and trigger service The middleware was using `com.daml.lf.data.Ref.Party` while the trigger service is using `com.daml.ledger.api.refinements.ApiTypes.Party` which requires conversions. This aligns the types to avoid such conversions. * optional application id in oauth2 test server * align party types * configure auth middleware in trigger service tests * handle empty cookie header * follow redirects in trigger service tests * keep track of cookies * keep track of cookies * Replace any previous Cookie header Otherwise on old daml-ledger-token cookie might persist and be preferred over a newly added instance. * DEBUG * Configure test ledger client readAs claims * fmt * docstrings * remove debug output * Avoid endless redirect loops When the replay still fails to authorize on the middleware then we do not want to attempt another login flow. * Store callback routes in authCallbacks * fmt * Push AuthTestConfig into test target https://github.com/digital-asset/daml/pull/7654#discussion_r506510193 * Unbind oauth2 server after middleware https://github.com/digital-asset/daml/pull/7654/files#r506513251 Co-authored-by: Andreas Herrmann --- triggers/service/BUILD.bazel | 7 + triggers/service/auth/BUILD.bazel | 2 + .../com/daml/oauth/middleware/Config.scala | 4 +- .../scala/com/daml/oauth/middleware/README.md | 1 - .../com/daml/oauth/middleware/Request.scala | 22 ++- .../com/daml/oauth/middleware/Server.scala | 8 +- .../scala/com/daml/oauth/server/Config.scala | 6 +- .../scala/com/daml/oauth/server/Server.scala | 2 +- .../com/daml/oauth/middleware/Test.scala | 14 +- .../daml/oauth/middleware/TestFixture.scala | 2 +- .../com/daml/oauth/server/TestFixture.scala | 2 +- .../daml/lf/engine/trigger/Server.scala | 140 ++++++++++++++++-- .../lf/engine/trigger/ServiceConfig.scala | 10 ++ .../daml/lf/engine/trigger/ServiceMain.scala | 7 + .../lf/engine/trigger/TriggerRunnerImpl.scala | 6 +- .../lf/engine/trigger/dao/DbTriggerDao.scala | 3 +- .../daml/lf/engine/trigger/package.scala | 8 +- .../trigger/TriggerServiceFixture.scala | 90 ++++++++++- .../engine/trigger/TriggerServiceTest.scala | 93 +++++++++++- 19 files changed, 375 insertions(+), 52 deletions(-) diff --git a/triggers/service/BUILD.bazel b/triggers/service/BUILD.bazel index ccc1023164..be7bf27966 100644 --- a/triggers/service/BUILD.bazel +++ b/triggers/service/BUILD.bazel @@ -45,6 +45,7 @@ da_scala_library( "//libs-scala/contextualized-logging", "//libs-scala/scala-utils", "//triggers/runner:trigger-runner-lib", + "//triggers/service/auth:oauth-middleware", # TODO[AH] Extract request/response types "@maven//:com_chuusai_shapeless_2_12", "@maven//:com_github_scopt_scopt_2_12", "@maven//:com_lihaoyi_sourcecode_2_12", @@ -53,6 +54,7 @@ da_scala_library( "@maven//:com_typesafe_akka_akka_http_2_12", "@maven//:com_typesafe_akka_akka_http_core_2_12", "@maven//:com_typesafe_akka_akka_http_spray_json_2_12", + "@maven//:com_typesafe_akka_akka_parsing_2_12", "@maven//:com_typesafe_akka_akka_stream_2_12", "@maven//:com_typesafe_scala_logging_scala_logging_2_12", "@maven//:com_zaxxer_HikariCP", @@ -101,7 +103,9 @@ da_scala_test( "//language-support/scala/bindings-akka", "//ledger-api/rs-grpc-bridge", "//ledger-service/cli-opts", + "//ledger-service/jwt", "//ledger/caching", + "//ledger/ledger-api-auth", "//ledger/ledger-api-common", "//ledger/participant-integration-api", "//ledger/participant-state", @@ -110,9 +114,12 @@ da_scala_test( "//libs-scala/ports", "//libs-scala/postgresql-testing", "//libs-scala/timer-utils", + "//triggers/service/auth:oauth-middleware", + "//triggers/service/auth:oauth-test-server", "@maven//:ch_qos_logback_logback_classic", "@maven//:com_typesafe_akka_akka_actor_typed_2_12", "@maven//:com_typesafe_akka_akka_http_core_2_12", + "@maven//:com_typesafe_akka_akka_parsing_2_12", "@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_3", "@maven//:io_spray_spray_json_2_12", "@maven//:org_scalatest_scalatest_2_12", diff --git a/triggers/service/auth/BUILD.bazel b/triggers/service/auth/BUILD.bazel index 5830a02010..3b6416ebb2 100644 --- a/triggers/service/auth/BUILD.bazel +++ b/triggers/service/auth/BUILD.bazel @@ -18,6 +18,7 @@ da_scala_library( deps = [ ":oauth-test-server", # TODO[AH] Extract OAuth2 request/response types "//daml-lf/data", + "//language-support/scala/bindings", "//ledger-service/jwt", "//ledger/ledger-api-auth", "//libs-scala/ports", @@ -112,6 +113,7 @@ da_scala_test( ":oauth-middleware", ":oauth-test-server", "//daml-lf/data", + "//language-support/scala/bindings", "//ledger-api/rs-grpc-bridge", "//ledger-api/testing-utils", "//ledger-service/jwt", diff --git a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Config.scala b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Config.scala index 94c007f7cd..f40a4a2f86 100644 --- a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Config.scala +++ b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Config.scala @@ -6,7 +6,7 @@ package com.daml.oauth.middleware import akka.http.scaladsl.model.Uri import com.daml.ports.Port -private[middleware] case class Config( +case class Config( // Port the middleware listens on port: Port, // OAuth2 server endpoints @@ -17,7 +17,7 @@ private[middleware] case class Config( clientSecret: String, ) -private[middleware] object Config { +object Config { private val Empty = Config( port = Port.Dynamic, diff --git a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/README.md b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/README.md index 9aa54f2553..9840be1ab1 100644 --- a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/README.md +++ b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/README.md @@ -51,7 +51,6 @@ repository](https://github.com/digital-asset/ex-secure-daml-infra). context.accessToken[namespace] = { // NOTE change the ledger ID to match your deployment. "ledgerId": "2D105384-CE61-4CCC-8E0E-37248BA935A3", - "applicationId": context.clientName, "actAs": actAs, "readAs": readAs, "admin": admin diff --git a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Request.scala b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Request.scala index cfae792014..6a0db76ceb 100644 --- a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Request.scala +++ b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Request.scala @@ -6,7 +6,7 @@ package com.daml.oauth.middleware import akka.http.scaladsl.marshalling.Marshaller import akka.http.scaladsl.model.Uri import akka.http.scaladsl.unmarshalling.Unmarshaller -import com.daml.lf.data.Ref.Party +import com.daml.ledger.api.refinements.ApiTypes.Party import spray.json.{ DefaultJsonProtocol, JsString, @@ -48,9 +48,9 @@ object Request { if (w == "admin") { admin = true } else if (w.startsWith("actAs:")) { - actAs.append(Party.assertFromString(w.stripPrefix("actAs:"))) + actAs.append(Party(w.stripPrefix("actAs:"))) } else if (w.startsWith("readAs:")) { - readAs.append(Party.assertFromString(w.stripPrefix("readAs:"))) + readAs.append(Party(w.stripPrefix("readAs:"))) } else { throw new IllegalArgumentException(s"Expected claim but got $w") } @@ -62,14 +62,26 @@ object Request { /** Auth endpoint query parameters */ - case class Auth(claims: Claims) + case class Auth(claims: Claims) { + def toQuery: Uri.Query = Uri.Query("claims" -> claims.toQueryString()) + } /** Login endpoint query parameters * * @param redirectUri Redirect target after the login flow completed. I.e. the original request URI on the trigger service. * @param claims Required ledger claims. + * @param state State that will be forwarded to the callback URI after authentication and authorization. */ - case class Login(redirectUri: Uri, claims: Claims) + case class Login(redirectUri: Uri, claims: Claims, state: Option[String]) { + def toQuery: Uri.Query = { + var params = Seq( + "redirect_uri" -> redirectUri.toString, + "claims" -> claims.toQueryString(), + ) + state.foreach(x => params ++= Seq("state" -> x)) + Uri.Query(params: _*) + } + } } diff --git a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Server.scala b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Server.scala index 8dc2f5aab0..3fb264f5b5 100644 --- a/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Server.scala +++ b/triggers/service/auth/src/main/scala/com/daml/oauth/middleware/Server.scala @@ -95,11 +95,15 @@ object Server extends StrictLogging { } private def login(config: Config, requests: TrieMap[UUID, Uri]) = - parameters(('redirect_uri.as[Uri], 'claims.as[Request.Claims])) + parameters(('redirect_uri.as[Uri], 'claims.as[Request.Claims], 'state ?)) .as[Request.Login](Request.Login) { login => extractRequest { request => val requestId = UUID.randomUUID - requests += (requestId -> login.redirectUri) + requests += (requestId -> { + var query = login.redirectUri.query().to[Seq] + login.state.foreach(x => query ++= Seq("state" -> x)) + login.redirectUri.withQuery(Uri.Query(query: _*)) + }) val authorize = OAuthRequest.Authorize( responseType = "code", clientId = config.clientId, diff --git a/triggers/service/auth/src/main/scala/com/daml/oauth/server/Config.scala b/triggers/service/auth/src/main/scala/com/daml/oauth/server/Config.scala index aae0549888..9c7dabe9ca 100644 --- a/triggers/service/auth/src/main/scala/com/daml/oauth/server/Config.scala +++ b/triggers/service/auth/src/main/scala/com/daml/oauth/server/Config.scala @@ -11,14 +11,14 @@ case class Config( // Ledger ID of issued tokens ledgerId: String, // Application ID of issued tokens - applicationId: String, + applicationId: Option[String], // Secret used to sign JWTs jwtSecret: String ) object Config { private val Empty = - Config(port = Port.Dynamic, ledgerId = null, applicationId = null, jwtSecret = null) + Config(port = Port.Dynamic, ledgerId = null, applicationId = None, jwtSecret = null) def parseConfig(args: Seq[String]): Option[Config] = configParser.parse(args, Empty) @@ -37,7 +37,7 @@ object Config { .action((x, c) => c.copy(ledgerId = x)) opt[String]("application-id") - .action((x, c) => c.copy(applicationId = x)) + .action((x, c) => c.copy(applicationId = Some(x))) opt[String]("secret") .action((x, c) => c.copy(jwtSecret = x)) diff --git a/triggers/service/auth/src/main/scala/com/daml/oauth/server/Server.scala b/triggers/service/auth/src/main/scala/com/daml/oauth/server/Server.scala index 37a66173ec..dde8922445 100644 --- a/triggers/service/auth/src/main/scala/com/daml/oauth/server/Server.scala +++ b/triggers/service/auth/src/main/scala/com/daml/oauth/server/Server.scala @@ -41,7 +41,7 @@ object Server { } AuthServiceJWTPayload( ledgerId = Some(config.ledgerId), - applicationId = Some(config.applicationId), + applicationId = config.applicationId, // Not required by the default auth service participantId = None, // Only for testing, expire never. diff --git a/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/Test.scala b/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/Test.scala index 0ab3e6d489..9d022732c7 100644 --- a/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/Test.scala +++ b/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/Test.scala @@ -12,8 +12,8 @@ import akka.http.scaladsl.unmarshalling.Unmarshal import com.daml.jwt.JwtSigner import com.daml.jwt.domain.DecodedJwt import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload} +import com.daml.ledger.api.refinements.ApiTypes import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll -import com.daml.lf.data.Ref.Party import com.daml.oauth.server.{Response => OAuthResponse} import org.scalatest.AsyncWordSpec @@ -33,8 +33,8 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr participantId = None, exp = None, admin = claims.admin, - actAs = claims.actAs, - readAs = claims.readAs + actAs = claims.actAs.map(ApiTypes.Party.unwrap(_)), + readAs = claims.readAs.map(ApiTypes.Party.unwrap(_)) ) OAuthResponse.Token( accessToken = JwtSigner.HMAC256 @@ -63,7 +63,7 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr } } "return the token from a cookie" in { - val claims = Request.Claims(actAs = List(Party.assertFromString("Alice"))) + val claims = Request.Claims(actAs = List(ApiTypes.Party("Alice"))) val token = makeToken(claims) val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue) val req = HttpRequest( @@ -83,13 +83,13 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr } } "return unauthorized on insufficient claims" in { - val token = makeToken(Request.Claims(actAs = List(Party.assertFromString("Alice")))) + val token = makeToken(Request.Claims(actAs = List(ApiTypes.Party("Alice")))) val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue) val req = HttpRequest( uri = middlewareUri .withPath(Path./("auth")) - .withQuery(Query( - ("claims", Request.Claims(actAs = List(Party.assertFromString("Bob"))).toQueryString))), + .withQuery( + Query(("claims", Request.Claims(actAs = List(ApiTypes.Party("Bob"))).toQueryString))), headers = List(cookieHeader) ) for { diff --git a/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/TestFixture.scala b/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/TestFixture.scala index 073157c2ac..b0d06a8edf 100644 --- a/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/TestFixture.scala +++ b/triggers/service/auth/src/test/scala/com/daml/oauth/middleware/TestFixture.scala @@ -30,7 +30,7 @@ trait TestFixture extends AkkaBeforeAndAfterAll with SuiteResource[(ServerBindin OAuthConfig( port = Port.Dynamic, ledgerId = ledgerId, - applicationId = applicationId, + applicationId = Some(applicationId), jwtSecret = jwtSecret)) serverUri = Uri() .withScheme("http") diff --git a/triggers/service/auth/src/test/scala/com/daml/oauth/server/TestFixture.scala b/triggers/service/auth/src/test/scala/com/daml/oauth/server/TestFixture.scala index c0790b58df..826ee99111 100644 --- a/triggers/service/auth/src/test/scala/com/daml/oauth/server/TestFixture.scala +++ b/triggers/service/auth/src/test/scala/com/daml/oauth/server/TestFixture.scala @@ -29,7 +29,7 @@ trait TestFixture extends AkkaBeforeAndAfterAll with SuiteResource[(ServerBindin Config( port = Port.Dynamic, ledgerId = ledgerId, - applicationId = applicationId, + applicationId = Some(applicationId), jwtSecret = jwtSecret)) client <- Resources.authClient( Client.Config( diff --git a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Server.scala b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Server.scala index 9d2a38af1e..18dbf624d8 100644 --- a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Server.scala +++ b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/Server.scala @@ -4,19 +4,28 @@ package com.daml.lf package engine.trigger +import akka.actor.ActorSystem import akka.actor.typed.{ActorRef, Behavior} import akka.actor.typed.scaladsl.{ActorContext, Behaviors} import akka.actor.typed.scaladsl.adapter._ import akka.http.scaladsl.Http import akka.http.scaladsl.Http.ServerBinding import akka.http.scaladsl.model._ +import akka.http.scaladsl.model.Uri.Path import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._ +import akka.http.scaladsl.server.{Directive, Directive1} import akka.http.scaladsl.server.Directives._ import akka.http.scaladsl.server.Route +import akka.http.scaladsl.unmarshalling.Unmarshal import akka.stream.Materializer import akka.util.ByteString import spray.json.DefaultJsonProtocol._ import spray.json._ +import com.daml.oauth.middleware.{ + JsonProtocol => AuthJsonProtocol, + Request => AuthRequest, + Response => AuthResponse +} import com.daml.ledger.api.refinements.ApiTypes.Party import com.daml.lf.archive.{Dar, DarReader, Decode} import com.daml.lf.archive.Reader.ParseError @@ -43,7 +52,10 @@ import java.util.UUID import java.util.zip.ZipInputStream import java.time.LocalDateTime +import scala.collection.concurrent.TrieMap + class Server( + authConfig: AuthConfig, ledgerConfig: LedgerConfig, restartConfig: TriggerRestartConfig, triggerDao: RunningTriggerDao)( @@ -96,7 +108,8 @@ class Server( private def restartTriggers(triggers: Vector[RunningTrigger]): Either[String, Unit] = { import cats.implicits._ // needed for traverse - triggers.traverse_(t => startTrigger(t.triggerParty, t.triggerName, Some(t.triggerInstance))) + triggers.traverse_(t => + startTrigger(t.triggerParty, t.triggerName, t.triggerToken, Some(t.triggerInstance))) } private def triggerRunnerName(triggerInstance: UUID): String = @@ -110,6 +123,7 @@ class Server( private def startTrigger( party: Party, triggerName: Identifier, + token: Option[String], existingInstance: Option[UUID] = None): Either[String, JsValue] = { for { trigger <- Trigger.fromIdentifier(compiledPackages, triggerName) @@ -117,7 +131,7 @@ class Server( case None => val newInstance = UUID.randomUUID triggerDao - .addRunningTrigger(RunningTrigger(newInstance, triggerName, party)) + .addRunningTrigger(RunningTrigger(newInstance, triggerName, party, token)) .map(_ => newInstance) case Some(instance) => Right(instance) } @@ -127,6 +141,7 @@ class Server( ctx.self, triggerInstance, party, + token, compiledPackages, trigger, ledgerConfig, @@ -171,20 +186,112 @@ class Server( private def getTriggerStatus(uuid: UUID): Vector[(LocalDateTime, String)] = triggerLog.getOrDefault(uuid, Vector.empty) - private val route = concat( + // TODO[AH] Make sure this is bounded in size. + private val authCallbacks: TrieMap[UUID, Route] = TrieMap() + + // This directive requires authorization for the given claims via the auth middleware, if configured. + // If no auth middleware is configured, then the request will proceed without attempting authorization. + // + // Authorization follows the steps defined in `triggers/service/authentication.md`. + // First asking for a token on the `/auth` endpoint and redirecting to `/login` if none was returned. + // If a login is required then this will store the current continuation in `authCallbacks` + // to proceed once the login flow completed and authentication succeeded. + private def authorize(claims: AuthRequest.Claims)( + implicit ec: ExecutionContext, + system: ActorSystem): Directive1[Option[String]] = + authConfig match { + case NoAuth => provide(None) + case AuthMiddleware(authUri) => + // Attempt to obtain the access token from the middleware's /auth endpoint. + // Forwards the current request's cookies. + // Fails if the response is not OK or Unauthorized. + def auth: Directive1[Option[String]] = { + val uri = authUri + .withPath(Path./("auth")) + .withQuery(AuthRequest.Auth(claims).toQuery) + import AuthJsonProtocol._ + extract(_.request.headers[headers.Cookie]).flatMap { cookies => + onSuccess(Http().singleRequest(HttpRequest(uri = uri, headers = cookies))).flatMap { + case HttpResponse(StatusCodes.OK, _, entity, _) => + onSuccess(Unmarshal(entity).to[AuthResponse.Authorize]).map { auth => + Some(auth.accessToken): Option[String] + } + case HttpResponse(StatusCodes.Unauthorized, _, _, _) => + provide(None) + case resp @ HttpResponse(code, _, _, _) => + onSuccess(Unmarshal(resp).to[String]).flatMap { msg => + logger.error(s"Failed to authorize with middleware ($code): $msg") + complete( + errorResponse( + StatusCodes.InternalServerError, + "Failed to authorize with middleware")) + } + } + } + } + Directive { inner => + auth { + // Authorization successful - pass token to continuation + case Some(token) => inner(Tuple1(Some(token))) + // Authorization failed - login and retry on callback request. + case None => { ctx => + val requestId = UUID.randomUUID() + authCallbacks.update( + requestId, { + auth { + case None => { + // Authorization failed after login - respond with 401 + // TODO[AH] Add WWW-Authenticate header + complete(errorResponse(StatusCodes.Unauthorized)) + } + case Some(token) => + // Authorization successful after login - use old request context and pass token to continuation. + mapRequestContext(_ => ctx) { + inner(Tuple1(Some(token))) + } + } + } + ) + // TODO[AH] Make the redirect URI configurable, especially the authority. E.g. when running behind nginx. + val callbackUri = Uri() + .withScheme(ctx.request.uri.scheme) + .withAuthority(ctx.request.uri.authority) + .withPath(Path./("cb")) + val uri = authUri + .withPath(Path./("login")) + .withQuery(AuthRequest.Login(callbackUri, claims, Some(requestId.toString)).toQuery) + ctx.redirect(uri, StatusCodes.Found) + } + } + } + } + + private def authCallback(requestId: UUID): Route = + authCallbacks.remove(requestId) match { + case None => complete(StatusCodes.NotFound) + case Some(callback) => callback + } + + private def route(implicit ec: ExecutionContext, system: ActorSystem) = concat( post { concat( // Start a new trigger given its identifier and the party it // should be running as. Returns a UUID for the newly // started trigger. path("v1" / "start") { - entity(as[StartParams]) { params => - startTrigger(params.party, params.triggerName) match { - case Left(err) => - complete(errorResponse(StatusCodes.UnprocessableEntity, err)) - case Right(triggerInstance) => - complete(successResponse(triggerInstance)) - } + entity(as[StartParams]) { + params => + val claims = + AuthRequest.Claims(actAs = List(params.party)) + // TODO[AH] Why do we need to pass ec, system explicitly? + authorize(claims)(ec, system) { token => + startTrigger(params.party, params.triggerName, token) match { + case Left(err) => + complete(errorResponse(StatusCodes.UnprocessableEntity, err)) + case Right(triggerInstance) => + complete(successResponse(triggerInstance)) + } + } } }, // upload a DAR as a multi-part form request with a single field called @@ -257,6 +364,14 @@ class Server( } } }, + // Authorization callback endpoint + path("cb") { + get { + parameters('state.as[UUID]) { requestId => + authCallback(requestId) + } + } + }, ) } @@ -265,6 +380,7 @@ object Server { def apply( host: String, port: Int, + authConfig: AuthConfig, ledgerConfig: LedgerConfig, restartConfig: TriggerRestartConfig, initialDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]], @@ -299,11 +415,11 @@ object Server { val (dao, server): (RunningTriggerDao, Server) = jdbcConfig match { case None => val dao = InMemoryTriggerDao() - val server = new Server(ledgerConfig, restartConfig, dao) + val server = new Server(authConfig, ledgerConfig, restartConfig, dao) (dao, server) case Some(c) => val dao = DbTriggerDao(c) - val server = new Server(ledgerConfig, restartConfig, dao) + val server = new Server(authConfig, ledgerConfig, restartConfig, dao) val recovery: Either[String, Unit] = for { _ <- dao.initialize packages <- dao.readPackages diff --git a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceConfig.scala b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceConfig.scala index 5111008980..9228add511 100644 --- a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceConfig.scala +++ b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceConfig.scala @@ -6,6 +6,7 @@ package com.daml.lf.engine.trigger import java.nio.file.{Path, Paths} import java.time.Duration +import akka.http.scaladsl.model.Uri import com.daml.cliopts import com.daml.platform.services.time.TimeProviderType import scalaz.Show @@ -21,6 +22,7 @@ private[trigger] final case class ServiceConfig( httpPort: Int, ledgerHost: String, ledgerPort: Int, + authUri: Option[Uri], maxInboundMessageSize: Int, minRestartInterval: FiniteDuration, maxRestartInterval: FiniteDuration, @@ -105,6 +107,13 @@ private[trigger] object ServiceConfig { .action((t, c) => c.copy(ledgerPort = t)) .text("Ledger port.") + opt[String]("auth") + .optional() + .action((t, c) => c.copy(authUri = Some(Uri(t)))) + .text("Auth middleware URI.") + // TODO[AH] Expose once the feature is fully implemented. + .hidden() + opt[Int]("max-inbound-message-size") .action((x, c) => c.copy(maxInboundMessageSize = x)) .optional() @@ -157,6 +166,7 @@ private[trigger] object ServiceConfig { httpPort = DefaultHttpPort, ledgerHost = null, ledgerPort = 0, + authUri = None, maxInboundMessageSize = DefaultMaxInboundMessageSize, minRestartInterval = DefaultMinRestartInterval, maxRestartInterval = DefaultMaxRestartInterval, diff --git a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceMain.scala b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceMain.scala index 2508b7b172..188318e21f 100644 --- a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceMain.scala +++ b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/ServiceMain.scala @@ -27,6 +27,7 @@ object ServiceMain { def startServer( host: String, port: Int, + authConfig: AuthConfig, ledgerConfig: LedgerConfig, restartConfig: TriggerRestartConfig, encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]], @@ -38,6 +39,7 @@ object ServiceMain { Server( host, port, + authConfig, ledgerConfig, restartConfig, encodedDar, @@ -66,6 +68,10 @@ object ServiceMain { case Success(dar) => dar } } + val authConfig: AuthConfig = config.authUri match { + case None => NoAuth + case Some(uri) => AuthMiddleware(uri) + } val ledgerConfig = LedgerConfig( config.ledgerHost, @@ -83,6 +89,7 @@ object ServiceMain { Server( config.address, config.httpPort, + authConfig, ledgerConfig, restartConfig, encodedDar, diff --git a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/TriggerRunnerImpl.scala b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/TriggerRunnerImpl.scala index 6e379a0bbc..dfa1c48dbd 100644 --- a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/TriggerRunnerImpl.scala +++ b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/TriggerRunnerImpl.scala @@ -31,7 +31,7 @@ object TriggerRunnerImpl { server: ActorRef[Message], triggerInstance: UUID, party: Party, - // TODO(SF, 2020-06-09): Add access token field here in the presence of authentication. + token: Option[String], compiledPackages: CompiledPackages, trigger: Trigger, ledgerConfig: LedgerConfig, @@ -67,9 +67,7 @@ object TriggerRunnerImpl { commandClient = CommandClientConfiguration.default.copy( defaultDeduplicationTime = config.ledgerConfig.commandTtl), sslContext = None, - // TODO(SF, 2020-06-09): In the presence of an authorization - // service, get an access token and pass it through here! - token = None, + token = config.token, maxInboundMessageSize = config.ledgerConfig.maxInboundMessageSize, ) diff --git a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/dao/DbTriggerDao.scala b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/dao/DbTriggerDao.scala index 9a1c7205b7..eec3fcc5d4 100644 --- a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/dao/DbTriggerDao.scala +++ b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/dao/DbTriggerDao.scala @@ -149,7 +149,8 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con triggerInstance: UUID, party: String, fullTriggerName: String): Either[String, RunningTrigger] = { - Identifier.fromString(fullTriggerName).map(RunningTrigger(triggerInstance, _, Tag(party))) + // TODO[AH] Persist the access and refresh token. + Identifier.fromString(fullTriggerName).map(RunningTrigger(triggerInstance, _, Tag(party), None)) } // Drop all tables and other objects associated with the database. diff --git a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/package.scala b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/package.scala index 1a857d62ba..7702e31af7 100644 --- a/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/package.scala +++ b/triggers/service/src/main/scala/com/digitalasset/daml/lf/engine/trigger/package.scala @@ -10,10 +10,15 @@ import com.daml.ledger.api.refinements.ApiTypes.Party import com.daml.lf.data.Ref.Identifier import com.daml.platform.services.time.TimeProviderType +import akka.http.scaladsl.model.Uri import scala.concurrent.duration.FiniteDuration package trigger { + sealed trait AuthConfig + case object NoAuth extends AuthConfig + final case class AuthMiddleware(uri: Uri) extends AuthConfig + case class LedgerConfig( host: String, port: Int, @@ -32,7 +37,6 @@ package trigger { triggerInstance: UUID, triggerName: Identifier, triggerParty: Party, - // TODO(SF, 2020-0610): Add access token field here in the - // presence of authentication. + triggerToken: Option[String], ) } diff --git a/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceFixture.scala b/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceFixture.scala index e9f23f0668..68421e5af7 100644 --- a/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceFixture.scala +++ b/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceFixture.scala @@ -7,14 +7,21 @@ import java.io.File import java.net.InetAddress import java.time.Duration +import akka.actor.ActorSystem import akka.actor.typed.{ActorSystem => TypedActorSystem} import akka.http.scaladsl.Http.ServerBinding import akka.http.scaladsl.model.Uri +import akka.http.scaladsl.model.Uri.Path import akka.stream.Materializer import com.daml.bazeltools.BazelRunfiles import com.daml.daml_lf_dev.DamlLf import com.daml.grpc.adapter.ExecutionSequencerFactory +import com.daml.jwt.domain.DecodedJwt +import com.daml.jwt.{HMAC256Verifier, JwtSigner} +import com.daml.ledger.api.auth +import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload} import com.daml.ledger.api.domain.LedgerId +import com.daml.ledger.api.refinements.ApiTypes import com.daml.ledger.api.refinements.ApiTypes.ApplicationId import com.daml.ledger.client.LedgerClient import com.daml.ledger.client.configuration.{ @@ -24,6 +31,8 @@ import com.daml.ledger.client.configuration.{ } import com.daml.lf.archive.Dar import com.daml.lf.data.Ref._ +import com.daml.oauth.middleware.{Config => MiddlewareConfig, Server => MiddlewareServer} +import com.daml.oauth.server.{Config => OAuthConfig, Server => OAuthServer} import com.daml.platform.common.LedgerIdMode import com.daml.platform.sandbox import com.daml.platform.sandbox.SandboxServer @@ -40,6 +49,13 @@ import scala.concurrent._ import scala.concurrent.duration._ import scala.sys.process.Process +private[trigger] final case class AuthTestConfig( + // HMAC256 signature secret. + jwtSecret: String, + // Grant readAs claims for these parties to the ledger client provided to test cases. + parties: List[ApiTypes.Party], +) + object TriggerServiceFixture extends StrictLogging { // Use a small initial interval so we can test restart behaviour more easily. @@ -50,10 +66,12 @@ object TriggerServiceFixture extends StrictLogging { dars: List[File], encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]], jdbcConfig: Option[JdbcConfig], + authTestConfig: Option[AuthTestConfig], )(testFn: (Uri, LedgerClient, Proxy) => Future[A])( implicit mat: Materializer, aesf: ExecutionSequencerFactory, ec: ExecutionContext, + system: ActorSystem, pos: source.Position, ): Future[A] = { logger.info(s"${pos.fileName}:${pos.lineNumber}: setting up trigger service") @@ -77,9 +95,43 @@ object TriggerServiceFixture extends StrictLogging { val ledgerId = LedgerId(testName) val applicationId = ApplicationId(testName) + val authF: Future[(AuthConfig, () => Future[Unit])] = authTestConfig match { + case None => Future((NoAuth, () => Future(()))) + case Some(AuthTestConfig(secret, _)) => + for { + oauth <- OAuthServer.start( + OAuthConfig( + port = Port.Dynamic, + ledgerId = LedgerId.unwrap(ledgerId), + // TODO[AH] Choose application ID, see https://github.com/digital-asset/daml/issues/7671 + applicationId = None, + jwtSecret = secret, + )) + serverUri = Uri() + .withScheme("http") + .withAuthority(oauth.localAddress.getHostString, oauth.localAddress.getPort) + middleware <- MiddlewareServer.start( + MiddlewareConfig( + port = Port.Dynamic, + oauthAuth = serverUri.withPath(Path./("authorize")), + oauthToken = serverUri.withPath(Path./("token")), + clientId = "oauth-middleware-id", + clientSecret = "oauth-middleware-secret", + )) + middlewareUri = Uri() + .withScheme("http") + .withAuthority(middleware.localAddress.getHostString, middleware.localAddress.getPort) + cleanup = () => + for { + _ <- middleware.unbind() + _ <- oauth.unbind() + } yield () + } yield (AuthMiddleware(middlewareUri), cleanup) + } val ledgerF = for { (_, toxiproxyClient) <- toxiproxyF - ledger <- Future(new SandboxServer(ledgerConfig(Port.Dynamic, dars, ledgerId), mat)) + ledger <- Future( + new SandboxServer(ledgerConfig(Port.Dynamic, dars, ledgerId, authTestConfig), mat)) ledgerPort <- ledger.portF ledgerProxyPort = LockedFreePort.find() ledgerProxy = toxiproxyClient.createProxy( @@ -97,12 +149,13 @@ object TriggerServiceFixture extends StrictLogging { client <- LedgerClient.singleHost( host.getHostName, ledgerPort.value, - clientConfig(applicationId), + clientConfig(applicationId, authTestConfig), ) } yield client // Configure the service with the ledger's *proxy* port. val serviceF: Future[(ServerBinding, TypedActorSystem[Message])] = for { + (authConfig, _) <- authF (_, _, ledgerProxyPort, _) <- ledgerF ledgerConfig = LedgerConfig( host.getHostName, @@ -119,6 +172,7 @@ object TriggerServiceFixture extends StrictLogging { service <- ServiceMain.startServer( host.getHostName, servicePort.port.value, + authConfig, ledgerConfig, restartConfig, encodedDar, @@ -152,7 +206,10 @@ object TriggerServiceFixture extends StrictLogging { proc.destroy() proc.exitValue() // destroy is async }) - result <- (ta.failed.toOption orElse se orElse le orElse te) + ae <- optErr(authF.flatMap { + case (_, cleanup) => cleanup() + }) + result <- (ta.failed.toOption orElse se orElse le orElse te orElse ae) .cata(Future.failed, Future fromTry ta) } yield result } @@ -164,24 +221,43 @@ object TriggerServiceFixture extends StrictLogging { private def ledgerConfig( ledgerPort: Port, dars: List[File], - ledgerId: LedgerId + ledgerId: LedgerId, + authTestConfig: Option[AuthTestConfig] ): SandboxConfig = sandbox.DefaultConfig.copy( port = ledgerPort, damlPackages = dars, timeProviderType = Some(TimeProviderType.Static), ledgerIdMode = LedgerIdMode.Static(ledgerId), - authService = None, + authService = for { + cfg <- authTestConfig + verifier <- HMAC256Verifier(cfg.jwtSecret).toOption + } yield auth.AuthServiceJWT(verifier) ) private def clientConfig[A]( applicationId: ApplicationId, - token: Option[String] = None): LedgerClientConfiguration = + authTestConfig: Option[AuthTestConfig]): LedgerClientConfiguration = LedgerClientConfiguration( applicationId = ApplicationId.unwrap(applicationId), ledgerIdRequirement = LedgerIdRequirement.none, commandClient = CommandClientConfiguration.default, sslContext = None, - token = token, + token = for { + cfg <- authTestConfig + header = """{"alg": "HS256", "typ": "JWT"}""" + payload = AuthServiceJWTPayload( + ledgerId = None, + applicationId = None, + participantId = None, + exp = None, + admin = true, + actAs = cfg.parties.map(ApiTypes.Party.unwrap), + readAs = List(), + ) + jwt <- JwtSigner.HMAC256 + .sign(DecodedJwt(header, AuthServiceJWTCodec.compactPrint(payload)), cfg.jwtSecret) + .toOption + } yield jwt.value, ) } diff --git a/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceTest.scala b/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceTest.scala index 4d17d067d2..23e0593f2c 100644 --- a/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceTest.scala +++ b/triggers/service/src/test/scala/com/digitalasset/daml/lf/engine/trigger/TriggerServiceTest.scala @@ -35,15 +35,82 @@ import com.daml.ledger.api.v1.transaction_filter.{Filters, InclusiveFilters, Tra import com.daml.ledger.client.LedgerClient import com.daml.lf.engine.trigger.dao.DbTriggerDao import com.daml.testing.postgresql.PostgresAroundAll +import com.typesafe.scalalogging.StrictLogging import eu.rekawek.toxiproxy._ -abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually with Matchers { +import scala.collection.concurrent.TrieMap +import scala.util.Success + +/** + * A test-fixture that persists cookies between http requests for each test-case. + */ +trait HttpCookies extends BeforeAndAfterEach { this: Suite => + private val cookieJar = TrieMap[String, String]() + + override protected def afterEach(): Unit = { + try super.afterEach() + finally cookieJar.clear() + } + + /** + * Adds a Cookie header for the currently stored cookies and performs the given http request. + */ + def httpRequest(request: HttpRequest)( + implicit system: ActorSystem, + ec: ExecutionContext): Future[HttpResponse] = { + Http() + .singleRequest { + if (cookieJar.nonEmpty) { + val cookies = headers.Cookie(values = cookieJar.to[Seq]: _*) + request.addHeader(cookies) + } else { + request + } + } + .andThen { + case Success(resp) => + resp.headers.foreach { + case headers.`Set-Cookie`(cookie) => + cookieJar.update(cookie.name, cookie.value) + case _ => + } + } + } + + /** + * Same as [[httpRequest]] but will follow redirections. + */ + def httpRequestFollow(request: HttpRequest, maxRedirections: Int = 10)( + implicit system: ActorSystem, + ec: ExecutionContext): Future[HttpResponse] = { + httpRequest(request).flatMap { + case resp @ HttpResponse(StatusCodes.Redirection(_), _, _, _) => + if (maxRedirections == 0) { + throw new RuntimeException("Too many redirections") + } else { + val uri = resp.header[headers.Location].get.uri + httpRequestFollow(HttpRequest(uri = uri), maxRedirections - 1) + } + case resp => Future(resp) + } + } +} + +abstract class AbstractTriggerServiceTest + extends AsyncFlatSpec + with HttpCookies + with Eventually + with Matchers + with StrictLogging { import AbstractTriggerServiceTest.CompatAssertion // Abstract member for testing with and without a database def jdbcConfig: Option[JdbcConfig] + // Abstract member for testing with and without authentication/authorization + def authTestConfig: Option[AuthTestConfig] + // Default retry config for `eventually` override implicit def patienceConfig: PatienceConfig = PatienceConfig(timeout = scaled(Span(15, Seconds)), interval = scaled(Span(1, Seconds))) @@ -77,7 +144,12 @@ abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually def withTriggerService[A](encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]])( testFn: (Uri, LedgerClient, Proxy) => Future[A])(implicit pos: source.Position): Future[A] = - TriggerServiceFixture.withTriggerService(testId, List(darPath), encodedDar, jdbcConfig)(testFn) + TriggerServiceFixture.withTriggerService( + testId, + List(darPath), + encodedDar, + jdbcConfig, + authTestConfig)(testFn) def startTrigger(uri: Uri, triggerName: String, party: Party): Future[HttpResponse] = { val req = HttpRequest( @@ -88,7 +160,7 @@ abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually s"""{"triggerName": "$triggerName", "party": "$party"}""" ) ) - Http().singleRequest(req) + httpRequestFollow(req) } def listTriggers(uri: Uri, party: Party): Future[HttpResponse] = { @@ -446,6 +518,7 @@ object AbstractTriggerServiceTest { class TriggerServiceTestInMem extends AbstractTriggerServiceTest { override def jdbcConfig: Option[JdbcConfig] = None + override def authTestConfig: Option[AuthTestConfig] = None } @@ -456,6 +529,7 @@ class TriggerServiceTestWithDb with PostgresAroundAll { override def jdbcConfig: Option[JdbcConfig] = Some(jdbcConfig_) + override def authTestConfig: Option[AuthTestConfig] = None // Lazy because the postgresDatabase is only available once the tests start private lazy val jdbcConfig_ = JdbcConfig(postgresDatabase.url, "operator", "password") @@ -530,3 +604,16 @@ class TriggerServiceTestWithDb } yield succeed) } + +// Tests for auth mode only go here +class TriggerServiceTestAuth extends AbstractTriggerServiceTest { + + override def jdbcConfig: Option[JdbcConfig] = None + override def authTestConfig: Option[AuthTestConfig] = + Some( + AuthTestConfig( + jwtSecret = "secret", + parties = List(alice, bob), + )) + +}