mirror of
https://github.com/digital-asset/daml.git
synced 2024-11-08 21:34:22 +03:00
Handle token expiry in trigger service (#8037)
* Enable adjustable clock in trigger service tests changelog_begin changelog_end * Test user side token expiry * Test service side token refresh * Use AccessToken wrapper in TriggerRunnerImpl * Store refresh token in trigger DB * add refresh token to trigger runner config * TriggerTokenExpired message to server * TriggerTokenRefresh message to server * refresh trigger token and update db * Restart trigger with fresh token * Test second token expiry * Refresh token on running trigger changelog_begin * [Triggers] UNAUTHENTICATED errors will now terminate the trigger. These errors are no longer available for handling in the trigger DAML code. Instead, they are forwarded to the trigger service for handling, e.g. access token refresh. changelog_end * todo note * Move triggerRunnerName and getRunner into object * Factor out token refresh * Factor out getActiveContracts * factor out create command * Add logging to token refresh * Handle token expiry in TriggerRunner TriggerRunnerImpl throws a dedicated exception when it fails on an expired access token (any unauthenticated error to be precise). The TriggerRunner supervisor reacts to this child failure by requesting a token refresh and restart on the trigger server and stopping itself. The trigger server requests a new access and refresh token on the auth middleware and restarts the trigger. This works around an issue with actor supervisors in akka-actor-typed. A stop supervisor wrapped within a restart supervisor will not cause a stop as expected. Instead, the restart supervisor will trigger as well and restart the actor. The work around uses a custom behavior interceptor to emulate the appropriate stop supervisors as closely as possible. We cannot properly emulate ChildFailed signals this way, so we use dedicated messages intead. * throw --> Future.failedo * getOrFail helper Co-authored-by: Andreas Herrmann <andreas.herrmann@tweag.io>
This commit is contained in:
parent
6b7f714c4d
commit
8bceeb13de
@ -535,7 +535,9 @@ class Runner(
|
||||
val f: Future[Empty] = client.commandClient
|
||||
.submitSingleCommand(req)
|
||||
f.map(_ => None).recover {
|
||||
case s: StatusRuntimeException =>
|
||||
case s: StatusRuntimeException if s.getStatus != io.grpc.Status.UNAUTHENTICATED =>
|
||||
// Do not capture UNAUTHENTICATED errors.
|
||||
// The access token may be expired, let the trigger runner handle token refresh.
|
||||
Some(SingleCommandFailure(req.getCommands.commandId, s))
|
||||
// any other error will cause the trigger's stream to fail
|
||||
}
|
||||
|
@ -117,6 +117,7 @@ da_scala_library(
|
||||
"//ledger/participant-state",
|
||||
"//ledger/sandbox-classic",
|
||||
"//ledger/sandbox-common",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/postgresql-testing",
|
||||
"//libs-scala/resources",
|
||||
@ -125,6 +126,7 @@ da_scala_library(
|
||||
"//triggers/service/auth:oauth-middleware",
|
||||
"//triggers/service/auth:oauth-test-server",
|
||||
"@maven//:ch_qos_logback_logback_classic",
|
||||
"@maven//:com_auth0_java_jwt",
|
||||
"@maven//:com_typesafe_akka_akka_actor_typed_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_http_core_2_12",
|
||||
"@maven//:com_typesafe_akka_akka_parsing_2_12",
|
||||
@ -181,6 +183,7 @@ da_scala_test_suite(
|
||||
"//ledger/ledger-resources",
|
||||
"//ledger/sandbox-classic",
|
||||
"//ledger/sandbox-common",
|
||||
"//libs-scala/adjustable-clock",
|
||||
"//libs-scala/flyway-testing",
|
||||
"//libs-scala/ports",
|
||||
"//libs-scala/postgresql-testing",
|
||||
|
@ -32,7 +32,7 @@ import scala.util.{Failure, Success, Try}
|
||||
// request to /authorize are immediately redirected to the redirect_uri.
|
||||
class Server(config: Config) {
|
||||
private val jwtHeader = """{"alg": "HS256", "typ": "JWT"}"""
|
||||
private val tokenLifetimeSeconds = 24 * 60 * 60
|
||||
val tokenLifetimeSeconds = 24 * 60 * 60
|
||||
|
||||
// None indicates that all parties are authorized, Some that only the given set of parties is authorized.
|
||||
private var authorizedParties: Option[Set[Party]] = config.parties
|
||||
|
@ -0,0 +1,5 @@
|
||||
-- Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
|
||||
-- SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
-- Add refresh token to running trigger table
|
||||
alter table running_triggers add column refresh_token text;
|
@ -0,0 +1 @@
|
||||
b6b0f69413a5977f86066c7ed2b0e11c3151f17e147f2554f320d7fa6fa1c338
|
@ -16,6 +16,7 @@ import akka.actor.typed.{ActorRef, Behavior, Scheduler}
|
||||
import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.Http.ServerBinding
|
||||
import akka.http.scaladsl.marshallers.sprayjson.SprayJsonSupport._
|
||||
import akka.http.scaladsl.marshalling.Marshal
|
||||
import akka.http.scaladsl.model.Uri.Path
|
||||
import akka.http.scaladsl.model._
|
||||
import akka.http.scaladsl.server.Directives._
|
||||
@ -36,7 +37,7 @@ import com.daml.lf.data.Ref.{Identifier, PackageId}
|
||||
import com.daml.lf.engine._
|
||||
import com.daml.lf.engine.trigger.Request.StartParams
|
||||
import com.daml.lf.engine.trigger.Response._
|
||||
import com.daml.lf.engine.trigger.Tagged.AccessToken
|
||||
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
|
||||
import com.daml.lf.engine.trigger.TriggerRunner._
|
||||
import com.daml.lf.engine.trigger.dao._
|
||||
import com.daml.oauth.middleware.Request.Claims
|
||||
@ -146,10 +147,16 @@ class Server(
|
||||
// Note that this does not yet start the trigger.
|
||||
private def addNewTrigger(
|
||||
config: TriggerConfig,
|
||||
token: Option[AccessToken]
|
||||
auth: Option[Authorization],
|
||||
)(implicit ec: ExecutionContext): Future[Either[String, (Trigger, RunningTrigger)]] = {
|
||||
val runningTrigger =
|
||||
RunningTrigger(config.instance, config.name, config.party, config.applicationId, token)
|
||||
RunningTrigger(
|
||||
config.instance,
|
||||
config.name,
|
||||
config.party,
|
||||
config.applicationId,
|
||||
auth.map(_.accessToken),
|
||||
auth.flatMap(_.refreshToken))
|
||||
// Validate trigger id before persisting to DB
|
||||
Trigger.fromIdentifier(compiledPackages, runningTrigger.triggerName) match {
|
||||
case Left(value) => Future.successful(Left(value))
|
||||
@ -216,6 +223,8 @@ class Server(
|
||||
// TODO[AH] Make sure this is bounded in size.
|
||||
private val authCallbacks: TrieMap[UUID, Route] = TrieMap()
|
||||
|
||||
case class Authorization(accessToken: AccessToken, refreshToken: Option[RefreshToken])
|
||||
|
||||
// This directive requires authorization for the given claims via the auth middleware, if configured.
|
||||
// If no auth middleware is configured, then the request will proceed without attempting authorization.
|
||||
//
|
||||
@ -225,14 +234,14 @@ class Server(
|
||||
// to proceed once the login flow completed and authentication succeeded.
|
||||
private def authorize(claims: AuthRequest.Claims)(
|
||||
implicit ec: ExecutionContext,
|
||||
system: ActorSystem): Directive1[Option[AccessToken]] =
|
||||
system: ActorSystem): Directive1[Option[Authorization]] =
|
||||
authConfig match {
|
||||
case NoAuth => provide(None)
|
||||
case AuthMiddleware(authUri) =>
|
||||
// Attempt to obtain the access token from the middleware's /auth endpoint.
|
||||
// Forwards the current request's cookies.
|
||||
// Fails if the response is not OK or Unauthorized.
|
||||
def auth: Directive1[Option[AccessToken]] = {
|
||||
def auth: Directive1[Option[Authorization]] = {
|
||||
val uri = authUri
|
||||
.withPath(Path./("auth"))
|
||||
.withQuery(AuthRequest.Auth(claims).toQuery)
|
||||
@ -241,7 +250,10 @@ class Server(
|
||||
onSuccess(Http().singleRequest(HttpRequest(uri = uri, headers = cookies))).flatMap {
|
||||
case HttpResponse(StatusCodes.OK, _, entity, _) =>
|
||||
onSuccess(Unmarshal(entity).to[AuthResponse.Authorize]).map { auth =>
|
||||
Some(AccessToken(auth.accessToken)): Option[AccessToken]
|
||||
Some(
|
||||
Authorization(
|
||||
AccessToken(auth.accessToken),
|
||||
RefreshToken.subst(auth.refreshToken))): Option[Authorization]
|
||||
}
|
||||
case HttpResponse(StatusCodes.Unauthorized, _, _, _) =>
|
||||
provide(None)
|
||||
@ -259,7 +271,7 @@ class Server(
|
||||
Directive { inner =>
|
||||
auth {
|
||||
// Authorization successful - pass token to continuation
|
||||
case Some(token) => inner(Tuple1(Some(token)))
|
||||
case Some(authorization) => inner(Tuple1(Some(authorization)))
|
||||
// Authorization failed - login and retry on callback request.
|
||||
case None => { ctx =>
|
||||
val requestId = UUID.randomUUID()
|
||||
@ -271,10 +283,10 @@ class Server(
|
||||
// TODO[AH] Add WWW-Authenticate header
|
||||
complete(errorResponse(StatusCodes.Unauthorized))
|
||||
}
|
||||
case Some(token) =>
|
||||
case Some(authorization) =>
|
||||
// Authorization successful after login - use old request context and pass token to continuation.
|
||||
mapRequestContext(_ => ctx) {
|
||||
inner(Tuple1(Some(token)))
|
||||
inner(Tuple1(Some(authorization)))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -298,7 +310,7 @@ class Server(
|
||||
// If the trigger does not exist, then the request will also proceed without attempting authorization.
|
||||
private def authorizeForTrigger(uuid: UUID, readOnly: Boolean = false)(
|
||||
implicit ec: ExecutionContext,
|
||||
system: ActorSystem): Directive1[Option[AccessToken]] =
|
||||
system: ActorSystem): Directive1[Option[Authorization]] =
|
||||
authConfig match {
|
||||
case NoAuth => provide(None)
|
||||
case AuthMiddleware(_) =>
|
||||
@ -349,9 +361,9 @@ class Server(
|
||||
applicationId = Some(config.applicationId))
|
||||
// TODO[AH] Why do we need to pass ec, system explicitly?
|
||||
authorize(claims)(ec, system) {
|
||||
token =>
|
||||
auth =>
|
||||
val instOrErr: Future[Either[String, JsValue]] =
|
||||
addNewTrigger(config, token)
|
||||
addNewTrigger(config, auth)
|
||||
.flatMap {
|
||||
case Left(value) => Future.successful(Left(value))
|
||||
case Right((trigger, runningTrigger)) =>
|
||||
@ -474,9 +486,26 @@ object Server {
|
||||
replyTo: ActorRef[StatusReply[Unit]])
|
||||
extends Message
|
||||
|
||||
final case class RestartTrigger(
|
||||
trigger: Trigger,
|
||||
runningTrigger: RunningTrigger,
|
||||
compiledPackages: CompiledPackages)
|
||||
extends Message
|
||||
|
||||
final case class GetRunner(replyTo: ActorRef[Option[ActorRef[TriggerRunner.Message]]], uuid: UUID)
|
||||
extends Message
|
||||
|
||||
final case class TriggerTokenRefreshFailed(triggerInstance: UUID, cause: Throwable)
|
||||
extends Message
|
||||
|
||||
// Messages passed to the server from a TriggerRunner
|
||||
|
||||
final case class TriggerTokenExpired(
|
||||
triggerInstance: UUID,
|
||||
trigger: Trigger,
|
||||
compiledPackages: CompiledPackages)
|
||||
extends Message
|
||||
|
||||
// Messages passed to the server from a TriggerRunnerImpl
|
||||
|
||||
final case class TriggerStarting(triggerInstance: UUID) extends Message
|
||||
@ -540,37 +569,94 @@ object Server {
|
||||
def logTriggerStarted(m: TriggerStarted): Unit =
|
||||
server.logTriggerStatus(m.triggerInstance, "running")
|
||||
|
||||
def startTrigger(req: StartTrigger): Unit = {
|
||||
val runningTrigger = req.runningTrigger
|
||||
Try(
|
||||
ctx.spawn(
|
||||
TriggerRunner(
|
||||
new TriggerRunner.Config(
|
||||
ctx.self,
|
||||
runningTrigger.triggerInstance,
|
||||
runningTrigger.triggerParty,
|
||||
runningTrigger.triggerApplicationId,
|
||||
AccessToken.unsubst(runningTrigger.triggerToken),
|
||||
req.compiledPackages,
|
||||
req.trigger,
|
||||
ledgerConfig,
|
||||
restartConfig
|
||||
),
|
||||
runningTrigger.triggerInstance.toString
|
||||
def spawnTrigger(
|
||||
trigger: Trigger,
|
||||
runningTrigger: RunningTrigger,
|
||||
compiledPackages: CompiledPackages): ActorRef[TriggerRunner.Message] =
|
||||
ctx.spawn(
|
||||
TriggerRunner(
|
||||
new TriggerRunner.Config(
|
||||
ctx.self,
|
||||
runningTrigger.triggerInstance,
|
||||
runningTrigger.triggerParty,
|
||||
runningTrigger.triggerApplicationId,
|
||||
runningTrigger.triggerAccessToken,
|
||||
runningTrigger.triggerRefreshToken,
|
||||
compiledPackages,
|
||||
trigger,
|
||||
ledgerConfig,
|
||||
restartConfig
|
||||
),
|
||||
triggerRunnerName(runningTrigger.triggerInstance)
|
||||
)) match {
|
||||
runningTrigger.triggerInstance.toString
|
||||
),
|
||||
triggerRunnerName(runningTrigger.triggerInstance)
|
||||
)
|
||||
|
||||
def startTrigger(req: StartTrigger): Unit = {
|
||||
Try(spawnTrigger(req.trigger, req.runningTrigger, req.compiledPackages)) match {
|
||||
case Failure(exception) => req.replyTo ! StatusReply.error(exception)
|
||||
case Success(_) => req.replyTo ! StatusReply.success(())
|
||||
}
|
||||
}
|
||||
|
||||
def restartTrigger(req: RestartTrigger): Unit = {
|
||||
val _ = spawnTrigger(req.trigger, req.runningTrigger, req.compiledPackages)
|
||||
}
|
||||
|
||||
def getRunner(req: GetRunner) = {
|
||||
req.replyTo ! ctx
|
||||
.child(triggerRunnerName(req.uuid))
|
||||
.asInstanceOf[Option[ActorRef[TriggerRunner.Message]]]
|
||||
}
|
||||
|
||||
def refreshAccessToken(triggerInstance: UUID): Future[RunningTrigger] = {
|
||||
def getOrFail[T](result: Option[T], ex: => Throwable): Future[T] = result match {
|
||||
case Some(value) => Future.successful(value)
|
||||
case None => Future.failed(ex)
|
||||
}
|
||||
|
||||
for {
|
||||
// Lookup running trigger
|
||||
runningTrigger <- dao
|
||||
.getRunningTrigger(triggerInstance)
|
||||
.flatMap(getOrFail(_, new RuntimeException(s"Unknown trigger $triggerInstance")))
|
||||
// Request a token refresh
|
||||
authUri <- getOrFail(authConfig match {
|
||||
case AuthMiddleware(uri) => Some(uri)
|
||||
case _ => None
|
||||
}, new RuntimeException("Cannot refresh token without authorization service"))
|
||||
refreshToken <- getOrFail(
|
||||
runningTrigger.triggerRefreshToken,
|
||||
new RuntimeException(s"No refresh token for $triggerInstance"))
|
||||
requestEntity <- {
|
||||
import AuthJsonProtocol._
|
||||
Marshal(AuthRequest.Refresh(RefreshToken.unwrap(refreshToken)))
|
||||
.to[RequestEntity]
|
||||
}
|
||||
response <- Http().singleRequest(
|
||||
HttpRequest(
|
||||
method = HttpMethods.POST,
|
||||
uri = authUri.withPath(Path./("refresh")),
|
||||
entity = requestEntity,
|
||||
))
|
||||
authorize <- response.status match {
|
||||
case StatusCodes.OK =>
|
||||
import AuthJsonProtocol._
|
||||
Unmarshal(response.entity).to[AuthResponse.Authorize]
|
||||
case status =>
|
||||
Unmarshal(response).to[String].flatMap { msg =>
|
||||
Future.failed(new RuntimeException(s"Failed to refresh token ($status): $msg"))
|
||||
}
|
||||
}
|
||||
// Update the tokens in the trigger db
|
||||
accessToken = AccessToken(authorize.accessToken)
|
||||
refreshToken = RefreshToken.subst(authorize.refreshToken)
|
||||
_ <- dao.updateRunningTriggerToken(triggerInstance, accessToken, refreshToken)
|
||||
} yield
|
||||
runningTrigger
|
||||
.copy(triggerAccessToken = Some(accessToken), triggerRefreshToken = refreshToken)
|
||||
}
|
||||
|
||||
// The server running state.
|
||||
def running(binding: ServerBinding): Behavior[Message] =
|
||||
Behaviors
|
||||
@ -578,6 +664,9 @@ object Server {
|
||||
case req: StartTrigger =>
|
||||
startTrigger(req)
|
||||
Behaviors.same
|
||||
case req: RestartTrigger =>
|
||||
restartTrigger(req)
|
||||
Behaviors.same
|
||||
case req: GetRunner =>
|
||||
getRunner(req)
|
||||
Behaviors.same
|
||||
@ -607,6 +696,18 @@ object Server {
|
||||
// the management of a supervision strategy).
|
||||
Behaviors.same
|
||||
|
||||
case TriggerTokenExpired(triggerInstance, trigger, compiledPackages) =>
|
||||
ctx.log.info(s"Updating token for $triggerInstance")
|
||||
ctx.pipeToSelf(refreshAccessToken(triggerInstance)) {
|
||||
case Success(runningTrigger) =>
|
||||
RestartTrigger(trigger, runningTrigger, compiledPackages)
|
||||
case Failure(cause) => TriggerTokenRefreshFailed(triggerInstance, cause)
|
||||
}
|
||||
Behaviors.same
|
||||
case TriggerTokenRefreshFailed(triggerInstance, cause) =>
|
||||
server.logTriggerStatus(triggerInstance, s"stopped: failed to refresh token: $cause")
|
||||
Behaviors.same
|
||||
|
||||
case GetServerBinding(replyTo) =>
|
||||
replyTo ! binding
|
||||
Behaviors.same
|
||||
@ -664,12 +765,18 @@ object Server {
|
||||
case req: StartTrigger =>
|
||||
startTrigger(req)
|
||||
Behaviors.same
|
||||
case req: RestartTrigger =>
|
||||
restartTrigger(req)
|
||||
Behaviors.same
|
||||
case req: GetRunner =>
|
||||
getRunner(req)
|
||||
Behaviors.same
|
||||
|
||||
case _: TriggerInitializationFailure | _: TriggerRuntimeFailure =>
|
||||
Behaviors.unhandled
|
||||
|
||||
case _: TriggerTokenExpired | _: TriggerTokenRefreshFailed =>
|
||||
Behaviors.unhandled
|
||||
}
|
||||
|
||||
// The server binding is a future that on completion will be piped
|
||||
|
@ -3,17 +3,27 @@
|
||||
|
||||
package com.daml.lf.engine.trigger
|
||||
|
||||
import akka.actor.typed.SupervisorStrategy._
|
||||
import akka.actor.typed.SupervisorStrategy.restartWithBackoff
|
||||
import akka.actor.typed.scaladsl.Behaviors
|
||||
import akka.actor.typed.{ActorRef, Behavior, PostStop}
|
||||
import akka.actor.typed.{
|
||||
ActorRef,
|
||||
Behavior,
|
||||
BehaviorInterceptor,
|
||||
PostStop,
|
||||
Signal,
|
||||
TypedActorContext
|
||||
}
|
||||
import akka.stream.Materializer
|
||||
import com.daml.grpc.adapter.ExecutionSequencerFactory
|
||||
import com.daml.logging.LoggingContextOf.{label, newLoggingContext}
|
||||
import com.daml.logging.{ContextualizedLogger}
|
||||
import com.daml.logging.ContextualizedLogger
|
||||
import spray.json._
|
||||
|
||||
import scala.util.control.Exception.Catcher
|
||||
|
||||
class InitializationHalted(s: String) extends Exception(s) {}
|
||||
class InitializationException(s: String) extends Exception(s) {}
|
||||
case class UnauthenticatedException(s: String) extends Exception(s) {}
|
||||
|
||||
object TriggerRunner {
|
||||
private val logger = ContextualizedLogger.get(this.getClass)
|
||||
@ -23,6 +33,7 @@ object TriggerRunner {
|
||||
sealed trait Message
|
||||
final case object Stop extends Message
|
||||
final case class Status(replyTo: ActorRef[TriggerStatus]) extends Message
|
||||
private final case class Unauthenticated(cause: UnauthenticatedException) extends Message
|
||||
|
||||
sealed trait TriggerStatus
|
||||
final case object QueryingACS extends TriggerStatus
|
||||
@ -43,6 +54,54 @@ object TriggerRunner {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO[AH] Workaround for https://github.com/akka/akka/issues/29841.
|
||||
// Remove once fixed upstream.
|
||||
private class Interceptor(parent: ActorRef[TriggerRunner.Message])
|
||||
extends BehaviorInterceptor[TriggerRunnerImpl.Message, TriggerRunnerImpl.Message] {
|
||||
private def handleException(ctx: TypedActorContext[TriggerRunnerImpl.Message])
|
||||
: Catcher[Behavior[TriggerRunnerImpl.Message]] = {
|
||||
case e: InitializationHalted => {
|
||||
// This should be a stop supervisor nested under the restart supervisor.
|
||||
ctx.asScala.log.info(s"Supervisor saw failure ${e.getMessage} - stopping")
|
||||
Behaviors.stopped
|
||||
}
|
||||
case e: UnauthenticatedException => {
|
||||
// This should be a stop supervisor nested under the restart supervisor.
|
||||
// The TriggerRunner should receive a ChildFailed signal when watching TriggerRunnerImpl.
|
||||
// This cannot be emulated outside the akka-actor-typed implementation, so we use a dedicated message instead.
|
||||
ctx.asScala.log.info(s"Supervisor saw failure ${e.getMessage} - stopping")
|
||||
parent ! Unauthenticated(e)
|
||||
Behaviors.stopped
|
||||
}
|
||||
}
|
||||
override def aroundStart(
|
||||
ctx: TypedActorContext[TriggerRunnerImpl.Message],
|
||||
target: BehaviorInterceptor.PreStartTarget[TriggerRunnerImpl.Message])
|
||||
: Behavior[TriggerRunnerImpl.Message] = {
|
||||
try {
|
||||
target.start(ctx)
|
||||
} catch handleException(ctx)
|
||||
}
|
||||
override def aroundReceive(
|
||||
ctx: TypedActorContext[TriggerRunnerImpl.Message],
|
||||
msg: TriggerRunnerImpl.Message,
|
||||
target: BehaviorInterceptor.ReceiveTarget[TriggerRunnerImpl.Message])
|
||||
: Behavior[TriggerRunnerImpl.Message] = {
|
||||
try {
|
||||
target(ctx, msg)
|
||||
} catch handleException(ctx)
|
||||
}
|
||||
override def aroundSignal(
|
||||
ctx: TypedActorContext[TriggerRunnerImpl.Message],
|
||||
signal: Signal,
|
||||
target: BehaviorInterceptor.SignalTarget[TriggerRunnerImpl.Message])
|
||||
: Behavior[TriggerRunnerImpl.Message] = {
|
||||
try {
|
||||
target(ctx, signal)
|
||||
} catch handleException(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
def apply(config: Config, name: String)(
|
||||
implicit esf: ExecutionSequencerFactory,
|
||||
mat: Materializer): Behavior[TriggerRunner.Message] =
|
||||
@ -57,8 +116,7 @@ object TriggerRunner {
|
||||
Behaviors
|
||||
.supervise(
|
||||
Behaviors
|
||||
.supervise(TriggerRunnerImpl(config))
|
||||
.onFailure[InitializationHalted](stop)
|
||||
.intercept(() => new Interceptor(ctx.self))(TriggerRunnerImpl(config))
|
||||
)
|
||||
.onFailure(
|
||||
restartWithBackoff(
|
||||
@ -75,6 +133,14 @@ object TriggerRunner {
|
||||
Behaviors.same
|
||||
case Stop =>
|
||||
Behaviors.stopped // Automatically stops the child actor if running.
|
||||
case Unauthenticated(cause) =>
|
||||
logger.warn(
|
||||
s"Trigger was unauthenticated - requesting token refresh: ${cause.getMessage}")
|
||||
config.server ! Server.TriggerTokenExpired(
|
||||
config.triggerInstance,
|
||||
config.trigger,
|
||||
config.compiledPackages)
|
||||
Behaviors.stopped
|
||||
}
|
||||
.receiveSignal {
|
||||
case (_, PostStop) =>
|
||||
|
@ -25,6 +25,7 @@ import scalaz.syntax.tag._
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
import scala.util.{Failure, Success}
|
||||
import TriggerRunner.{QueryingACS, Running, TriggerStatus}
|
||||
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
|
||||
|
||||
object TriggerRunnerImpl {
|
||||
|
||||
@ -33,7 +34,8 @@ object TriggerRunnerImpl {
|
||||
triggerInstance: UUID,
|
||||
party: Party,
|
||||
applicationId: ApplicationId,
|
||||
token: Option[String],
|
||||
accessToken: Option[AccessToken],
|
||||
refreshToken: Option[RefreshToken],
|
||||
compiledPackages: CompiledPackages,
|
||||
trigger: Trigger,
|
||||
ledgerConfig: LedgerConfig,
|
||||
@ -69,7 +71,7 @@ object TriggerRunnerImpl {
|
||||
commandClient = CommandClientConfiguration.default.copy(
|
||||
defaultDeduplicationTime = config.ledgerConfig.commandTtl),
|
||||
sslContext = None,
|
||||
token = config.token,
|
||||
token = AccessToken.unsubst(config.accessToken),
|
||||
maxInboundMessageSize = config.ledgerConfig.maxInboundMessageSize,
|
||||
)
|
||||
|
||||
@ -80,6 +82,9 @@ object TriggerRunnerImpl {
|
||||
case Status(replyTo) =>
|
||||
replyTo ! QueryingACS
|
||||
Behaviors.same
|
||||
case QueryACSFailed(cause: io.grpc.StatusRuntimeException)
|
||||
if cause.getStatus == io.grpc.Status.UNAUTHENTICATED =>
|
||||
throw new UnauthenticatedException(s"Querying ACS failed: ${cause.toString}")
|
||||
case QueryACSFailed(cause) =>
|
||||
// Report the failure to the server.
|
||||
config.server ! Server.TriggerInitializationFailure(triggerInstance, cause.toString)
|
||||
@ -138,6 +143,9 @@ object TriggerRunnerImpl {
|
||||
case Status(replyTo) =>
|
||||
replyTo ! Running
|
||||
Behaviors.same
|
||||
case Failed(cause: io.grpc.StatusRuntimeException)
|
||||
if cause.getStatus == io.grpc.Status.UNAUTHENTICATED =>
|
||||
throw new UnauthenticatedException(s"Querying ACS failed: ${cause.toString}")
|
||||
case Failed(cause) =>
|
||||
// Report the failure to the server.
|
||||
config.server ! Server.TriggerRuntimeFailure(triggerInstance, cause.toString)
|
||||
|
@ -23,7 +23,7 @@ import doobie.{Fragment, Put, Transactor}
|
||||
import scalaz.Tag
|
||||
import java.io.{Closeable, IOException}
|
||||
|
||||
import com.daml.lf.engine.trigger.Tagged.AccessToken
|
||||
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
|
||||
import javax.sql.DataSource
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
@ -83,6 +83,10 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
|
||||
|
||||
implicit val accessTokenGet: Get[AccessToken] = Tag.subst(implicitly[Get[String]])
|
||||
|
||||
implicit val refreshTokenPut: Put[RefreshToken] = Tag.subst(implicitly[Put[String]])
|
||||
|
||||
implicit val refreshTokenGet: Get[RefreshToken] = Tag.subst(implicitly[Get[String]])
|
||||
|
||||
implicit val identifierPut: Put[Identifier] = implicitly[Put[String]].contramap(_.toString)
|
||||
|
||||
implicit val identifierGet: Get[Identifier] =
|
||||
@ -105,24 +109,37 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
|
||||
private def insertRunningTrigger(t: RunningTrigger): ConnectionIO[Unit] = {
|
||||
val insert: Fragment = sql"""
|
||||
insert into running_triggers
|
||||
(trigger_instance, trigger_party, full_trigger_name, access_token, application_id)
|
||||
(trigger_instance, trigger_party, full_trigger_name, access_token, refresh_token, application_id)
|
||||
values
|
||||
(${t.triggerInstance}, ${t.triggerParty}, ${t.triggerName}, ${t.triggerToken}, ${t.triggerApplicationId})
|
||||
(${t.triggerInstance}, ${t.triggerParty}, ${t.triggerName}, ${t.triggerAccessToken}, ${t.triggerRefreshToken}, ${t.triggerApplicationId})
|
||||
"""
|
||||
insert.update.run.void
|
||||
}
|
||||
|
||||
private def queryRunningTrigger(triggerInstance: UUID): ConnectionIO[Option[RunningTrigger]] = {
|
||||
val select: Fragment = sql"""
|
||||
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token from running_triggers
|
||||
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token, refresh_token from running_triggers
|
||||
where trigger_instance = $triggerInstance
|
||||
"""
|
||||
select
|
||||
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken])]
|
||||
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken], Option[RefreshToken])]
|
||||
.map(RunningTrigger.tupled)
|
||||
.option
|
||||
}
|
||||
|
||||
private def setRunningTriggerToken(
|
||||
triggerInstance: UUID,
|
||||
accessToken: AccessToken,
|
||||
refreshToken: Option[RefreshToken]) = {
|
||||
val update: Fragment =
|
||||
sql"""
|
||||
update running_triggers
|
||||
set access_token = $accessToken, refresh_token = $refreshToken
|
||||
where trigger_instance = $triggerInstance
|
||||
"""
|
||||
update.update.run.void
|
||||
}
|
||||
|
||||
// trigger_instance is the primary key on running_triggers so this deletes
|
||||
// at most one row. Return whether or not it deleted.
|
||||
private def deleteRunningTrigger(triggerInstance: UUID): ConnectionIO[Boolean] = {
|
||||
@ -170,10 +187,10 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
|
||||
|
||||
private def selectAllTriggers: ConnectionIO[Vector[RunningTrigger]] = {
|
||||
val select: Fragment = sql"""
|
||||
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token from running_triggers order by trigger_instance
|
||||
select trigger_instance, full_trigger_name, trigger_party, application_id, access_token, refresh_token from running_triggers order by trigger_instance
|
||||
"""
|
||||
select
|
||||
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken])]
|
||||
.query[(UUID, Identifier, Party, ApplicationId, Option[AccessToken], Option[RefreshToken])]
|
||||
.map(RunningTrigger.tupled)
|
||||
.to[Vector]
|
||||
}
|
||||
@ -206,6 +223,12 @@ final class DbTriggerDao private (dataSource: DataSource with Closeable, xa: Con
|
||||
implicit ec: ExecutionContext): Future[Option[RunningTrigger]] =
|
||||
run(queryRunningTrigger(triggerInstance))
|
||||
|
||||
override def updateRunningTriggerToken(
|
||||
triggerInstance: UUID,
|
||||
accessToken: AccessToken,
|
||||
refreshToken: Option[RefreshToken])(implicit ec: ExecutionContext): Future[Unit] =
|
||||
run(setRunningTriggerToken(triggerInstance, accessToken, refreshToken))
|
||||
|
||||
override def removeRunningTrigger(triggerInstance: UUID)(
|
||||
implicit ec: ExecutionContext): Future[Boolean] =
|
||||
run(deleteRunningTrigger(triggerInstance))
|
||||
|
@ -10,6 +10,7 @@ import com.daml.ledger.api.refinements.ApiTypes.Party
|
||||
import com.daml.lf.archive.Dar
|
||||
import com.daml.lf.data.Ref.PackageId
|
||||
import com.daml.lf.engine.trigger.RunningTrigger
|
||||
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
|
||||
@ -29,6 +30,18 @@ final class InMemoryTriggerDao extends RunningTriggerDao {
|
||||
triggers.get(triggerInstance)
|
||||
}
|
||||
|
||||
override def updateRunningTriggerToken(
|
||||
triggerInstance: UUID,
|
||||
accessToken: AccessToken,
|
||||
refreshToken: Option[RefreshToken])(implicit ec: ExecutionContext): Future[Unit] = Future {
|
||||
triggers.get(triggerInstance) match {
|
||||
case Some(t) =>
|
||||
triggers += (triggerInstance -> t
|
||||
.copy(triggerAccessToken = Some(accessToken), triggerRefreshToken = refreshToken))
|
||||
case None => ()
|
||||
}
|
||||
}
|
||||
|
||||
override def removeRunningTrigger(triggerInstance: UUID)(
|
||||
implicit ec: ExecutionContext): Future[Boolean] = Future {
|
||||
triggers.get(triggerInstance) match {
|
||||
|
@ -11,6 +11,7 @@ import com.daml.ledger.api.refinements.ApiTypes.Party
|
||||
import com.daml.lf.archive.Dar
|
||||
import com.daml.lf.data.Ref.PackageId
|
||||
import com.daml.lf.engine.trigger.RunningTrigger
|
||||
import com.daml.lf.engine.trigger.Tagged.{AccessToken, RefreshToken}
|
||||
|
||||
import scala.concurrent.{ExecutionContext, Future}
|
||||
|
||||
@ -18,6 +19,10 @@ trait RunningTriggerDao extends Closeable {
|
||||
def addRunningTrigger(t: RunningTrigger)(implicit ec: ExecutionContext): Future[Unit]
|
||||
def getRunningTrigger(triggerInstance: UUID)(
|
||||
implicit ec: ExecutionContext): Future[Option[RunningTrigger]]
|
||||
def updateRunningTriggerToken(
|
||||
triggerInstance: UUID,
|
||||
accessToken: AccessToken,
|
||||
refreshToken: Option[RefreshToken])(implicit ec: ExecutionContext): Future[Unit]
|
||||
def removeRunningTrigger(triggerInstance: UUID)(implicit ec: ExecutionContext): Future[Boolean]
|
||||
def listRunningTriggers(party: Party)(implicit ec: ExecutionContext): Future[Vector[UUID]]
|
||||
def persistPackages(dar: Dar[(PackageId, DamlLf.ArchivePayload)])(
|
||||
|
@ -24,6 +24,10 @@ package trigger {
|
||||
sealed trait AccessTokenTag
|
||||
type AccessToken = String @@ AccessTokenTag
|
||||
val AccessToken = Tag.of[AccessTokenTag]
|
||||
|
||||
sealed trait RefreshTokenTag
|
||||
type RefreshToken = String @@ RefreshTokenTag
|
||||
val RefreshToken = Tag.of[RefreshTokenTag]
|
||||
}
|
||||
import Tagged._
|
||||
import com.daml.ledger.api.refinements.ApiTypes.ApplicationId
|
||||
@ -47,6 +51,7 @@ package trigger {
|
||||
triggerName: Identifier,
|
||||
triggerParty: Party,
|
||||
triggerApplicationId: ApplicationId,
|
||||
triggerToken: Option[AccessToken],
|
||||
triggerAccessToken: Option[AccessToken],
|
||||
triggerRefreshToken: Option[RefreshToken],
|
||||
)
|
||||
}
|
||||
|
@ -11,6 +11,9 @@
|
||||
<logger name="io.grpc.netty" level="WARN">
|
||||
<appender-ref ref="stderr-appender"/>
|
||||
</logger>
|
||||
<logger name="com.daml.lf.engine.trigger" level="DEBUG">
|
||||
<appender-ref ref="stderr-appender"/>
|
||||
</logger>
|
||||
<root level="INFO">
|
||||
<appender-ref ref="STDOUT" />
|
||||
</root>
|
||||
|
@ -5,8 +5,8 @@ package com.daml.lf.engine.trigger
|
||||
|
||||
import java.io.File
|
||||
import java.net.InetAddress
|
||||
import java.time.LocalDateTime
|
||||
import java.util.UUID
|
||||
import java.time.{Clock, Duration => JDuration, Instant, LocalDateTime, ZoneId}
|
||||
import java.util.{Date, UUID}
|
||||
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
|
||||
|
||||
import io.grpc.Channel
|
||||
@ -21,10 +21,15 @@ import akka.http.scaladsl.Http
|
||||
import akka.http.scaladsl.Http.ServerBinding
|
||||
import akka.http.scaladsl.model.{HttpRequest, HttpResponse, StatusCodes, Uri, headers}
|
||||
import akka.http.scaladsl.model.Uri.Path
|
||||
import com.auth0.jwt.JWT
|
||||
import com.auth0.jwt.JWTVerifier.BaseVerification
|
||||
import com.auth0.jwt.algorithms.Algorithm
|
||||
import com.auth0.jwt.interfaces.{Clock => Auth0Clock}
|
||||
import com.daml.bazeltools.BazelRunfiles
|
||||
import com.daml.clock.AdjustableClock
|
||||
import com.daml.daml_lf_dev.DamlLf
|
||||
import com.daml.jwt.domain.DecodedJwt
|
||||
import com.daml.jwt.{HMAC256Verifier, JwtSigner, JwtVerifierBase}
|
||||
import com.daml.jwt.{JwtSigner, JwtVerifier, JwtVerifierBase}
|
||||
import com.daml.ledger.api.auth
|
||||
import com.daml.ledger.api.auth.{AuthServiceJWTCodec, AuthServiceJWTPayload}
|
||||
import com.daml.ledger.api.domain.LedgerId
|
||||
@ -158,17 +163,26 @@ trait AuthMiddlewareFixture
|
||||
jwt.value
|
||||
}
|
||||
protected def authConfig: AuthConfig = AuthMiddleware(authMiddlewareUri)
|
||||
protected def authServer: OAuthServer = resource.value._1
|
||||
protected def authClock: AdjustableClock = resource.value._1
|
||||
protected def authServer: OAuthServer = resource.value._2
|
||||
|
||||
private def authVerifier: JwtVerifierBase = HMAC256Verifier(authSecret).toOption.get
|
||||
private def authMiddleware: ServerBinding = resource.value._2
|
||||
private def authVerifier: JwtVerifierBase = new JwtVerifier(
|
||||
JWT
|
||||
.require(Algorithm.HMAC256(authSecret))
|
||||
.asInstanceOf[BaseVerification]
|
||||
.build(new Auth0Clock {
|
||||
override def getToday: Date = Date.from(authClock.instant())
|
||||
})
|
||||
)
|
||||
private def authMiddleware: ServerBinding = resource.value._3
|
||||
private def authMiddlewareUri: Uri =
|
||||
Uri()
|
||||
.withScheme("http")
|
||||
.withAuthority(authMiddleware.localAddress.getHostString, authMiddleware.localAddress.getPort)
|
||||
|
||||
private val authSecret: String = "secret"
|
||||
private var resource: OwnedResource[ResourceContext, (OAuthServer, ServerBinding)] = null
|
||||
private var resource
|
||||
: OwnedResource[ResourceContext, (AdjustableClock, OAuthServer, ServerBinding)] = null
|
||||
|
||||
override protected def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
@ -178,18 +192,23 @@ trait AuthMiddlewareFixture
|
||||
for {
|
||||
_ <- binding.unbind()
|
||||
} yield ()
|
||||
val oauthConfig = OAuthConfig(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = this.getClass.getSimpleName,
|
||||
jwtSecret = authSecret,
|
||||
parties = authParties,
|
||||
clock = None,
|
||||
)
|
||||
resource = new OwnedResource(new ResourceOwner[(OAuthServer, ServerBinding)] {
|
||||
override def acquire()(
|
||||
implicit context: ResourceContext): Resource[(OAuthServer, ServerBinding)] = {
|
||||
val oauthServer = OAuthServer(oauthConfig)
|
||||
val ledgerId = this.getClass.getSimpleName
|
||||
resource = new OwnedResource(new ResourceOwner[(AdjustableClock, OAuthServer, ServerBinding)] {
|
||||
override def acquire()(implicit context: ResourceContext)
|
||||
: Resource[(AdjustableClock, OAuthServer, ServerBinding)] = {
|
||||
for {
|
||||
clock <- Resource(
|
||||
Future(
|
||||
AdjustableClock(Clock.fixed(Instant.now(), ZoneId.systemDefault()), JDuration.ZERO)))(
|
||||
_ => Future(()))
|
||||
oauthConfig = OAuthConfig(
|
||||
port = Port.Dynamic,
|
||||
ledgerId = ledgerId,
|
||||
jwtSecret = authSecret,
|
||||
parties = authParties,
|
||||
clock = Some(clock),
|
||||
)
|
||||
oauthServer = OAuthServer(oauthConfig)
|
||||
oauth <- Resource(oauthServer.start())(closeServerBinding)
|
||||
uri = Uri()
|
||||
.withScheme("http")
|
||||
@ -203,7 +222,7 @@ trait AuthMiddlewareFixture
|
||||
tokenVerifier = authVerifier,
|
||||
)
|
||||
middleware <- Resource(MiddlewareServer.start(middlewareConfig))(closeServerBinding)
|
||||
} yield (oauthServer, middleware)
|
||||
} yield (clock, oauthServer, middleware)
|
||||
}
|
||||
})
|
||||
resource.setup()
|
||||
|
@ -10,6 +10,7 @@ import akka.http.scaladsl.model._
|
||||
import akka.util.ByteString
|
||||
import akka.stream.scaladsl.{FileIO, Sink, Source}
|
||||
import java.io.File
|
||||
import java.time.{Duration => JDuration}
|
||||
import java.util.UUID
|
||||
|
||||
import akka.http.scaladsl.model.Uri.Query
|
||||
@ -24,6 +25,7 @@ import com.daml.bazeltools.BazelRunfiles.requiredResource
|
||||
import com.daml.ledger.api.refinements.ApiTypes
|
||||
import com.daml.ledger.api.v1.commands._
|
||||
import com.daml.ledger.api.v1.command_service._
|
||||
import com.daml.ledger.api.v1.event.CreatedEvent
|
||||
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset
|
||||
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset.LedgerBoundary.LEDGER_BEGIN
|
||||
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset.Value.Boundary
|
||||
@ -58,7 +60,7 @@ trait AbstractTriggerServiceTest
|
||||
protected val testPkgId = dar.main._1
|
||||
override protected val damlPackages: List[File] = List(darPath)
|
||||
|
||||
private def submitCmd(client: LedgerClient, party: String, cmd: Command) = {
|
||||
protected def submitCmd(client: LedgerClient, party: String, cmd: Command) = {
|
||||
val req = SubmitAndWaitRequest(
|
||||
Some(
|
||||
Commands(
|
||||
@ -75,12 +77,12 @@ trait AbstractTriggerServiceTest
|
||||
protected override def actorSystemName = testId
|
||||
|
||||
protected val alice: Party = Tag("Alice")
|
||||
// This party is used by the test that queries the ACS.
|
||||
// To avoid mixing this up with the other tests, we use a separate
|
||||
// party.
|
||||
protected val aliceAcs: Party = Tag("Alice_acs")
|
||||
protected val bob: Party = Tag("Bob")
|
||||
protected val eve: Party = Tag("Eve")
|
||||
// These parties are used by tests that query the ACS.
|
||||
// To avoid mixing this up with the other tests, we use a separate party.
|
||||
protected val aliceAcs: Party = Tag("Alice_acs")
|
||||
protected val aliceExp: Party = Tag("Alice_exp")
|
||||
|
||||
def startTrigger(
|
||||
uri: Uri,
|
||||
@ -173,6 +175,18 @@ trait AbstractTriggerServiceTest
|
||||
} yield triggerIds
|
||||
}
|
||||
|
||||
def getActiveContracts(
|
||||
client: LedgerClient,
|
||||
party: Party,
|
||||
template: Identifier): Future[Seq[CreatedEvent]] = {
|
||||
val filter = TransactionFilter(
|
||||
Map(party.unwrap -> Filters(Some(InclusiveFilters(Seq(template))))))
|
||||
client.activeContractSetClient
|
||||
.getActiveContracts(filter)
|
||||
.runWith(Sink.seq)
|
||||
.map(acsPages => acsPages.flatMap(_.activeContracts))
|
||||
}
|
||||
|
||||
def assertTriggerIds(uri: Uri, party: Party, expected: Vector[UUID]): Future[Assertion] =
|
||||
for {
|
||||
resp <- listTriggers(uri, party)
|
||||
@ -266,18 +280,10 @@ trait AbstractTriggerServiceTest
|
||||
client <- sandboxClient(
|
||||
ApiTypes.ApplicationId("my-app-id"),
|
||||
actAs = List(ApiTypes.Party(aliceAcs.unwrap)))
|
||||
filter = TransactionFilter(
|
||||
List(
|
||||
(
|
||||
aliceAcs.unwrap,
|
||||
Filters(Some(InclusiveFilters(Seq(Identifier(testPkgId, "TestTrigger", "B"))))))).toMap)
|
||||
// Make sure that no contracts exist initially to guard against accidental
|
||||
// party reuse.
|
||||
acs <- client.activeContractSetClient
|
||||
.getActiveContracts(filter)
|
||||
.runWith(Sink.seq)
|
||||
.map(acsPages => acsPages.flatMap(_.activeContracts))
|
||||
_ = acs shouldBe Vector()
|
||||
_ <- getActiveContracts(client, aliceAcs, Identifier(testPkgId, "TestTrigger", "B"))
|
||||
.map(_ shouldBe Vector())
|
||||
// Start the trigger
|
||||
resp <- startTrigger(
|
||||
uri,
|
||||
@ -302,12 +308,8 @@ trait AbstractTriggerServiceTest
|
||||
}
|
||||
// Query ACS until we see a B contract
|
||||
_ <- RetryStrategy.constant(5, 1.seconds) { (_, _) =>
|
||||
for {
|
||||
acs <- client.activeContractSetClient
|
||||
.getActiveContracts(filter)
|
||||
.runWith(Sink.seq)
|
||||
.map(acsPages => acsPages.flatMap(_.activeContracts))
|
||||
} yield assert(acs.length == 1)
|
||||
getActiveContracts(client, aliceAcs, Identifier(testPkgId, "TestTrigger", "B"))
|
||||
.map(_.length shouldBe 1)
|
||||
}
|
||||
// Read completions to make sure we set the right app id.
|
||||
r <- client.commandClient
|
||||
@ -512,7 +514,7 @@ trait AbstractTriggerServiceTestAuthMiddleware
|
||||
extends AbstractTriggerServiceTest
|
||||
with AuthMiddlewareFixture {
|
||||
|
||||
override protected val authParties = Some(Set(alice, aliceAcs, bob))
|
||||
override protected val authParties = Some(Set(alice, aliceAcs, aliceExp, bob))
|
||||
|
||||
behavior of "authenticated service"
|
||||
|
||||
@ -559,4 +561,93 @@ trait AbstractTriggerServiceTestAuthMiddleware
|
||||
_ <- resp.status shouldBe StatusCodes.Forbidden
|
||||
} yield succeed
|
||||
}
|
||||
|
||||
it should "request a fresh token after expiry on user request" in withTriggerService(Nil) {
|
||||
uri: Uri =>
|
||||
for {
|
||||
resp <- listTriggers(uri, alice)
|
||||
_ <- resp.status shouldBe StatusCodes.OK
|
||||
// Expire old token and test the trigger service transparently requests a new token.
|
||||
_ = authClock.fastForward(
|
||||
JDuration.ofSeconds(authServer.tokenLifetimeSeconds.asInstanceOf[Long] + 1))
|
||||
resp <- listTriggers(uri, alice)
|
||||
_ <- resp.status shouldBe StatusCodes.OK
|
||||
} yield succeed
|
||||
}
|
||||
|
||||
it should "refresh a token after expiry on the server side" in withTriggerService(List(dar)) {
|
||||
uri: Uri =>
|
||||
for {
|
||||
client <- sandboxClient(
|
||||
ApiTypes.ApplicationId("exp-app-id"),
|
||||
actAs = List(ApiTypes.Party(aliceExp.unwrap)))
|
||||
// Make sure that no contracts exist initially to guard against accidental
|
||||
// party reuse.
|
||||
_ <- getActiveContracts(client, aliceExp, Identifier(testPkgId, "TestTrigger", "B"))
|
||||
.map(_ shouldBe Vector())
|
||||
// Start the trigger
|
||||
resp <- startTrigger(
|
||||
uri,
|
||||
s"$testPkgId:TestTrigger:trigger",
|
||||
aliceExp,
|
||||
Some(ApplicationId("exp-app-id")))
|
||||
triggerId <- parseTriggerId(resp)
|
||||
|
||||
// Expire old token and test that the trigger service requests a new token during trigger start-up.
|
||||
// TODO[AH] Here we want to test token expiry during QueryingACS.
|
||||
// For now the test relies on timing. Find a way to enforce expiry during QueryingACS.
|
||||
_ = authClock.fastForward(
|
||||
JDuration.ofSeconds(authServer.tokenLifetimeSeconds.asInstanceOf[Long] + 1))
|
||||
|
||||
// Trigger is running, create an A contract
|
||||
createACommand = { v: Long =>
|
||||
Command().withCreate(
|
||||
CreateCommand(
|
||||
templateId = Some(Identifier(testPkgId, "TestTrigger", "A")),
|
||||
createArguments = Some(
|
||||
Record(
|
||||
None,
|
||||
Seq(
|
||||
RecordField(value = Some(Value().withParty(aliceExp.unwrap))),
|
||||
RecordField(value = Some(Value().withInt64(v))))))
|
||||
))
|
||||
}
|
||||
_ <- submitCmd(client, aliceExp.unwrap, createACommand(7))
|
||||
// Query ACS until we see a B contract
|
||||
_ <- RetryStrategy.constant(5, 1.seconds) { (_, _) =>
|
||||
getActiveContracts(client, aliceExp, Identifier(testPkgId, "TestTrigger", "B"))
|
||||
.map(_.length shouldBe 1)
|
||||
}
|
||||
|
||||
// Expire old token and test that the trigger service requests a new token during running trigger.
|
||||
_ = authClock.fastForward(
|
||||
JDuration.ofSeconds(authServer.tokenLifetimeSeconds.asInstanceOf[Long] + 1))
|
||||
|
||||
// Create another A contract
|
||||
_ <- submitCmd(client, aliceExp.unwrap, createACommand(42))
|
||||
// Query ACS until we see a second B contract
|
||||
_ <- RetryStrategy.constant(5, 1.seconds) { (_, _) =>
|
||||
getActiveContracts(client, aliceExp, Identifier(testPkgId, "TestTrigger", "B"))
|
||||
.map(_.length shouldBe 2)
|
||||
}
|
||||
|
||||
// Read completions to make sure we set the right app id.
|
||||
r <- client.commandClient
|
||||
.completionSource(List(aliceExp.unwrap), LedgerOffset(Boundary(LEDGER_BEGIN)))
|
||||
.collect({
|
||||
case CompletionStreamElement.CompletionElement(completion)
|
||||
if !completion.transactionId.isEmpty =>
|
||||
completion
|
||||
})
|
||||
.take(1)
|
||||
.runWith(Sink.seq)
|
||||
_ = r.length shouldBe 1
|
||||
status <- triggerStatus(uri, triggerId)
|
||||
_ = status.status shouldBe StatusCodes.OK
|
||||
body <- responseBodyToString(status)
|
||||
_ = body shouldBe s"""{"result":{"party":"Alice_exp","status":"running","triggerId":"$testPkgId:TestTrigger:trigger"},"status":200}"""
|
||||
resp <- stopTrigger(uri, triggerId, aliceExp)
|
||||
_ <- assert(resp.status.isSuccess)
|
||||
} yield succeed
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user