mirror of
https://github.com/digital-asset/daml.git
synced 2024-09-20 01:07:18 +03:00
Use auth middleware in trigger service /v1/start
endpoint (#7654)
* Authorize trigger service on middleware changelog_begin changelog_end * Trigger service auth callback handler * Forward token * Do not pin the application ID in the access token The trigger service will assign an individual application ID to each trigger based on its UUID. Requiring tokens on the granularity of application IDs would break the idea of storing the token in a cookie to be able to use it across multiple requests. changelog_begin changelog_end * todo persist trigger token * Add a state parameter to middleware login * add documentation comments * typo * fmt * Align Party type between middleware and trigger service The middleware was using `com.daml.lf.data.Ref.Party` while the trigger service is using `com.daml.ledger.api.refinements.ApiTypes.Party` which requires conversions. This aligns the types to avoid such conversions. * optional application id in oauth2 test server * align party types * configure auth middleware in trigger service tests * handle empty cookie header * follow redirects in trigger service tests * keep track of cookies * keep track of cookies * Replace any previous Cookie header Otherwise on old daml-ledger-token cookie might persist and be preferred over a newly added instance. * DEBUG * Configure test ledger client readAs claims * fmt * docstrings * remove debug output * Avoid endless redirect loops When the replay still fails to authorize on the middleware then we do not want to attempt another login flow. * Store callback routes in authCallbacks * fmt * Push AuthTestConfig into test target https://github.com/digital-asset/daml/pull/7654#discussion_r506510193 * Unbind oauth2 server after middleware https://github.com/digital-asset/daml/pull/7654/files#r506513251 Co-authored-by: Andreas Herrmann <andreas.herrmann@tweag.io>
This commit is contained in:
parent
f025dc3065
commit
60fe244e1b
@ -45,6 +45,7 @@ da_scala_library(
|
||||
"//libs-scala/contextualized-logging",
|
||||
"//libs-scala/scala-utils",
|
||||
"//triggers/runner:trigger-runner-lib",
|
||||
"//triggers/service/auth:oauth-middleware", # TODO[AH] Extract request/response types
|
||||
"@maven//:com_chuusai_shapeless_2_12",
|
||||
"@maven//:com_github_scopt_scopt_2_12",
|
||||
"@maven//:com_lihaoyi_sourcecode_2_12",
|
||||
@ -53,6 +54,7 @@ da_scala_library(
|
||||
"@maven//:com_typesafe_akka_akka_http_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_core_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_spray_json_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_parsing_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_stream_2_12",
|
||||
"@maven//:com_typesafe_scala_logging_scala_logging_2_12",
|
||||
"@maven//:com_zaxxer_HikariCP",
|
||||
@ -101,7 +103,9 @@ da_scala_test(
|
||||
"//language-support/scala/bindings-akka",
|
||||
"//ledger-api/rs-grpc-bridge",
|
||||
"//ledger-service/cli-opts",
|
||||
"//ledger-service/jwt",
|
||||
"//ledger/caching",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//ledger/ledger-api-common",
|
||||
"//ledger/participant-integration-api",
|
||||
"//ledger/participant-state",
|
||||
@ -110,9 +114,12 @@ da_scala_test(
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/postgresql-testing",
|
||||
"//libs-scala/timer-utils",
|
||||
"//triggers/service/auth:oauth-middleware",
|
||||
"//triggers/service/auth:oauth-test-server",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:com_typesafe_akka_akka_actor_typed_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_core_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_parsing_2_12",
|
||||
"@maven//:eu_rekawek_toxiproxy_toxiproxy_java_2_1_3",
|
||||
"@maven//:io_spray_spray_json_2_12",
|
||||
"@maven//:org_scalatest_scalatest_2_12",
|
||||
|
@ -18,6 +18,7 @@ da_scala_library(
|
||||
deps = [
|
||||
":oauth-test-server", # TODO[AH] Extract OAuth2 request/response types
|
||||
"//daml-lf/data",
|
||||
"//language-support/scala/bindings",
|
||||
"//ledger-service/jwt",
|
||||
"//ledger/ledger-api-auth",
|
||||
"//libs-scala/ports",
|
||||
@ -112,6 +113,7 @@ da_scala_test(
|
||||
":oauth-middleware",
|
||||
":oauth-test-server",
|
||||
"//daml-lf/data",
|
||||
"//language-support/scala/bindings",
|
||||
"//ledger-api/rs-grpc-bridge",
|
||||
"//ledger-api/testing-utils",
|
||||
"//ledger-service/jwt",
|
||||
|
@ -6,7 +6,7 @@ package com.daml.oauth.middleware
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import com.daml.ports.Port
|
||||
|
||||
private[middleware] case class Config(
|
||||
case class Config(
|
||||
// Port the middleware listens on
|
||||
port: Port,
|
||||
// OAuth2 server endpoints
|
||||
@ -17,7 +17,7 @@ private[middleware] case class Config(
|
||||
clientSecret: String,
|
||||
)
|
||||
|
||||
private[middleware] object Config {
|
||||
object Config {
|
||||
private val Empty =
|
||||
Config(
|
||||
port = Port.Dynamic,
|
||||
|
@ -51,7 +51,6 @@ repository](https://github.com/digital-asset/ex-secure-daml-infra).
|
||||
context.accessToken[namespace] = {
|
||||
// NOTE change the ledger ID to match your deployment.
|
||||
"ledgerId": "2D105384-CE61-4CCC-8E0E-37248BA935A3",
|
||||
"applicationId": context.clientName,
|
||||
"actAs": actAs,
|
||||
"readAs": readAs,
|
||||
"admin": admin
|
||||
|
@ -6,7 +6,7 @@ package com.daml.oauth.middleware
|
||||
import akka.http.scaladsl.marshalling.Marshaller
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import akka.http.scaladsl.unmarshalling.Unmarshaller
|
||||
import com.daml.lf.data.Ref.Party
|
||||
import com.daml.ledger.api.refinements.ApiTypes.Party
|
||||
import spray.json.{
|
||||
DefaultJsonProtocol,
|
||||
JsString,
|
||||
@ -48,9 +48,9 @@ object Request {
|
||||
if (w == "admin") {
|
||||
admin = true
|
||||
} else if (w.startsWith("actAs:")) {
|
||||
actAs.append(Party.assertFromString(w.stripPrefix("actAs:")))
|
||||
actAs.append(Party(w.stripPrefix("actAs:")))
|
||||
} else if (w.startsWith("readAs:")) {
|
||||
readAs.append(Party.assertFromString(w.stripPrefix("readAs:")))
|
||||
readAs.append(Party(w.stripPrefix("readAs:")))
|
||||
} else {
|
||||
throw new IllegalArgumentException(s"Expected claim but got $w")
|
||||
}
|
||||
@ -62,14 +62,26 @@ object Request {
|
||||
|
||||
/** Auth endpoint query parameters
|
||||
*/
|
||||
case class Auth(claims: Claims)
|
||||
case class Auth(claims: Claims) {
|
||||
def toQuery: Uri.Query = Uri.Query("claims" -> claims.toQueryString())
|
||||
}
|
||||
|
||||
/** Login endpoint query parameters
|
||||
*
|
||||
* @param redirectUri Redirect target after the login flow completed. I.e. the original request URI on the trigger service.
|
||||
* @param claims Required ledger claims.
|
||||
* @param state State that will be forwarded to the callback URI after authentication and authorization.
|
||||
*/
|
||||
case class Login(redirectUri: Uri, claims: Claims)
|
||||
case class Login(redirectUri: Uri, claims: Claims, state: Option[String]) {
|
||||
def toQuery: Uri.Query = {
|
||||
var params = Seq(
|
||||
"redirect_uri" -> redirectUri.toString,
|
||||
"claims" -> claims.toQueryString(),
|
||||
)
|
||||
state.foreach(x => params ++= Seq("state" -> x))
|
||||
Uri.Query(params: _*)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
@ -95,11 +95,15 @@ object Server extends StrictLogging {
|
||||
}
|
||||
|
||||
private def login(config: Config, requests: TrieMap[UUID, Uri]) =
|
||||
parameters(('redirect_uri.as[Uri], 'claims.as[Request.Claims]))
|
||||
parameters(('redirect_uri.as[Uri], 'claims.as[Request.Claims], 'state ?))
|
||||
.as[Request.Login](Request.Login) { login =>
|
||||
extractRequest { request =>
|
||||
val requestId = UUID.randomUUID
|
||||
requests += (requestId -> login.redirectUri)
|
||||
requests += (requestId -> {
|
||||
var query = login.redirectUri.query().to[Seq]
|
||||
login.state.foreach(x => query ++= Seq("state" -> x))
|
||||
login.redirectUri.withQuery(Uri.Query(query: _*))
|
||||
})
|
||||
val authorize = OAuthRequest.Authorize(
|
||||
responseType = "code",
|
||||
clientId = config.clientId,
|
||||
|
@ -11,14 +11,14 @@ case class Config(
|
||||
// Ledger ID of issued tokens
|
||||
ledgerId: String,
|
||||
// Application ID of issued tokens
|
||||
applicationId: String,
|
||||
applicationId: Option[String],
|
||||
// Secret used to sign JWTs
|
||||
jwtSecret: String
|
||||
)
|
||||
|
||||
object Config {
|
||||
private val Empty =
|
||||
Config(port = Port.Dynamic, ledgerId = null, applicationId = null, jwtSecret = null)
|
||||
Config(port = Port.Dynamic, ledgerId = null, applicationId = None, jwtSecret = null)
|
||||
|
||||
def parseConfig(args: Seq[String]): Option[Config] =
|
||||
configParser.parse(args, Empty)
|
||||
@ -37,7 +37,7 @@ object Config {
|
||||
.action((x, c) => c.copy(ledgerId = x))
|
||||
|
||||
opt[String]("application-id")
|
||||
.action((x, c) => c.copy(applicationId = x))
|
||||
.action((x, c) => c.copy(applicationId = Some(x)))
|
||||
|
||||
opt[String]("secret")
|
||||
.action((x, c) => c.copy(jwtSecret = x))
|
||||
|
@ -41,7 +41,7 @@ object Server {
|
||||
}
|
||||
AuthServiceJWTPayload(
|
||||
ledgerId = Some(config.ledgerId),
|
||||
applicationId = Some(config.applicationId),
|
||||
applicationId = config.applicationId,
|
||||
// Not required by the default auth service
|
||||
participantId = None,
|
||||
// Only for testing, expire never.
|
||||
|
@ -12,8 +12,8 @@ import akka.http.scaladsl.unmarshalling.Unmarshal
|
||||
import com.daml.jwt.JwtSigner
|
||||
import com.daml.jwt.domain.DecodedJwt
|
||||
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload}
|
||||
import com.daml.ledger.api.refinements.ApiTypes
|
||||
import com.daml.ledger.api.testing.utils.SuiteResourceManagementAroundAll
|
||||
import com.daml.lf.data.Ref.Party
|
||||
import com.daml.oauth.server.{Response => OAuthResponse}
|
||||
import org.scalatest.AsyncWordSpec
|
||||
|
||||
@ -33,8 +33,8 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
|
||||
participantId = None,
|
||||
exp = None,
|
||||
admin = claims.admin,
|
||||
actAs = claims.actAs,
|
||||
readAs = claims.readAs
|
||||
actAs = claims.actAs.map(ApiTypes.Party.unwrap(_)),
|
||||
readAs = claims.readAs.map(ApiTypes.Party.unwrap(_))
|
||||
)
|
||||
OAuthResponse.Token(
|
||||
accessToken = JwtSigner.HMAC256
|
||||
@ -63,7 +63,7 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
|
||||
}
|
||||
}
|
||||
"return the token from a cookie" in {
|
||||
val claims = Request.Claims(actAs = List(Party.assertFromString("Alice")))
|
||||
val claims = Request.Claims(actAs = List(ApiTypes.Party("Alice")))
|
||||
val token = makeToken(claims)
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
val req = HttpRequest(
|
||||
@ -83,13 +83,13 @@ class Test extends AsyncWordSpec with TestFixture with SuiteResourceManagementAr
|
||||
}
|
||||
}
|
||||
"return unauthorized on insufficient claims" in {
|
||||
val token = makeToken(Request.Claims(actAs = List(Party.assertFromString("Alice"))))
|
||||
val token = makeToken(Request.Claims(actAs = List(ApiTypes.Party("Alice"))))
|
||||
val cookieHeader = Cookie("daml-ledger-token", token.toCookieValue)
|
||||
val req = HttpRequest(
|
||||
uri = middlewareUri
|
||||
.withPath(Path./("auth"))
|
||||
.withQuery(Query(
|
||||
("claims", Request.Claims(actAs = List(Party.assertFromString("Bob"))).toQueryString))),
|
||||
.withQuery(
|
||||
Query(("claims", Request.Claims(actAs = List(ApiTypes.Party("Bob"))).toQueryString))),
|
||||
headers = List(cookieHeader)
|
||||
)
|
||||
for {
|
||||
|
@ -30,7 +30,7 @@ trait TestFixture extends AkkaBeforeAndAfterAll with SuiteResource[(ServerBindin
|
||||
OAuthConfig(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = ledgerId,
|
||||
applicationId = applicationId,
|
||||
applicationId = Some(applicationId),
|
||||
jwtSecret = jwtSecret))
|
||||
serverUri = Uri()
|
||||
.withScheme("http")
|
||||
|
@ -29,7 +29,7 @@ trait TestFixture extends AkkaBeforeAndAfterAll with SuiteResource[(ServerBindin
|
||||
Config(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = ledgerId,
|
||||
applicationId = applicationId,
|
||||
applicationId = Some(applicationId),
|
||||
jwtSecret = jwtSecret))
|
||||
client <- Resources.authClient(
|
||||
Client.Config(
|
||||
|
@ -4,19 +4,28 @@
|
||||
package com.daml.lf
|
||||
package engine.trigger
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.actor.typed.{ActorRef, Behavior}
|
||||
import akka.actor.typed.scaladsl.{ActorContext, Behaviors}
|
||||
import akka.actor.typed.scaladsl.adapter._
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.Http.ServerBinding
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.model.Uri.Path
|
||||
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import akka.http.scaladsl.server.{Directive, Directive1}
|
||||
import akka.http.scaladsl.server.Directives._
|
||||
import akka.http.scaladsl.server.Route
|
||||
import akka.http.scaladsl.unmarshalling.Unmarshal
|
||||
import akka.stream.Materializer
|
||||
import akka.util.ByteString
|
||||
import spray.json.DefaultJsonProtocol._
|
||||
import spray.json._
|
||||
import com.daml.oauth.middleware.{
|
||||
JsonProtocol => AuthJsonProtocol,
|
||||
Request => AuthRequest,
|
||||
Response => AuthResponse
|
||||
}
|
||||
import com.daml.ledger.api.refinements.ApiTypes.Party
|
||||
import com.daml.lf.archive.{Dar, DarReader, Decode}
|
||||
import com.daml.lf.archive.Reader.ParseError
|
||||
@ -43,7 +52,10 @@ import java.util.UUID
|
||||
import java.util.zip.ZipInputStream
|
||||
import java.time.LocalDateTime
|
||||
|
||||
import scala.collection.concurrent.TrieMap
|
||||
|
||||
class Server(
|
||||
authConfig: AuthConfig,
|
||||
ledgerConfig: LedgerConfig,
|
||||
restartConfig: TriggerRestartConfig,
|
||||
triggerDao: RunningTriggerDao)(
|
||||
@ -96,7 +108,8 @@ class Server(
|
||||
|
||||
private def restartTriggers(triggers: Vector[RunningTrigger]): Either[String, Unit] = {
|
||||
import cats.implicits._ // needed for traverse
|
||||
triggers.traverse_(t => startTrigger(t.triggerParty, t.triggerName, Some(t.triggerInstance)))
|
||||
triggers.traverse_(t =>
|
||||
startTrigger(t.triggerParty, t.triggerName, t.triggerToken, Some(t.triggerInstance)))
|
||||
}
|
||||
|
||||
private def triggerRunnerName(triggerInstance: UUID): String =
|
||||
@ -110,6 +123,7 @@ class Server(
|
||||
private def startTrigger(
|
||||
party: Party,
|
||||
triggerName: Identifier,
|
||||
token: Option[String],
|
||||
existingInstance: Option[UUID] = None): Either[String, JsValue] = {
|
||||
for {
|
||||
trigger <- Trigger.fromIdentifier(compiledPackages, triggerName)
|
||||
@ -117,7 +131,7 @@ class Server(
|
||||
case None =>
|
||||
val newInstance = UUID.randomUUID
|
||||
triggerDao
|
||||
.addRunningTrigger(RunningTrigger(newInstance, triggerName, party))
|
||||
.addRunningTrigger(RunningTrigger(newInstance, triggerName, party, token))
|
||||
.map(_ => newInstance)
|
||||
case Some(instance) => Right(instance)
|
||||
}
|
||||
@ -127,6 +141,7 @@ class Server(
|
||||
ctx.self,
|
||||
triggerInstance,
|
||||
party,
|
||||
token,
|
||||
compiledPackages,
|
||||
trigger,
|
||||
ledgerConfig,
|
||||
@ -171,20 +186,112 @@ class Server(
|
||||
private def getTriggerStatus(uuid: UUID): Vector[(LocalDateTime, String)] =
|
||||
triggerLog.getOrDefault(uuid, Vector.empty)
|
||||
|
||||
private val route = concat(
|
||||
// TODO[AH] Make sure this is bounded in size.
|
||||
private val authCallbacks: TrieMap[UUID, Route] = TrieMap()
|
||||
|
||||
// This directive requires authorization for the given claims via the auth middleware, if configured.
|
||||
// If no auth middleware is configured, then the request will proceed without attempting authorization.
|
||||
//
|
||||
// Authorization follows the steps defined in `triggers/service/authentication.md`.
|
||||
// First asking for a token on the `/auth` endpoint and redirecting to `/login` if none was returned.
|
||||
// If a login is required then this will store the current continuation in `authCallbacks`
|
||||
// to proceed once the login flow completed and authentication succeeded.
|
||||
private def authorize(claims: AuthRequest.Claims)(
|
||||
implicit ec: ExecutionContext,
|
||||
system: ActorSystem): Directive1[Option[String]] =
|
||||
authConfig match {
|
||||
case NoAuth => provide(None)
|
||||
case AuthMiddleware(authUri) =>
|
||||
// Attempt to obtain the access token from the middleware's /auth endpoint.
|
||||
// Forwards the current request's cookies.
|
||||
// Fails if the response is not OK or Unauthorized.
|
||||
def auth: Directive1[Option[String]] = {
|
||||
val uri = authUri
|
||||
.withPath(Path./("auth"))
|
||||
.withQuery(AuthRequest.Auth(claims).toQuery)
|
||||
import AuthJsonProtocol._
|
||||
extract(_.request.headers[headers.Cookie]).flatMap { cookies =>
|
||||
onSuccess(Http().singleRequest(HttpRequest(uri = uri, headers = cookies))).flatMap {
|
||||
case HttpResponse(StatusCodes.OK, _, entity, _) =>
|
||||
onSuccess(Unmarshal(entity).to[AuthResponse.Authorize]).map { auth =>
|
||||
Some(auth.accessToken): Option[String]
|
||||
}
|
||||
case HttpResponse(StatusCodes.Unauthorized, _, _, _) =>
|
||||
provide(None)
|
||||
case resp @ HttpResponse(code, _, _, _) =>
|
||||
onSuccess(Unmarshal(resp).to[String]).flatMap { msg =>
|
||||
logger.error(s"Failed to authorize with middleware ($code): $msg")
|
||||
complete(
|
||||
errorResponse(
|
||||
StatusCodes.InternalServerError,
|
||||
"Failed to authorize with middleware"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Directive { inner =>
|
||||
auth {
|
||||
// Authorization successful - pass token to continuation
|
||||
case Some(token) => inner(Tuple1(Some(token)))
|
||||
// Authorization failed - login and retry on callback request.
|
||||
case None => { ctx =>
|
||||
val requestId = UUID.randomUUID()
|
||||
authCallbacks.update(
|
||||
requestId, {
|
||||
auth {
|
||||
case None => {
|
||||
// Authorization failed after login - respond with 401
|
||||
// TODO[AH] Add WWW-Authenticate header
|
||||
complete(errorResponse(StatusCodes.Unauthorized))
|
||||
}
|
||||
case Some(token) =>
|
||||
// Authorization successful after login - use old request context and pass token to continuation.
|
||||
mapRequestContext(_ => ctx) {
|
||||
inner(Tuple1(Some(token)))
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
// TODO[AH] Make the redirect URI configurable, especially the authority. E.g. when running behind nginx.
|
||||
val callbackUri = Uri()
|
||||
.withScheme(ctx.request.uri.scheme)
|
||||
.withAuthority(ctx.request.uri.authority)
|
||||
.withPath(Path./("cb"))
|
||||
val uri = authUri
|
||||
.withPath(Path./("login"))
|
||||
.withQuery(AuthRequest.Login(callbackUri, claims, Some(requestId.toString)).toQuery)
|
||||
ctx.redirect(uri, StatusCodes.Found)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def authCallback(requestId: UUID): Route =
|
||||
authCallbacks.remove(requestId) match {
|
||||
case None => complete(StatusCodes.NotFound)
|
||||
case Some(callback) => callback
|
||||
}
|
||||
|
||||
private def route(implicit ec: ExecutionContext, system: ActorSystem) = concat(
|
||||
post {
|
||||
concat(
|
||||
// Start a new trigger given its identifier and the party it
|
||||
// should be running as. Returns a UUID for the newly
|
||||
// started trigger.
|
||||
path("v1" / "start") {
|
||||
entity(as[StartParams]) { params =>
|
||||
startTrigger(params.party, params.triggerName) match {
|
||||
case Left(err) =>
|
||||
complete(errorResponse(StatusCodes.UnprocessableEntity, err))
|
||||
case Right(triggerInstance) =>
|
||||
complete(successResponse(triggerInstance))
|
||||
}
|
||||
entity(as[StartParams]) {
|
||||
params =>
|
||||
val claims =
|
||||
AuthRequest.Claims(actAs = List(params.party))
|
||||
// TODO[AH] Why do we need to pass ec, system explicitly?
|
||||
authorize(claims)(ec, system) { token =>
|
||||
startTrigger(params.party, params.triggerName, token) match {
|
||||
case Left(err) =>
|
||||
complete(errorResponse(StatusCodes.UnprocessableEntity, err))
|
||||
case Right(triggerInstance) =>
|
||||
complete(successResponse(triggerInstance))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
// upload a DAR as a multi-part form request with a single field called
|
||||
@ -257,6 +364,14 @@ class Server(
|
||||
}
|
||||
}
|
||||
},
|
||||
// Authorization callback endpoint
|
||||
path("cb") {
|
||||
get {
|
||||
parameters('state.as[UUID]) { requestId =>
|
||||
authCallback(requestId)
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@ -265,6 +380,7 @@ object Server {
|
||||
def apply(
|
||||
host: String,
|
||||
port: Int,
|
||||
authConfig: AuthConfig,
|
||||
ledgerConfig: LedgerConfig,
|
||||
restartConfig: TriggerRestartConfig,
|
||||
initialDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]],
|
||||
@ -299,11 +415,11 @@ object Server {
|
||||
val (dao, server): (RunningTriggerDao, Server) = jdbcConfig match {
|
||||
case None =>
|
||||
val dao = InMemoryTriggerDao()
|
||||
val server = new Server(ledgerConfig, restartConfig, dao)
|
||||
val server = new Server(authConfig, ledgerConfig, restartConfig, dao)
|
||||
(dao, server)
|
||||
case Some(c) =>
|
||||
val dao = DbTriggerDao(c)
|
||||
val server = new Server(ledgerConfig, restartConfig, dao)
|
||||
val server = new Server(authConfig, ledgerConfig, restartConfig, dao)
|
||||
val recovery: Either[String, Unit] = for {
|
||||
_ <- dao.initialize
|
||||
packages <- dao.readPackages
|
||||
|
@ -6,6 +6,7 @@ package com.daml.lf.engine.trigger
|
||||
import java.nio.file.{Path, Paths}
|
||||
import java.time.Duration
|
||||
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import com.daml.cliopts
|
||||
import com.daml.platform.services.time.TimeProviderType
|
||||
import scalaz.Show
|
||||
@ -21,6 +22,7 @@ private[trigger] final case class ServiceConfig(
|
||||
httpPort: Int,
|
||||
ledgerHost: String,
|
||||
ledgerPort: Int,
|
||||
authUri: Option[Uri],
|
||||
maxInboundMessageSize: Int,
|
||||
minRestartInterval: FiniteDuration,
|
||||
maxRestartInterval: FiniteDuration,
|
||||
@ -105,6 +107,13 @@ private[trigger] object ServiceConfig {
|
||||
.action((t, c) => c.copy(ledgerPort = t))
|
||||
.text("Ledger port.")
|
||||
|
||||
opt[String]("auth")
|
||||
.optional()
|
||||
.action((t, c) => c.copy(authUri = Some(Uri(t))))
|
||||
.text("Auth middleware URI.")
|
||||
// TODO[AH] Expose once the feature is fully implemented.
|
||||
.hidden()
|
||||
|
||||
opt[Int]("max-inbound-message-size")
|
||||
.action((x, c) => c.copy(maxInboundMessageSize = x))
|
||||
.optional()
|
||||
@ -157,6 +166,7 @@ private[trigger] object ServiceConfig {
|
||||
httpPort = DefaultHttpPort,
|
||||
ledgerHost = null,
|
||||
ledgerPort = 0,
|
||||
authUri = None,
|
||||
maxInboundMessageSize = DefaultMaxInboundMessageSize,
|
||||
minRestartInterval = DefaultMinRestartInterval,
|
||||
maxRestartInterval = DefaultMaxRestartInterval,
|
||||
|
@ -27,6 +27,7 @@ object ServiceMain {
|
||||
def startServer(
|
||||
host: String,
|
||||
port: Int,
|
||||
authConfig: AuthConfig,
|
||||
ledgerConfig: LedgerConfig,
|
||||
restartConfig: TriggerRestartConfig,
|
||||
encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]],
|
||||
@ -38,6 +39,7 @@ object ServiceMain {
|
||||
Server(
|
||||
host,
|
||||
port,
|
||||
authConfig,
|
||||
ledgerConfig,
|
||||
restartConfig,
|
||||
encodedDar,
|
||||
@ -66,6 +68,10 @@ object ServiceMain {
|
||||
case Success(dar) => dar
|
||||
}
|
||||
}
|
||||
val authConfig: AuthConfig = config.authUri match {
|
||||
case None => NoAuth
|
||||
case Some(uri) => AuthMiddleware(uri)
|
||||
}
|
||||
val ledgerConfig =
|
||||
LedgerConfig(
|
||||
config.ledgerHost,
|
||||
@ -83,6 +89,7 @@ object ServiceMain {
|
||||
Server(
|
||||
config.address,
|
||||
config.httpPort,
|
||||
authConfig,
|
||||
ledgerConfig,
|
||||
restartConfig,
|
||||
encodedDar,
|
||||
|
@ -31,7 +31,7 @@ object TriggerRunnerImpl {
|
||||
server: ActorRef[Message],
|
||||
triggerInstance: UUID,
|
||||
party: Party,
|
||||
// TODO(SF, 2020-06-09): Add access token field here in the presence of authentication.
|
||||
token: Option[String],
|
||||
compiledPackages: CompiledPackages,
|
||||
trigger: Trigger,
|
||||
ledgerConfig: LedgerConfig,
|
||||
@ -67,9 +67,7 @@ object TriggerRunnerImpl {
|
||||
commandClient = CommandClientConfiguration.default.copy(
|
||||
defaultDeduplicationTime = config.ledgerConfig.commandTtl),
|
||||
sslContext = None,
|
||||
// TODO(SF, 2020-06-09): In the presence of an authorization
|
||||
// service, get an access token and pass it through here!
|
||||
token = None,
|
||||
token = config.token,
|
||||
maxInboundMessageSize = config.ledgerConfig.maxInboundMessageSize,
|
||||
)
|
||||
|
||||
|
@ -149,7 +149,8 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
|
||||
triggerInstance: UUID,
|
||||
party: String,
|
||||
fullTriggerName: String): Either[String, RunningTrigger] = {
|
||||
Identifier.fromString(fullTriggerName).map(RunningTrigger(triggerInstance, _, Tag(party)))
|
||||
// TODO[AH] Persist the access and refresh token.
|
||||
Identifier.fromString(fullTriggerName).map(RunningTrigger(triggerInstance, _, Tag(party), None))
|
||||
}
|
||||
|
||||
// Drop all tables and other objects associated with the database.
|
||||
|
@ -10,10 +10,15 @@ import com.daml.ledger.api.refinements.ApiTypes.Party
|
||||
import com.daml.lf.data.Ref.Identifier
|
||||
import com.daml.platform.services.time.TimeProviderType
|
||||
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import scala.concurrent.duration.FiniteDuration
|
||||
|
||||
package trigger {
|
||||
|
||||
sealed trait AuthConfig
|
||||
case object NoAuth extends AuthConfig
|
||||
final case class AuthMiddleware(uri: Uri) extends AuthConfig
|
||||
|
||||
case class LedgerConfig(
|
||||
host: String,
|
||||
port: Int,
|
||||
@ -32,7 +37,6 @@ package trigger {
|
||||
triggerInstance: UUID,
|
||||
triggerName: Identifier,
|
||||
triggerParty: Party,
|
||||
// TODO(SF, 2020-0610): Add access token field here in the
|
||||
// presence of authentication.
|
||||
triggerToken: Option[String],
|
||||
)
|
||||
}
|
||||
|
@ -7,14 +7,21 @@ import java.io.File
|
||||
import java.net.InetAddress
|
||||
import java.time.Duration
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.actor.typed.{ActorSystem => TypedActorSystem}
|
||||
import akka.http.scaladsl.Http.ServerBinding
|
||||
import akka.http.scaladsl.model.Uri
|
||||
import akka.http.scaladsl.model.Uri.Path
|
||||
import akka.stream.Materializer
|
||||
import com.daml.bazeltools.BazelRunfiles
|
||||
import com.daml.daml_lf_dev.DamlLf
|
||||
import com.daml.grpc.adapter.ExecutionSequencerFactory
|
||||
import com.daml.jwt.domain.DecodedJwt
|
||||
import com.daml.jwt.{HMAC256Verifier, JwtSigner}
|
||||
import com.daml.ledger.api.auth
|
||||
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload}
|
||||
import com.daml.ledger.api.domain.LedgerId
|
||||
import com.daml.ledger.api.refinements.ApiTypes
|
||||
import com.daml.ledger.api.refinements.ApiTypes.ApplicationId
|
||||
import com.daml.ledger.client.LedgerClient
|
||||
import com.daml.ledger.client.configuration.{
|
||||
@ -24,6 +31,8 @@ import com.daml.ledger.client.configuration.{
|
||||
}
|
||||
import com.daml.lf.archive.Dar
|
||||
import com.daml.lf.data.Ref._
|
||||
import com.daml.oauth.middleware.{Config => MiddlewareConfig, Server => MiddlewareServer}
|
||||
import com.daml.oauth.server.{Config => OAuthConfig, Server => OAuthServer}
|
||||
import com.daml.platform.common.LedgerIdMode
|
||||
import com.daml.platform.sandbox
|
||||
import com.daml.platform.sandbox.SandboxServer
|
||||
@ -40,6 +49,13 @@ import scala.concurrent._
|
||||
import scala.concurrent.duration._
|
||||
import scala.sys.process.Process
|
||||
|
||||
private[trigger] final case class AuthTestConfig(
|
||||
// HMAC256 signature secret.
|
||||
jwtSecret: String,
|
||||
// Grant readAs claims for these parties to the ledger client provided to test cases.
|
||||
parties: List[ApiTypes.Party],
|
||||
)
|
||||
|
||||
object TriggerServiceFixture extends StrictLogging {
|
||||
|
||||
// Use a small initial interval so we can test restart behaviour more easily.
|
||||
@ -50,10 +66,12 @@ object TriggerServiceFixture extends StrictLogging {
|
||||
dars: List[File],
|
||||
encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]],
|
||||
jdbcConfig: Option[JdbcConfig],
|
||||
authTestConfig: Option[AuthTestConfig],
|
||||
)(testFn: (Uri, LedgerClient, Proxy) => Future[A])(
|
||||
implicit mat: Materializer,
|
||||
aesf: ExecutionSequencerFactory,
|
||||
ec: ExecutionContext,
|
||||
system: ActorSystem,
|
||||
pos: source.Position,
|
||||
): Future[A] = {
|
||||
logger.info(s"${pos.fileName}:${pos.lineNumber}: setting up trigger service")
|
||||
@ -77,9 +95,43 @@ object TriggerServiceFixture extends StrictLogging {
|
||||
|
||||
val ledgerId = LedgerId(testName)
|
||||
val applicationId = ApplicationId(testName)
|
||||
val authF: Future[(AuthConfig, () => Future[Unit])] = authTestConfig match {
|
||||
case None => Future((NoAuth, () => Future(())))
|
||||
case Some(AuthTestConfig(secret, _)) =>
|
||||
for {
|
||||
oauth <- OAuthServer.start(
|
||||
OAuthConfig(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = LedgerId.unwrap(ledgerId),
|
||||
// TODO[AH] Choose application ID, see https://github.com/digital-asset/daml/issues/7671
|
||||
applicationId = None,
|
||||
jwtSecret = secret,
|
||||
))
|
||||
serverUri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(oauth.localAddress.getHostString, oauth.localAddress.getPort)
|
||||
middleware <- MiddlewareServer.start(
|
||||
MiddlewareConfig(
|
||||
port = Port.Dynamic,
|
||||
oauthAuth = serverUri.withPath(Path./("authorize")),
|
||||
oauthToken = serverUri.withPath(Path./("token")),
|
||||
clientId = "oauth-middleware-id",
|
||||
clientSecret = "oauth-middleware-secret",
|
||||
))
|
||||
middlewareUri = Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(middleware.localAddress.getHostString, middleware.localAddress.getPort)
|
||||
cleanup = () =>
|
||||
for {
|
||||
_ <- middleware.unbind()
|
||||
_ <- oauth.unbind()
|
||||
} yield ()
|
||||
} yield (AuthMiddleware(middlewareUri), cleanup)
|
||||
}
|
||||
val ledgerF = for {
|
||||
(_, toxiproxyClient) <- toxiproxyF
|
||||
ledger <- Future(new SandboxServer(ledgerConfig(Port.Dynamic, dars, ledgerId), mat))
|
||||
ledger <- Future(
|
||||
new SandboxServer(ledgerConfig(Port.Dynamic, dars, ledgerId, authTestConfig), mat))
|
||||
ledgerPort <- ledger.portF
|
||||
ledgerProxyPort = LockedFreePort.find()
|
||||
ledgerProxy = toxiproxyClient.createProxy(
|
||||
@ -97,12 +149,13 @@ object TriggerServiceFixture extends StrictLogging {
|
||||
client <- LedgerClient.singleHost(
|
||||
host.getHostName,
|
||||
ledgerPort.value,
|
||||
clientConfig(applicationId),
|
||||
clientConfig(applicationId, authTestConfig),
|
||||
)
|
||||
} yield client
|
||||
|
||||
// Configure the service with the ledger's *proxy* port.
|
||||
val serviceF: Future[(ServerBinding, TypedActorSystem[Message])] = for {
|
||||
(authConfig, _) <- authF
|
||||
(_, _, ledgerProxyPort, _) <- ledgerF
|
||||
ledgerConfig = LedgerConfig(
|
||||
host.getHostName,
|
||||
@ -119,6 +172,7 @@ object TriggerServiceFixture extends StrictLogging {
|
||||
service <- ServiceMain.startServer(
|
||||
host.getHostName,
|
||||
servicePort.port.value,
|
||||
authConfig,
|
||||
ledgerConfig,
|
||||
restartConfig,
|
||||
encodedDar,
|
||||
@ -152,7 +206,10 @@ object TriggerServiceFixture extends StrictLogging {
|
||||
proc.destroy()
|
||||
proc.exitValue() // destroy is async
|
||||
})
|
||||
result <- (ta.failed.toOption orElse se orElse le orElse te)
|
||||
ae <- optErr(authF.flatMap {
|
||||
case (_, cleanup) => cleanup()
|
||||
})
|
||||
result <- (ta.failed.toOption orElse se orElse le orElse te orElse ae)
|
||||
.cata(Future.failed, Future fromTry ta)
|
||||
} yield result
|
||||
}
|
||||
@ -164,24 +221,43 @@ object TriggerServiceFixture extends StrictLogging {
|
||||
private def ledgerConfig(
|
||||
ledgerPort: Port,
|
||||
dars: List[File],
|
||||
ledgerId: LedgerId
|
||||
ledgerId: LedgerId,
|
||||
authTestConfig: Option[AuthTestConfig]
|
||||
): SandboxConfig =
|
||||
sandbox.DefaultConfig.copy(
|
||||
port = ledgerPort,
|
||||
damlPackages = dars,
|
||||
timeProviderType = Some(TimeProviderType.Static),
|
||||
ledgerIdMode = LedgerIdMode.Static(ledgerId),
|
||||
authService = None,
|
||||
authService = for {
|
||||
cfg <- authTestConfig
|
||||
verifier <- HMAC256Verifier(cfg.jwtSecret).toOption
|
||||
} yield auth.AuthServiceJWT(verifier)
|
||||
)
|
||||
|
||||
private def clientConfig[A](
|
||||
applicationId: ApplicationId,
|
||||
token: Option[String] = None): LedgerClientConfiguration =
|
||||
authTestConfig: Option[AuthTestConfig]): LedgerClientConfiguration =
|
||||
LedgerClientConfiguration(
|
||||
applicationId = ApplicationId.unwrap(applicationId),
|
||||
ledgerIdRequirement = LedgerIdRequirement.none,
|
||||
commandClient = CommandClientConfiguration.default,
|
||||
sslContext = None,
|
||||
token = token,
|
||||
token = for {
|
||||
cfg <- authTestConfig
|
||||
header = """{"alg": "HS256", "typ": "JWT"}"""
|
||||
payload = AuthServiceJWTPayload(
|
||||
ledgerId = None,
|
||||
applicationId = None,
|
||||
participantId = None,
|
||||
exp = None,
|
||||
admin = true,
|
||||
actAs = cfg.parties.map(ApiTypes.Party.unwrap),
|
||||
readAs = List(),
|
||||
)
|
||||
jwt <- JwtSigner.HMAC256
|
||||
.sign(DecodedJwt(header, AuthServiceJWTCodec.compactPrint(payload)), cfg.jwtSecret)
|
||||
.toOption
|
||||
} yield jwt.value,
|
||||
)
|
||||
}
|
||||
|
@ -35,15 +35,82 @@ import com.daml.ledger.api.v1.transaction_filter.{Filters, InclusiveFilters, Tra
|
||||
import com.daml.ledger.client.LedgerClient
|
||||
import com.daml.lf.engine.trigger.dao.DbTriggerDao
|
||||
import com.daml.testing.postgresql.PostgresAroundAll
|
||||
import com.typesafe.scalalogging.StrictLogging
|
||||
import eu.rekawek.toxiproxy._
|
||||
|
||||
abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually with Matchers {
|
||||
import scala.collection.concurrent.TrieMap
|
||||
import scala.util.Success
|
||||
|
||||
/**
|
||||
* A test-fixture that persists cookies between http requests for each test-case.
|
||||
*/
|
||||
trait HttpCookies extends BeforeAndAfterEach { this: Suite =>
|
||||
private val cookieJar = TrieMap[String, String]()
|
||||
|
||||
override protected def afterEach(): Unit = {
|
||||
try super.afterEach()
|
||||
finally cookieJar.clear()
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a Cookie header for the currently stored cookies and performs the given http request.
|
||||
*/
|
||||
def httpRequest(request: HttpRequest)(
|
||||
implicit system: ActorSystem,
|
||||
ec: ExecutionContext): Future[HttpResponse] = {
|
||||
Http()
|
||||
.singleRequest {
|
||||
if (cookieJar.nonEmpty) {
|
||||
val cookies = headers.Cookie(values = cookieJar.to[Seq]: _*)
|
||||
request.addHeader(cookies)
|
||||
} else {
|
||||
request
|
||||
}
|
||||
}
|
||||
.andThen {
|
||||
case Success(resp) =>
|
||||
resp.headers.foreach {
|
||||
case headers.`Set-Cookie`(cookie) =>
|
||||
cookieJar.update(cookie.name, cookie.value)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Same as [[httpRequest]] but will follow redirections.
|
||||
*/
|
||||
def httpRequestFollow(request: HttpRequest, maxRedirections: Int = 10)(
|
||||
implicit system: ActorSystem,
|
||||
ec: ExecutionContext): Future[HttpResponse] = {
|
||||
httpRequest(request).flatMap {
|
||||
case resp @ HttpResponse(StatusCodes.Redirection(_), _, _, _) =>
|
||||
if (maxRedirections == 0) {
|
||||
throw new RuntimeException("Too many redirections")
|
||||
} else {
|
||||
val uri = resp.header[headers.Location].get.uri
|
||||
httpRequestFollow(HttpRequest(uri = uri), maxRedirections - 1)
|
||||
}
|
||||
case resp => Future(resp)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
abstract class AbstractTriggerServiceTest
|
||||
extends AsyncFlatSpec
|
||||
with HttpCookies
|
||||
with Eventually
|
||||
with Matchers
|
||||
with StrictLogging {
|
||||
|
||||
import AbstractTriggerServiceTest.CompatAssertion
|
||||
|
||||
// Abstract member for testing with and without a database
|
||||
def jdbcConfig: Option[JdbcConfig]
|
||||
|
||||
// Abstract member for testing with and without authentication/authorization
|
||||
def authTestConfig: Option[AuthTestConfig]
|
||||
|
||||
// Default retry config for `eventually`
|
||||
override implicit def patienceConfig: PatienceConfig =
|
||||
PatienceConfig(timeout = scaled(Span(15, Seconds)), interval = scaled(Span(1, Seconds)))
|
||||
@ -77,7 +144,12 @@ abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually
|
||||
|
||||
def withTriggerService[A](encodedDar: Option[Dar[(PackageId, DamlLf.ArchivePayload)]])(
|
||||
testFn: (Uri, LedgerClient, Proxy) => Future[A])(implicit pos: source.Position): Future[A] =
|
||||
TriggerServiceFixture.withTriggerService(testId, List(darPath), encodedDar, jdbcConfig)(testFn)
|
||||
TriggerServiceFixture.withTriggerService(
|
||||
testId,
|
||||
List(darPath),
|
||||
encodedDar,
|
||||
jdbcConfig,
|
||||
authTestConfig)(testFn)
|
||||
|
||||
def startTrigger(uri: Uri, triggerName: String, party: Party): Future[HttpResponse] = {
|
||||
val req = HttpRequest(
|
||||
@ -88,7 +160,7 @@ abstract class AbstractTriggerServiceTest extends AsyncFlatSpec with Eventually
|
||||
s"""{"triggerName": "$triggerName", "party": "$party"}"""
|
||||
)
|
||||
)
|
||||
Http().singleRequest(req)
|
||||
httpRequestFollow(req)
|
||||
}
|
||||
|
||||
def listTriggers(uri: Uri, party: Party): Future[HttpResponse] = {
|
||||
@ -446,6 +518,7 @@ object AbstractTriggerServiceTest {
|
||||
class TriggerServiceTestInMem extends AbstractTriggerServiceTest {
|
||||
|
||||
override def jdbcConfig: Option[JdbcConfig] = None
|
||||
override def authTestConfig: Option[AuthTestConfig] = None
|
||||
|
||||
}
|
||||
|
||||
@ -456,6 +529,7 @@ class TriggerServiceTestWithDb
|
||||
with PostgresAroundAll {
|
||||
|
||||
override def jdbcConfig: Option[JdbcConfig] = Some(jdbcConfig_)
|
||||
override def authTestConfig: Option[AuthTestConfig] = None
|
||||
|
||||
// Lazy because the postgresDatabase is only available once the tests start
|
||||
private lazy val jdbcConfig_ = JdbcConfig(postgresDatabase.url, "operator", "password")
|
||||
@ -530,3 +604,16 @@ class TriggerServiceTestWithDb
|
||||
} yield succeed)
|
||||
|
||||
}
|
||||
|
||||
// Tests for auth mode only go here
|
||||
class TriggerServiceTestAuth extends AbstractTriggerServiceTest {
|
||||
|
||||
override def jdbcConfig: Option[JdbcConfig] = None
|
||||
override def authTestConfig: Option[AuthTestConfig] =
|
||||
Some(
|
||||
AuthTestConfig(
|
||||
jwtSecret = "secret",
|
||||
parties = List(alice, bob),
|
||||
))
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user