Handle token expiry in trigger service (#8037)

* Enable adjustable clock in trigger service tests

changelog_begin
changelog_end

* Test user side token expiry

* Test service side token refresh

* Use AccessToken wrapper in TriggerRunnerImpl

* Store refresh token in trigger DB

* add refresh token to trigger runner config

* TriggerTokenExpired message to server

* TriggerTokenRefresh message to server

* refresh trigger token and update db

* Restart trigger with fresh token

* Test second token expiry

* Refresh token on running trigger

changelog_begin
* [Triggers] UNAUTHENTICATED errors will now terminate the trigger.
  These errors are no longer available for handling in the trigger DAML
  code. Instead, they are forwarded to the trigger service for handling,
  e.g. access token refresh.
changelog_end

* todo note

* Move triggerRunnerName and getRunner into object

* Factor out token refresh

* Factor out getActiveContracts

* factor out create command

* Add logging to token refresh

* Handle token expiry in TriggerRunner

TriggerRunnerImpl throws a dedicated exception when it fails on an
expired access token (any unauthenticated error to be precise).
The TriggerRunner supervisor reacts to this child failure by
requesting a token refresh and restart on the trigger server and
stopping itself.
The trigger server requests a new access and refresh token on the auth
middleware and restarts the trigger.

This works around an issue with actor supervisors in akka-actor-typed.
A stop supervisor wrapped within a restart supervisor will not cause a
stop as expected. Instead, the restart supervisor will trigger as well
and restart the actor. The work around uses a custom behavior
interceptor to emulate the appropriate stop supervisors as closely as
possible. We cannot properly emulate ChildFailed signals this way, so
we use dedicated messages intead.

* throw --> Future.failedo

* getOrFail helper

Co-authored-by: Andreas Herrmann <andreas.herrmann@tweag.io>
This commit is contained in:
Andreas Herrmann 2020-12-02 17:17:45 +01:00 committed by GitHub
parent 6b7f714c4d
commit 8bceeb13de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 440 additions and 89 deletions

View File

@ -535,7 +535,9 @@ class Runner(
val f: Future[Empty] = client.commandClient
.submitSingleCommand(req)
f.map(_ => None).recover {
case s: StatusRuntimeException =>
case s: StatusRuntimeException if s.getStatus != io.grpc.Status.UNAUTHENTICATED =>
// Do not capture UNAUTHENTICATED errors.
// The access token may be expired, let the trigger runner handle token refresh.
Some(SingleCommandFailure(req.getCommands.commandId, s))
// any other error will cause the trigger's stream to fail
}

View File

@ -117,6 +117,7 @@ da_scala_library(
"//ledger/participant-state",
"//ledger/sandbox-classic",
"//ledger/sandbox-common",
"//libs-scala/adjustable-clock",
"//libs-scala/ports",
"//libs-scala/postgresql-testing",
"//libs-scala/resources",
@ -125,6 +126,7 @@ da_scala_library(
"//triggers/service/auth:oauth-middleware",
"//triggers/service/auth:oauth-test-server",
"@maven//:ch_qos_logback_logback_classic",
"@maven//:com_auth0_java_jwt",
"@maven//:com_typesafe_akka_akka_actor_typed_2_12",
"@maven//:com_typesafe_akka_akka_http_core_2_12",
"@maven//:com_typesafe_akka_akka_parsing_2_12",
@ -181,6 +183,7 @@ da_scala_test_suite(
"//ledger/ledger-resources",
"//ledger/sandbox-classic",
"//ledger/sandbox-common",
"//libs-scala/adjustable-clock",
"//libs-scala/flyway-testing",
"//libs-scala/ports",
"//libs-scala/postgresql-testing",

View File

@ -32,7 +32,7 @@ import scala.util.{Failure, Success, Try}
// request to /authorize are immediately redirected to the redirect_uri.
class Server(config: Config) {
private val jwtHeader = """{"alg": "HS256", "typ": "JWT"}"""
private val tokenLifetimeSeconds = 24 * 60 * 60
val tokenLifetimeSeconds = 24 * 60 * 60
// None indicates that all parties are authorized, Some that only the given set of parties is authorized.
private var authorizedParties: Option[Set[Party]] = config.parties

View File

@ -0,0 +1,5 @@
-- Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
-- SPDX-License-Identifier: Apache-2.0
-- Add refresh token to running trigger table
alter table running_triggers add column refresh_token text;

View File

@ -0,0 +1 @@
b6b0f69413a5977f86066c7ed2b0e11c3151f17e147f2554f320d7fa6fa1c338

View File

@ -16,6 +16,7 @@ import akka.actor.typed.{ActorRef, Behavior, Scheduler}
import akka.http.scaladsl.Http
import akka.http.scaladsl.Http.ServerBinding
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
import akka.http.scaladsl.marshalling.Marshal
import akka.http.scaladsl.model.Uri.Path
import akka.http.scaladsl.model._
import akka.http.scaladsl.server.Directives._
@ -36,7 +37,7 @@ import com.daml.lf.data.Ref.{Identifier, PackageId}
import com.daml.lf.engine._
import com.daml.lf.engine.trigger.Request.StartParams
import com.daml.lf.engine.trigger.Response._
import com.daml.lf.engine.trigger.Tagged.AccessToken
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
import com.daml.lf.engine.trigger.TriggerRunner._
import com.daml.lf.engine.trigger.dao._
import com.daml.oauth.middleware.Request.Claims
@ -146,10 +147,16 @@ class Server(
// Note that this does not yet start the trigger.
private def addNewTrigger(
config: TriggerConfig,
token: Option[AccessToken]
auth: Option[Authorization],
)(implicit ec: ExecutionContext): Future[Either[String, (Trigger, RunningTrigger)]] = {
val runningTrigger =
RunningTrigger(config.instance, config.name, config.party, config.applicationId, token)
RunningTrigger(
config.instance,
config.name,
config.party,
config.applicationId,
auth.map(_.accessToken),
auth.flatMap(_.refreshToken))
// Validate trigger id before persisting to DB
Trigger.fromIdentifier(compiledPackages, runningTrigger.triggerName) match {
case Left(value) => Future.successful(Left(value))
@ -216,6 +223,8 @@ class Server(
// TODO[AH] Make sure this is bounded in size.
private val authCallbacks: TrieMap[UUID, Route] = TrieMap()
case class Authorization(accessToken: AccessToken, refreshToken: Option[RefreshToken])
// This directive requires authorization for the given claims via the auth middleware, if configured.
// If no auth middleware is configured, then the request will proceed without attempting authorization.
//
@ -225,14 +234,14 @@ class Server(
// to proceed once the login flow completed and authentication succeeded.
private def authorize(claims: AuthRequest.Claims)(
implicit ec: ExecutionContext,
system: ActorSystem): Directive1[Option[AccessToken]] =
system: ActorSystem): Directive1[Option[Authorization]] =
authConfig match {
case NoAuth => provide(None)
case AuthMiddleware(authUri) =>
// Attempt to obtain the access token from the middleware's /auth endpoint.
// Forwards the current request's cookies.
// Fails if the response is not OK or Unauthorized.
def auth: Directive1[Option[AccessToken]] = {
def auth: Directive1[Option[Authorization]] = {
val uri = authUri
.withPath(Path./("auth"))
.withQuery(AuthRequest.Auth(claims).toQuery)
@ -241,7 +250,10 @@ class Server(
onSuccess(Http().singleRequest(HttpRequest(uri = uri, headers = cookies))).flatMap {
case HttpResponse(StatusCodes.OK, _, entity, _) =>
onSuccess(Unmarshal(entity).to[AuthResponse.Authorize]).map { auth =>
Some(AccessToken(auth.accessToken)): Option[AccessToken]
Some(
Authorization(
AccessToken(auth.accessToken),
RefreshToken.subst(auth.refreshToken))): Option[Authorization]
}
case HttpResponse(StatusCodes.Unauthorized, _, _, _) =>
provide(None)
@ -259,7 +271,7 @@ class Server(
Directive { inner =>
auth {
// Authorization successful - pass token to continuation
case Some(token) => inner(Tuple1(Some(token)))
case Some(authorization) => inner(Tuple1(Some(authorization)))
// Authorization failed - login and retry on callback request.
case None => { ctx =>
val requestId = UUID.randomUUID()
@ -271,10 +283,10 @@ class Server(
// TODO[AH] Add WWW-Authenticate header
complete(errorResponse(StatusCodes.Unauthorized))
}
case Some(token) =>
case Some(authorization) =>
// Authorization successful after login - use old request context and pass token to continuation.
mapRequestContext(_ => ctx) {
inner(Tuple1(Some(token)))
inner(Tuple1(Some(authorization)))
}
}
}
@ -298,7 +310,7 @@ class Server(
// If the trigger does not exist, then the request will also proceed without attempting authorization.
private def authorizeForTrigger(uuid: UUID, readOnly: Boolean = false)(
implicit ec: ExecutionContext,
system: ActorSystem): Directive1[Option[AccessToken]] =
system: ActorSystem): Directive1[Option[Authorization]] =
authConfig match {
case NoAuth => provide(None)
case AuthMiddleware(_) =>
@ -349,9 +361,9 @@ class Server(
applicationId = Some(config.applicationId))
// TODO[AH] Why do we need to pass ec, system explicitly?
authorize(claims)(ec, system) {
token =>
auth =>
val instOrErr: Future[Either[String, JsValue]] =
addNewTrigger(config, token)
addNewTrigger(config, auth)
.flatMap {
case Left(value) => Future.successful(Left(value))
case Right((trigger, runningTrigger)) =>
@ -474,9 +486,26 @@ object Server {
replyTo: ActorRef[StatusReply[Unit]])
extends Message
final case class RestartTrigger(
trigger: Trigger,
runningTrigger: RunningTrigger,
compiledPackages: CompiledPackages)
extends Message
final case class GetRunner(replyTo: ActorRef[Option[ActorRef[TriggerRunner.Message]]], uuid: UUID)
extends Message
final case class TriggerTokenRefreshFailed(triggerInstance: UUID, cause: Throwable)
extends Message
// Messages passed to the server from a TriggerRunner
final case class TriggerTokenExpired(
triggerInstance: UUID,
trigger: Trigger,
compiledPackages: CompiledPackages)
extends Message
// Messages passed to the server from a TriggerRunnerImpl
final case class TriggerStarting(triggerInstance: UUID) extends Message
@ -540,37 +569,94 @@ object Server {
def logTriggerStarted(m: TriggerStarted): Unit =
server.logTriggerStatus(m.triggerInstance, "running")
def startTrigger(req: StartTrigger): Unit = {
val runningTrigger = req.runningTrigger
Try(
ctx.spawn(
TriggerRunner(
new TriggerRunner.Config(
ctx.self,
runningTrigger.triggerInstance,
runningTrigger.triggerParty,
runningTrigger.triggerApplicationId,
AccessToken.unsubst(runningTrigger.triggerToken),
req.compiledPackages,
req.trigger,
ledgerConfig,
restartConfig
),
runningTrigger.triggerInstance.toString
def spawnTrigger(
trigger: Trigger,
runningTrigger: RunningTrigger,
compiledPackages: CompiledPackages): ActorRef[TriggerRunner.Message] =
ctx.spawn(
TriggerRunner(
new TriggerRunner.Config(
ctx.self,
runningTrigger.triggerInstance,
runningTrigger.triggerParty,
runningTrigger.triggerApplicationId,
runningTrigger.triggerAccessToken,
runningTrigger.triggerRefreshToken,
compiledPackages,
trigger,
ledgerConfig,
restartConfig
),
triggerRunnerName(runningTrigger.triggerInstance)
)) match {
runningTrigger.triggerInstance.toString
),
triggerRunnerName(runningTrigger.triggerInstance)
)
def startTrigger(req: StartTrigger): Unit = {
Try(spawnTrigger(req.trigger, req.runningTrigger, req.compiledPackages)) match {
case Failure(exception) => req.replyTo ! StatusReply.error(exception)
case Success(_) => req.replyTo ! StatusReply.success(())
}
}
def restartTrigger(req: RestartTrigger): Unit = {
val _ = spawnTrigger(req.trigger, req.runningTrigger, req.compiledPackages)
}
def getRunner(req: GetRunner) = {
req.replyTo ! ctx
.child(triggerRunnerName(req.uuid))
.asInstanceOf[Option[ActorRef[TriggerRunner.Message]]]
}
def refreshAccessToken(triggerInstance: UUID): Future[RunningTrigger] = {
def getOrFail[T](result: Option[T], ex: => Throwable): Future[T] = result match {
case Some(value) => Future.successful(value)
case None => Future.failed(ex)
}
for {
// Lookup running trigger
runningTrigger <- dao
.getRunningTrigger(triggerInstance)
.flatMap(getOrFail(_, new RuntimeException(s"Unknown trigger $triggerInstance")))
// Request a token refresh
authUri <- getOrFail(authConfig match {
case AuthMiddleware(uri) => Some(uri)
case _ => None
}, new RuntimeException("Cannot refresh token without authorization service"))
refreshToken <- getOrFail(
runningTrigger.triggerRefreshToken,
new RuntimeException(s"No refresh token for $triggerInstance"))
requestEntity <- {
import AuthJsonProtocol._
Marshal(AuthRequest.Refresh(RefreshToken.unwrap(refreshToken)))
.to[RequestEntity]
}
response <- Http().singleRequest(
HttpRequest(
method = HttpMethods.POST,
uri = authUri.withPath(Path./("refresh")),
entity = requestEntity,
))
authorize <- response.status match {
case StatusCodes.OK =>
import AuthJsonProtocol._
Unmarshal(response.entity).to[AuthResponse.Authorize]
case status =>
Unmarshal(response).to[String].flatMap { msg =>
Future.failed(new RuntimeException(s"Failed to refresh token ($status): $msg"))
}
}
// Update the tokens in the trigger db
accessToken = AccessToken(authorize.accessToken)
refreshToken = RefreshToken.subst(authorize.refreshToken)
_ <- dao.updateRunningTriggerToken(triggerInstance, accessToken, refreshToken)
} yield
runningTrigger
.copy(triggerAccessToken = Some(accessToken), triggerRefreshToken = refreshToken)
}
// The server running state.
def running(binding: ServerBinding): Behavior[Message] =
Behaviors
@ -578,6 +664,9 @@ object Server {
case req: StartTrigger =>
startTrigger(req)
Behaviors.same
case req: RestartTrigger =>
restartTrigger(req)
Behaviors.same
case req: GetRunner =>
getRunner(req)
Behaviors.same
@ -607,6 +696,18 @@ object Server {
// the management of a supervision strategy).
Behaviors.same
case TriggerTokenExpired(triggerInstance, trigger, compiledPackages) =>
ctx.log.info(s"Updating token for $triggerInstance")
ctx.pipeToSelf(refreshAccessToken(triggerInstance)) {
case Success(runningTrigger) =>
RestartTrigger(trigger, runningTrigger, compiledPackages)
case Failure(cause) => TriggerTokenRefreshFailed(triggerInstance, cause)
}
Behaviors.same
case TriggerTokenRefreshFailed(triggerInstance, cause) =>
server.logTriggerStatus(triggerInstance, s"stopped: failed to refresh token: $cause")
Behaviors.same
case GetServerBinding(replyTo) =>
replyTo ! binding
Behaviors.same
@ -664,12 +765,18 @@ object Server {
case req: StartTrigger =>
startTrigger(req)
Behaviors.same
case req: RestartTrigger =>
restartTrigger(req)
Behaviors.same
case req: GetRunner =>
getRunner(req)
Behaviors.same
case _: TriggerInitializationFailure | _: TriggerRuntimeFailure =>
Behaviors.unhandled
case _: TriggerTokenExpired | _: TriggerTokenRefreshFailed =>
Behaviors.unhandled
}
// The server binding is a future that on completion will be piped

View File

@ -3,17 +3,27 @@
package com.daml.lf.engine.trigger
import akka.actor.typed.SupervisorStrategy._
import akka.actor.typed.SupervisorStrategy.restartWithBackoff
import akka.actor.typed.scaladsl.Behaviors
import akka.actor.typed.{ActorRef, Behavior, PostStop}
import akka.actor.typed.{
ActorRef,
Behavior,
BehaviorInterceptor,
PostStop,
Signal,
TypedActorContext
}
import akka.stream.Materializer
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.logging.LoggingContextOf.{label, newLoggingContext}
import com.daml.logging.{ContextualizedLogger}
import com.daml.logging.ContextualizedLogger
import spray.json._
import scala.util.control.Exception.Catcher
class InitializationHalted(s: String) extends Exception(s) {}
class InitializationException(s: String) extends Exception(s) {}
case class UnauthenticatedException(s: String) extends Exception(s) {}
object TriggerRunner {
private val logger = ContextualizedLogger.get(this.getClass)
@ -23,6 +33,7 @@ object TriggerRunner {
sealed trait Message
final case object Stop extends Message
final case class Status(replyTo: ActorRef[TriggerStatus]) extends Message
private final case class Unauthenticated(cause: UnauthenticatedException) extends Message
sealed trait TriggerStatus
final case object QueryingACS extends TriggerStatus
@ -43,6 +54,54 @@ object TriggerRunner {
}
}
// TODO[AH] Workaround for https://github.com/akka/akka/issues/29841.
// Remove once fixed upstream.
private class Interceptor(parent: ActorRef[TriggerRunner.Message])
extends BehaviorInterceptor[TriggerRunnerImpl.Message, TriggerRunnerImpl.Message] {
private def handleException(ctx: TypedActorContext[TriggerRunnerImpl.Message])
: Catcher[Behavior[TriggerRunnerImpl.Message]] = {
case e: InitializationHalted => {
// This should be a stop supervisor nested under the restart supervisor.
ctx.asScala.log.info(s"Supervisor saw failure ${e.getMessage} - stopping")
Behaviors.stopped
}
case e: UnauthenticatedException => {
// This should be a stop supervisor nested under the restart supervisor.
// The TriggerRunner should receive a ChildFailed signal when watching TriggerRunnerImpl.
// This cannot be emulated outside the akka-actor-typed implementation, so we use a dedicated message instead.
ctx.asScala.log.info(s"Supervisor saw failure ${e.getMessage} - stopping")
parent ! Unauthenticated(e)
Behaviors.stopped
}
}
override def aroundStart(
ctx: TypedActorContext[TriggerRunnerImpl.Message],
target: BehaviorInterceptor.PreStartTarget[TriggerRunnerImpl.Message])
: Behavior[TriggerRunnerImpl.Message] = {
try {
target.start(ctx)
} catch handleException(ctx)
}
override def aroundReceive(
ctx: TypedActorContext[TriggerRunnerImpl.Message],
msg: TriggerRunnerImpl.Message,
target: BehaviorInterceptor.ReceiveTarget[TriggerRunnerImpl.Message])
: Behavior[TriggerRunnerImpl.Message] = {
try {
target(ctx, msg)
} catch handleException(ctx)
}
override def aroundSignal(
ctx: TypedActorContext[TriggerRunnerImpl.Message],
signal: Signal,
target: BehaviorInterceptor.SignalTarget[TriggerRunnerImpl.Message])
: Behavior[TriggerRunnerImpl.Message] = {
try {
target(ctx, signal)
} catch handleException(ctx)
}
}
def apply(config: Config, name: String)(
implicit esf: ExecutionSequencerFactory,
mat: Materializer): Behavior[TriggerRunner.Message] =
@ -57,8 +116,7 @@ object TriggerRunner {
Behaviors
.supervise(
Behaviors
.supervise(TriggerRunnerImpl(config))
.onFailure[InitializationHalted](stop)
.intercept(() => new Interceptor(ctx.self))(TriggerRunnerImpl(config))
)
.onFailure(
restartWithBackoff(
@ -75,6 +133,14 @@ object TriggerRunner {
Behaviors.same
case Stop =>
Behaviors.stopped // Automatically stops the child actor if running.
case Unauthenticated(cause) =>
logger.warn(
s"Trigger was unauthenticated - requesting token refresh: ${cause.getMessage}")
config.server ! Server.TriggerTokenExpired(
config.triggerInstance,
config.trigger,
config.compiledPackages)
Behaviors.stopped
}
.receiveSignal {
case (_, PostStop) =>

View File

@ -25,6 +25,7 @@ import scalaz.syntax.tag._
import scala.concurrent.{ExecutionContext, Future}
import scala.util.{Failure, Success}
import TriggerRunner.{QueryingACS, Running, TriggerStatus}
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
object TriggerRunnerImpl {
@ -33,7 +34,8 @@ object TriggerRunnerImpl {
triggerInstance: UUID,
party: Party,
applicationId: ApplicationId,
token: Option[String],
accessToken: Option[AccessToken],
refreshToken: Option[RefreshToken],
compiledPackages: CompiledPackages,
trigger: Trigger,
ledgerConfig: LedgerConfig,
@ -69,7 +71,7 @@ object TriggerRunnerImpl {
commandClient = CommandClientConfiguration.default.copy(
defaultDeduplicationTime = config.ledgerConfig.commandTtl),
sslContext = None,
token = config.token,
token = AccessToken.unsubst(config.accessToken),
maxInboundMessageSize = config.ledgerConfig.maxInboundMessageSize,
)
@ -80,6 +82,9 @@ object TriggerRunnerImpl {
case Status(replyTo) =>
replyTo ! QueryingACS
Behaviors.same
case QueryACSFailed(cause: io.grpc.StatusRuntimeException)
if cause.getStatus == io.grpc.Status.UNAUTHENTICATED =>
throw new UnauthenticatedException(s"Querying ACS failed: ${cause.toString}")
case QueryACSFailed(cause) =>
// Report the failure to the server.
config.server ! Server.TriggerInitializationFailure(triggerInstance, cause.toString)
@ -138,6 +143,9 @@ object TriggerRunnerImpl {
case Status(replyTo) =>
replyTo ! Running
Behaviors.same
case Failed(cause: io.grpc.StatusRuntimeException)
if cause.getStatus == io.grpc.Status.UNAUTHENTICATED =>
throw new UnauthenticatedException(s"Querying ACS failed: ${cause.toString}")
case Failed(cause) =>
// Report the failure to the server.
config.server ! Server.TriggerRuntimeFailure(triggerInstance, cause.toString)

View File

@ -23,7 +23,7 @@ import doobie.{Fragment, Put, Transactor}
import scalaz.Tag
import java.io.{Closeable, IOException}
import com.daml.lf.engine.trigger.Tagged.AccessToken
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
import javax.sql.DataSource
import scala.concurrent.{ExecutionContext, Future}
@ -83,6 +83,10 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
implicit val accessTokenGet: Get[AccessToken] = Tag.subst(implicitly[Get[String]])
implicit val refreshTokenPut: Put[RefreshToken] = Tag.subst(implicitly[Put[String]])
implicit val refreshTokenGet: Get[RefreshToken] = Tag.subst(implicitly[Get[String]])
implicit val identifierPut: Put[Identifier] = implicitly[Put[String]].contramap(_.toString)
implicit val identifierGet: Get[Identifier] =
@ -105,24 +109,37 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
private def insertRunningTrigger(t: RunningTrigger): ConnectionIO[Unit] = {
val insert: Fragment = sql"""
insert into running_triggers
(trigger_instance, trigger_party, full_trigger_name, access_token, application_id)
(trigger_instance, trigger_party, full_trigger_name, access_token, refresh_token, application_id)
values
(${t.triggerInstance}, ${t.triggerParty}, ${t.triggerName}, ${t.triggerToken}, ${t.triggerApplicationId})
(${t.triggerInstance}, ${t.triggerParty}, ${t.triggerName}, ${t.triggerAccessToken}, ${t.triggerRefreshToken}, ${t.triggerApplicationId})
"""
insert.update.run.void
}
private def queryRunningTrigger(triggerInstance: UUID): ConnectionIO[Option[RunningTrigger]] = {
val select: Fragment = sql"""
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token from running_triggers
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token, refresh_token from running_triggers
where trigger_instance = $triggerInstance
"""
select
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken])]
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken], Option[RefreshToken])]
.map(RunningTrigger.tupled)
.option
}
private def setRunningTriggerToken(
triggerInstance: UUID,
accessToken: AccessToken,
refreshToken: Option[RefreshToken]) = {
val update: Fragment =
sql"""
update running_triggers
set access_token = $accessToken, refresh_token = $refreshToken
where trigger_instance = $triggerInstance
"""
update.update.run.void
}
// trigger_instance is the primary key on running_triggers so this deletes
// at most one row. Return whether or not it deleted.
private def deleteRunningTrigger(triggerInstance: UUID): ConnectionIO[Boolean] = {
@ -170,10 +187,10 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
private def selectAllTriggers: ConnectionIO[Vector[RunningTrigger]] = {
val select: Fragment = sql"""
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token from running_triggers order by trigger_instance
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token, refresh_token from running_triggers order by trigger_instance
"""
select
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken])]
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken], Option[RefreshToken])]
.map(RunningTrigger.tupled)
.to[Vector]
}
@ -206,6 +223,12 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
implicit ec: ExecutionContext): Future[Option[RunningTrigger]] =
run(queryRunningTrigger(triggerInstance))
override def updateRunningTriggerToken(
triggerInstance: UUID,
accessToken: AccessToken,
refreshToken: Option[RefreshToken])(implicit ec: ExecutionContext): Future[Unit] =
run(setRunningTriggerToken(triggerInstance, accessToken, refreshToken))
override def removeRunningTrigger(triggerInstance: UUID)(
implicit ec: ExecutionContext): Future[Boolean] =
run(deleteRunningTrigger(triggerInstance))

View File

@ -10,6 +10,7 @@ import com.daml.ledger.api.refinements.ApiTypes.Party
import com.daml.lf.archive.Dar
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.engine.trigger.RunningTrigger
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
import scala.concurrent.{ExecutionContext, Future}
@ -29,6 +30,18 @@ final class InMemoryTriggerDao extends RunningTriggerDao {
triggers.get(triggerInstance)
}
override def updateRunningTriggerToken(
triggerInstance: UUID,
accessToken: AccessToken,
refreshToken: Option[RefreshToken])(implicit ec: ExecutionContext): Future[Unit] = Future {
triggers.get(triggerInstance) match {
case Some(t) =>
triggers += (triggerInstance -> t
.copy(triggerAccessToken = Some(accessToken), triggerRefreshToken = refreshToken))
case None => ()
}
}
override def removeRunningTrigger(triggerInstance: UUID)(
implicit ec: ExecutionContext): Future[Boolean] = Future {
triggers.get(triggerInstance) match {

View File

@ -11,6 +11,7 @@ import com.daml.ledger.api.refinements.ApiTypes.Party
import com.daml.lf.archive.Dar
import com.daml.lf.data.Ref.PackageId
import com.daml.lf.engine.trigger.RunningTrigger
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
import scala.concurrent.{ExecutionContext, Future}
@ -18,6 +19,10 @@ trait RunningTriggerDao extends Closeable {
def addRunningTrigger(t: RunningTrigger)(implicit ec: ExecutionContext): Future[Unit]
def getRunningTrigger(triggerInstance: UUID)(
implicit ec: ExecutionContext): Future[Option[RunningTrigger]]
def updateRunningTriggerToken(
triggerInstance: UUID,
accessToken: AccessToken,
refreshToken: Option[RefreshToken])(implicit ec: ExecutionContext): Future[Unit]
def removeRunningTrigger(triggerInstance: UUID)(implicit ec: ExecutionContext): Future[Boolean]
def listRunningTriggers(party: Party)(implicit ec: ExecutionContext): Future[Vector[UUID]]
def persistPackages(dar: Dar[(PackageId, DamlLf.ArchivePayload)])(

View File

@ -24,6 +24,10 @@ package trigger {
sealed trait AccessTokenTag
type AccessToken = String @@ AccessTokenTag
val AccessToken = Tag.of[AccessTokenTag]
sealed trait RefreshTokenTag
type RefreshToken = String @@ RefreshTokenTag
val RefreshToken = Tag.of[RefreshTokenTag]
}
import Tagged._
import com.daml.ledger.api.refinements.ApiTypes.ApplicationId
@ -47,6 +51,7 @@ package trigger {
triggerName: Identifier,
triggerParty: Party,
triggerApplicationId: ApplicationId,
triggerToken: Option[AccessToken],
triggerAccessToken: Option[AccessToken],
triggerRefreshToken: Option[RefreshToken],
)
}

View File

@ -11,6 +11,9 @@
<logger name="io.grpc.netty" level="WARN">
<appender-ref ref="stderr-appender"/>
</logger>
<logger name="com.daml.lf.engine.trigger" level="DEBUG">
<appender-ref ref="stderr-appender"/>
</logger>
<root level="INFO">
<appender-ref ref="STDOUT" />
</root>

View File

@ -5,8 +5,8 @@ package com.daml.lf.engine.trigger
import java.io.File
import java.net.InetAddress
import java.time.LocalDateTime
import java.util.UUID
import java.time.{Clock, Duration => JDuration, Instant, LocalDateTime, ZoneId}
import java.util.{Date, UUID}
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
import io.grpc.Channel
@ -21,10 +21,15 @@ import akka.http.scaladsl.Http
import akka.http.scaladsl.Http.ServerBinding
import akka.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes, Uri, headers}
import akka.http.scaladsl.model.Uri.Path
import com.auth0.jwt.JWT
import com.auth0.jwt.JWTVerifier.BaseVerification
import com.auth0.jwt.algorithms.Algorithm
import com.auth0.jwt.interfaces.{Clock => Auth0Clock}
import com.daml.bazeltools.BazelRunfiles
import com.daml.clock.AdjustableClock
import com.daml.daml_lf_dev.DamlLf
import com.daml.jwt.domain.DecodedJwt
import com.daml.jwt.{HMAC256Verifier, JwtSigner, JwtVerifierBase}
import com.daml.jwt.{JwtSigner, JwtVerifier, JwtVerifierBase}
import com.daml.ledger.api.auth
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload}
import com.daml.ledger.api.domain.LedgerId
@ -158,17 +163,26 @@ trait AuthMiddlewareFixture
jwt.value
}
protected def authConfig: AuthConfig = AuthMiddleware(authMiddlewareUri)
protected def authServer: OAuthServer = resource.value._1
protected def authClock: AdjustableClock = resource.value._1
protected def authServer: OAuthServer = resource.value._2
private def authVerifier: JwtVerifierBase = HMAC256Verifier(authSecret).toOption.get
private def authMiddleware: ServerBinding = resource.value._2
private def authVerifier: JwtVerifierBase = new JwtVerifier(
JWT
.require(Algorithm.HMAC256(authSecret))
.asInstanceOf[BaseVerification]
.build(new Auth0Clock {
override def getToday: Date = Date.from(authClock.instant())
})
)
private def authMiddleware: ServerBinding = resource.value._3
private def authMiddlewareUri: Uri =
Uri()
.withScheme("http")
.withAuthority(authMiddleware.localAddress.getHostString, authMiddleware.localAddress.getPort)
private val authSecret: String = "secret"
private var resource: OwnedResource[ResourceContext, (OAuthServer, ServerBinding)] = null
private var resource
: OwnedResource[ResourceContext, (AdjustableClock, OAuthServer, ServerBinding)] = null
override protected def beforeAll(): Unit = {
super.beforeAll()
@ -178,18 +192,23 @@ trait AuthMiddlewareFixture
for {
_ <- binding.unbind()
} yield ()
val oauthConfig = OAuthConfig(
port = Port.Dynamic,
ledgerId = this.getClass.getSimpleName,
jwtSecret = authSecret,
parties = authParties,
clock = None,
)
resource = new OwnedResource(new ResourceOwner[(OAuthServer, ServerBinding)] {
override def acquire()(
implicit context: ResourceContext): Resource[(OAuthServer, ServerBinding)] = {
val oauthServer = OAuthServer(oauthConfig)
val ledgerId = this.getClass.getSimpleName
resource = new OwnedResource(new ResourceOwner[(AdjustableClock, OAuthServer, ServerBinding)] {
override def acquire()(implicit context: ResourceContext)
: Resource[(AdjustableClock, OAuthServer, ServerBinding)] = {
for {
clock <- Resource(
Future(
AdjustableClock(Clock.fixed(Instant.now(), ZoneId.systemDefault()), JDuration.ZERO)))(
_ => Future(()))
oauthConfig = OAuthConfig(
port = Port.Dynamic,
ledgerId = ledgerId,
jwtSecret = authSecret,
parties = authParties,
clock = Some(clock),
)
oauthServer = OAuthServer(oauthConfig)
oauth <- Resource(oauthServer.start())(closeServerBinding)
uri = Uri()
.withScheme("http")
@ -203,7 +222,7 @@ trait AuthMiddlewareFixture
tokenVerifier = authVerifier,
)
middleware <- Resource(MiddlewareServer.start(middlewareConfig))(closeServerBinding)
} yield (oauthServer, middleware)
} yield (clock, oauthServer, middleware)
}
})
resource.setup()

View File

@ -10,6 +10,7 @@ import akka.http.scaladsl.model._
import akka.util.ByteString
import akka.stream.scaladsl.{FileIO, Sink, Source}
import java.io.File
import java.time.{Duration => JDuration}
import java.util.UUID
import akka.http.scaladsl.model.Uri.Query
@ -24,6 +25,7 @@ import com.daml.bazeltools.BazelRunfiles.requiredResource
import com.daml.ledger.api.refinements.ApiTypes
import com.daml.ledger.api.v1.commands._
import com.daml.ledger.api.v1.command_service._
import com.daml.ledger.api.v1.event.CreatedEvent
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset.LedgerBoundary.LEDGER_BEGIN
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset.Value.Boundary
@ -58,7 +60,7 @@ trait AbstractTriggerServiceTest
protected val testPkgId = dar.main._1
override protected val damlPackages: List[File] = List(darPath)
private def submitCmd(client: LedgerClient, party: String, cmd: Command) = {
protected def submitCmd(client: LedgerClient, party: String, cmd: Command) = {
val req = SubmitAndWaitRequest(
Some(
Commands(
@ -75,12 +77,12 @@ trait AbstractTriggerServiceTest
protected override def actorSystemName = testId
protected val alice: Party = Tag("Alice")
// This party is used by the test that queries the ACS.
// To avoid mixing this up with the other tests, we use a separate
// party.
protected val aliceAcs: Party = Tag("Alice_acs")
protected val bob: Party = Tag("Bob")
protected val eve: Party = Tag("Eve")
// These parties are used by tests that query the ACS.
// To avoid mixing this up with the other tests, we use a separate party.
protected val aliceAcs: Party = Tag("Alice_acs")
protected val aliceExp: Party = Tag("Alice_exp")
def startTrigger(
uri: Uri,
@ -173,6 +175,18 @@ trait AbstractTriggerServiceTest
} yield triggerIds
}
def getActiveContracts(
client: LedgerClient,
party: Party,
template: Identifier): Future[Seq[CreatedEvent]] = {
val filter = TransactionFilter(
Map(party.unwrap -> Filters(Some(InclusiveFilters(Seq(template))))))
client.activeContractSetClient
.getActiveContracts(filter)
.runWith(Sink.seq)
.map(acsPages => acsPages.flatMap(_.activeContracts))
}
def assertTriggerIds(uri: Uri, party: Party, expected: Vector[UUID]): Future[Assertion] =
for {
resp <- listTriggers(uri, party)
@ -266,18 +280,10 @@ trait AbstractTriggerServiceTest
client <- sandboxClient(
ApiTypes.ApplicationId("my-app-id"),
actAs = List(ApiTypes.Party(aliceAcs.unwrap)))
filter = TransactionFilter(
List(
(
aliceAcs.unwrap,
Filters(Some(InclusiveFilters(Seq(Identifier(testPkgId, "TestTrigger", "B"))))))).toMap)
// Make sure that no contracts exist initially to guard against accidental
// party reuse.
acs <- client.activeContractSetClient
.getActiveContracts(filter)
.runWith(Sink.seq)
.map(acsPages => acsPages.flatMap(_.activeContracts))
_ = acs shouldBe Vector()
_ <- getActiveContracts(client, aliceAcs, Identifier(testPkgId, "TestTrigger", "B"))
.map(_ shouldBe Vector())
// Start the trigger
resp <- startTrigger(
uri,
@ -302,12 +308,8 @@ trait AbstractTriggerServiceTest
}
// Query ACS until we see a B contract
_ <- RetryStrategy.constant(5, 1.seconds) { (_, _) =>
for {
acs <- client.activeContractSetClient
.getActiveContracts(filter)
.runWith(Sink.seq)
.map(acsPages => acsPages.flatMap(_.activeContracts))
} yield assert(acs.length == 1)
getActiveContracts(client, aliceAcs, Identifier(testPkgId, "TestTrigger", "B"))
.map(_.length shouldBe 1)
}
// Read completions to make sure we set the right app id.
r <- client.commandClient
@ -512,7 +514,7 @@ trait AbstractTriggerServiceTestAuthMiddleware
extends AbstractTriggerServiceTest
with AuthMiddlewareFixture {
override protected val authParties = Some(Set(alice, aliceAcs, bob))
override protected val authParties = Some(Set(alice, aliceAcs, aliceExp, bob))
behavior of "authenticated service"
@ -559,4 +561,93 @@ trait AbstractTriggerServiceTestAuthMiddleware
_ <- resp.status shouldBe StatusCodes.Forbidden
} yield succeed
}
it should "request a fresh token after expiry on user request" in withTriggerService(Nil) {
uri: Uri =>
for {
resp <- listTriggers(uri, alice)
_ <- resp.status shouldBe StatusCodes.OK
// Expire old token and test the trigger service transparently requests a new token.
_ = authClock.fastForward(
JDuration.ofSeconds(authServer.tokenLifetimeSeconds.asInstanceOf[Long] + 1))
resp <- listTriggers(uri, alice)
_ <- resp.status shouldBe StatusCodes.OK
} yield succeed
}
it should "refresh a token after expiry on the server side" in withTriggerService(List(dar)) {
uri: Uri =>
for {
client <- sandboxClient(
ApiTypes.ApplicationId("exp-app-id"),
actAs = List(ApiTypes.Party(aliceExp.unwrap)))
// Make sure that no contracts exist initially to guard against accidental
// party reuse.
_ <- getActiveContracts(client, aliceExp, Identifier(testPkgId, "TestTrigger", "B"))
.map(_ shouldBe Vector())
// Start the trigger
resp <- startTrigger(
uri,
s"$testPkgId:TestTrigger:trigger",
aliceExp,
Some(ApplicationId("exp-app-id")))
triggerId <- parseTriggerId(resp)
// Expire old token and test that the trigger service requests a new token during trigger start-up.
// TODO[AH] Here we want to test token expiry during QueryingACS.
// For now the test relies on timing. Find a way to enforce expiry during QueryingACS.
_ = authClock.fastForward(
JDuration.ofSeconds(authServer.tokenLifetimeSeconds.asInstanceOf[Long] + 1))
// Trigger is running, create an A contract
createACommand = { v: Long =>
Command().withCreate(
CreateCommand(
templateId = Some(Identifier(testPkgId, "TestTrigger", "A")),
createArguments = Some(
Record(
None,
Seq(
RecordField(value = Some(Value().withParty(aliceExp.unwrap))),
RecordField(value = Some(Value().withInt64(v))))))
))
}
_ <- submitCmd(client, aliceExp.unwrap, createACommand(7))
// Query ACS until we see a B contract
_ <- RetryStrategy.constant(5, 1.seconds) { (_, _) =>
getActiveContracts(client, aliceExp, Identifier(testPkgId, "TestTrigger", "B"))
.map(_.length shouldBe 1)
}
// Expire old token and test that the trigger service requests a new token during running trigger.
_ = authClock.fastForward(
JDuration.ofSeconds(authServer.tokenLifetimeSeconds.asInstanceOf[Long] + 1))
// Create another A contract
_ <- submitCmd(client, aliceExp.unwrap, createACommand(42))
// Query ACS until we see a second B contract
_ <- RetryStrategy.constant(5, 1.seconds) { (_, _) =>
getActiveContracts(client, aliceExp, Identifier(testPkgId, "TestTrigger", "B"))
.map(_.length shouldBe 2)
}
// Read completions to make sure we set the right app id.
r <- client.commandClient
.completionSource(List(aliceExp.unwrap), LedgerOffset(Boundary(LEDGER_BEGIN)))
.collect({
case CompletionStreamElement.CompletionElement(completion)
if !completion.transactionId.isEmpty =>
completion
})
.take(1)
.runWith(Sink.seq)
_ = r.length shouldBe 1
status <- triggerStatus(uri, triggerId)
_ = status.status shouldBe StatusCodes.OK
body <- responseBodyToString(status)
_ = body shouldBe s"""{"result":{"party":"Alice_exp","status":"running","triggerId":"$testPkgId:TestTrigger:trigger"},"status":200}"""
resp <- stopTrigger(uri, triggerId, aliceExp)
_ <- assert(resp.status.isSuccess)
} yield succeed
}
}