diff --git a/ledger-service/http-json/src/it/scala/http/AbstractHttpServiceIntegrationTest.scala b/ledger-service/http-json/src/it/scala/http/AbstractHttpServiceIntegrationTest.scala index 1f59f7447c..e4c3f2e373 100644 --- a/ledger-service/http-json/src/it/scala/http/AbstractHttpServiceIntegrationTest.scala +++ b/ledger-service/http-json/src/it/scala/http/AbstractHttpServiceIntegrationTest.scala @@ -82,6 +82,8 @@ trait AbstractHttpServiceIntegrationTestFuns extends StrictLogging { def useTls: UseTls + def wsConfig: Option[WebsocketConfig] + protected def testId: String = this.getClass.getSimpleName protected val metdata2: MetadataReader.LfMetadata = @@ -123,7 +125,8 @@ trait AbstractHttpServiceIntegrationTestFuns extends StrictLogging { List(dar1, dar2), jdbcConfig, staticContentConfig, - useTls = useTls) + useTls = useTls, + wsConfig = wsConfig) protected def withHttpService[A]( f: (Uri, DomainJsonEncoder, DomainJsonDecoder) => Future[A]): Future[A] = @@ -436,6 +439,21 @@ trait AbstractHttpServiceIntegrationTestFuns extends StrictLogging { case \/-(x) => x } } + + protected def initialIouCreate(serviceUri: Uri): Future[(StatusCode, JsValue)] = { + val payload = TestUtil.readFile("it/iouCreateCommand.json") + TestUtil.postJsonStringRequest( + serviceUri.withPath(Uri.Path("/v1/create")), + payload, + headersWithAuth) + } + + protected def initialAccountCreate( + serviceUri: Uri, + encoder: DomainJsonEncoder): Future[(StatusCode, JsValue)] = { + val command = accountCreateCommand(domain.Party("Alice"), "abc123") + postCreateCommand(command, encoder, serviceUri) + } } @SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements")) diff --git a/ledger-service/http-json/src/it/scala/http/HttpServiceIntegrationTest.scala b/ledger-service/http-json/src/it/scala/http/HttpServiceIntegrationTest.scala index 2a42dd6e7f..a8f493a64e 100644 --- a/ledger-service/http-json/src/it/scala/http/HttpServiceIntegrationTest.scala +++ b/ledger-service/http-json/src/it/scala/http/HttpServiceIntegrationTest.scala @@ -27,6 +27,8 @@ class HttpServiceIntegrationTest extends AbstractHttpServiceIntegrationTest with override def jdbcConfig: Option[JdbcConfig] = None + override def wsConfig: Option[WebsocketConfig] = None + private val expectedDummyContent: String = Gen .listOfN(100, Gen.identifier) .map(_.mkString(" ")) diff --git a/ledger-service/http-json/src/it/scala/http/HttpServiceTestFixture.scala b/ledger-service/http-json/src/it/scala/http/HttpServiceTestFixture.scala index bc4f4873cb..a0b3cb177e 100644 --- a/ledger-service/http-json/src/it/scala/http/HttpServiceTestFixture.scala +++ b/ledger-service/http-json/src/it/scala/http/HttpServiceTestFixture.scala @@ -48,7 +48,8 @@ object HttpServiceTestFixture { jdbcConfig: Option[JdbcConfig], staticContentConfig: Option[StaticContentConfig], leakPasswords: LeakPasswords = LeakPasswords.FiresheepStyle, - useTls: UseTls = UseTls.NoTls + useTls: UseTls = UseTls.NoTls, + wsConfig: Option[WebsocketConfig] = None, )(testFn: (Uri, DomainJsonEncoder, DomainJsonDecoder, LedgerClient) => Future[A])( implicit asys: ActorSystem, mat: Materializer, @@ -77,7 +78,7 @@ object HttpServiceTestFixture { httpPort = 0, portFile = None, tlsConfig = if (useTls) clientTlsConfig else noTlsConfig, - wsConfig = Some(Config.DefaultWsConfig), + wsConfig = wsConfig, accessTokenFile = None, allowNonHttps = leakPasswords, staticContentConfig = staticContentConfig, diff --git a/ledger-service/http-json/src/it/scala/http/HttpServiceWithPostgresIntTest.scala b/ledger-service/http-json/src/it/scala/http/HttpServiceWithPostgresIntTest.scala index 16aba05e3a..0a5daeecb9 100644 --- a/ledger-service/http-json/src/it/scala/http/HttpServiceWithPostgresIntTest.scala +++ b/ledger-service/http-json/src/it/scala/http/HttpServiceWithPostgresIntTest.scala @@ -18,6 +18,8 @@ class HttpServiceWithPostgresIntTest override def staticContentConfig: Option[StaticContentConfig] = None + override def wsConfig: Option[WebsocketConfig] = None + // has to be lazy because postgresFixture is NOT initialized yet private lazy val jdbcConfig_ = JdbcConfig( driver = "org.postgresql.Driver", diff --git a/ledger-service/http-json/src/it/scala/http/TlsTest.scala b/ledger-service/http-json/src/it/scala/http/TlsTest.scala index 96caf1f468..6979e5f8e0 100644 --- a/ledger-service/http-json/src/it/scala/http/TlsTest.scala +++ b/ledger-service/http-json/src/it/scala/http/TlsTest.scala @@ -23,6 +23,8 @@ class TlsTest override def useTls = UseTls.Tls + override def wsConfig: Option[WebsocketConfig] = None + "connect normally with tls on" in withHttpService { (uri: Uri, _, _) => getRequest(uri = uri.withPath(Uri.Path("/v1/query"))) .flatMap { diff --git a/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala b/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala index e52cc94ed3..3931f585c2 100644 --- a/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala +++ b/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala @@ -6,11 +6,13 @@ package com.daml.http import akka.NotUsed import akka.http.scaladsl.Http import akka.http.scaladsl.model.ws.{BinaryMessage, Message, TextMessage, WebSocketRequest} -import akka.http.scaladsl.model.{StatusCode, StatusCodes, Uri} +import akka.http.scaladsl.model.{StatusCodes, Uri} import akka.stream.scaladsl.{Flow, Keep, Sink, Source} -import com.daml.http.json.{DomainJsonEncoder, SprayJson} +import com.daml.http.json.SprayJson import com.daml.http.util.TestUtil import HttpServiceTestFixture.UseTls +import akka.actor.ActorSystem +import com.daml.jwt.domain.Jwt import com.typesafe.scalalogging.StrictLogging import org.scalacheck.Gen import org.scalatest._ @@ -36,7 +38,6 @@ class WebsocketServiceIntegrationTest with BeforeAndAfterAll { import WebsocketServiceIntegrationTest._ - import WebsocketEndpoints._ override def jdbcConfig: Option[JdbcConfig] = None @@ -44,6 +45,8 @@ class WebsocketServiceIntegrationTest override def useTls = UseTls.NoTls + override def wsConfig: Option[WebsocketConfig] = Some(Config.DefaultWsConfig) + private val baseQueryInput: Source[Message, NotUsed] = Source.single(TextMessage.Strict("""{"templateIds": ["Account:Account"]}""")) @@ -53,8 +56,6 @@ class WebsocketServiceIntegrationTest private val baseFetchInput: Source[Message, NotUsed] = Source.single(TextMessage.Strict(fetchRequest)) - private val validSubprotocol = Option(s"""$tokenPrefix${jwt.value},$wsProtocol""") - List( SimpleScenario("query", Uri.Path("/v1/stream/query"), baseQueryInput), SimpleScenario("fetch", Uri.Path("/v1/stream/fetch"), baseFetchInput) @@ -63,7 +64,7 @@ class WebsocketServiceIntegrationTest (uri, _, _) => wsConnectRequest( uri.copy(scheme = "ws").withPath(scenario.path), - validSubprotocol, + validSubprotocol(jwt), scenario.input)._1 flatMap (x => x.response.status shouldBe StatusCodes.SwitchingProtocols) } @@ -92,7 +93,7 @@ class WebsocketServiceIntegrationTest Http().webSocketClientFlow( WebSocketRequest( uri = uri.copy(scheme = "ws").withPath(scenario.path), - subprotocol = validSubprotocol)) + subprotocol = validSubprotocol(jwt))) input .via(webSocketFlow) .runWith(collectResultsAsTextMessageSkipOffsetTicks) @@ -126,7 +127,7 @@ class WebsocketServiceIntegrationTest Http().webSocketClientFlow( WebSocketRequest( uri = uri.copy(scheme = "ws").withPath(scenario.path), - subprotocol = validSubprotocol)) + subprotocol = validSubprotocol(jwt))) scenario.input .via(webSocketFlow) .runWith(collectResultsAsTextMessageSkipOffsetTicks) @@ -149,71 +150,12 @@ class WebsocketServiceIntegrationTest } } - private val collectResultsAsTextMessageSkipOffsetTicks: Sink[Message, Future[Seq[String]]] = - Flow[Message] - .collect { case m: TextMessage => m.getStrictText } - .filterNot(isOffsetTick) - .toMat(Sink.seq)(Keep.right) - - private val collectResultsAsTextMessage: Sink[Message, Future[Seq[String]]] = - Flow[Message] - .collect { case m: TextMessage => m.getStrictText } - .toMat(Sink.seq)(Keep.right) - - private def singleClientWSStream( - path: String, - serviceUri: Uri, - query: String, - offset: Option[domain.Offset]): Source[Message, NotUsed] = { - import spray.json._, json.JsonProtocol._ - val uri = serviceUri.copy(scheme = "ws").withPath(Uri.Path(s"/v1/stream/$path")) - logger.info( - s"---- singleClientWSStream uri: ${uri.toString}, query: $query, offset: ${offset.toString}") - val webSocketFlow = - Http().webSocketClientFlow(WebSocketRequest(uri = uri, subprotocol = validSubprotocol)) - offset - .cata( - off => - Source.fromIterator(() => - Seq(Map("offset" -> off.unwrap).toJson.compactPrint, query).iterator), - Source single query) - .map(TextMessage(_)) - .via(webSocketFlow) - } - - private def singleClientQueryStream( - serviceUri: Uri, - query: String, - offset: Option[domain.Offset] = None): Source[Message, NotUsed] = - singleClientWSStream("query", serviceUri, query, offset) - - private def singleClientFetchStream( - serviceUri: Uri, - request: String, - offset: Option[domain.Offset] = None): Source[Message, NotUsed] = - singleClientWSStream("fetch", serviceUri, request, offset) - - private def initialIouCreate(serviceUri: Uri) = { - val payload = TestUtil.readFile("it/iouCreateCommand.json") - TestUtil.postJsonStringRequest( - serviceUri.withPath(Uri.Path("/v1/create")), - payload, - headersWithAuth) - } - - private def initialAccountCreate( - serviceUri: Uri, - encoder: DomainJsonEncoder): Future[(StatusCode, JsValue)] = { - val command = accountCreateCommand(domain.Party("Alice"), "abc123") - postCreateCommand(command, encoder, serviceUri) - } - "query endpoint should publish transactions when command create is completed" in withHttpService { (uri, _, _) => for { _ <- initialIouCreate(uri) - clientMsg <- singleClientQueryStream(uri, """{"templateIds": ["Iou:Iou"]}""") + clientMsg <- singleClientQueryStream(jwt, uri, """{"templateIds": ["Iou:Iou"]}""") .runWith(collectResultsAsTextMessage) } yield inside(clientMsg) { @@ -229,7 +171,7 @@ class WebsocketServiceIntegrationTest for { _ <- initialAccountCreate(uri, encoder) - clientMsg <- singleClientFetchStream(uri, fetchRequest) + clientMsg <- singleClientFetchStream(jwt, uri, fetchRequest) .runWith(collectResultsAsTextMessage) } yield inside(clientMsg) { @@ -246,6 +188,7 @@ class WebsocketServiceIntegrationTest _ <- initialIouCreate(uri) clientMsg <- singleClientQueryStream( + jwt, uri, """{"templateIds": ["Iou:Iou", "Unknown:Template"]}""") .runWith(collectResultsAsTextMessage) @@ -263,6 +206,7 @@ class WebsocketServiceIntegrationTest _ <- initialAccountCreate(uri, encoder) clientMsg <- singleClientFetchStream( + jwt, uri, """[{"templateId": "Account:Account", "key": ["Alice", "abc123"]}, {"templateId": "Unknown:Template", "key": ["Alice", "abc123"]}]""") .runWith(collectResultsAsTextMessage) @@ -278,7 +222,7 @@ class WebsocketServiceIntegrationTest "query endpoint should send error msg when receiving malformed message" in withHttpService { (uri, _, _) => - val clientMsg = singleClientQueryStream(uri, "{}") + val clientMsg = singleClientQueryStream(jwt, uri, "{}") .runWith(collectResultsAsTextMessageSkipOffsetTicks) val result = Await.result(clientMsg, 10.seconds) @@ -291,7 +235,7 @@ class WebsocketServiceIntegrationTest "fetch endpoint should send error msg when receiving malformed message" in withHttpService { (uri, _, _) => - val clientMsg = singleClientFetchStream(uri, """[abcdefg!]""") + val clientMsg = singleClientFetchStream(jwt, uri, """[abcdefg!]""") .runWith(collectResultsAsTextMessageSkipOffsetTicks) val result = Await.result(clientMsg, 10.seconds) @@ -399,13 +343,13 @@ class WebsocketServiceIntegrationTest creation <- initialCreate _ = creation._1 shouldBe 'success iouCid = getContractId(getResult(creation._2)) - lastState <- singleClientQueryStream(uri, query) via parseResp runWith resp(iouCid) + lastState <- singleClientQueryStream(jwt, uri, query) via parseResp runWith resp(iouCid) liveOffset = inside(lastState) { case ShouldHaveEnded(liveStart, 2, lastSeen) => lastSeen.unwrap should be > liveStart.unwrap liveStart } - rescan <- (singleClientQueryStream(uri, query, Some(liveOffset)) + rescan <- (singleClientQueryStream(jwt, uri, query, Some(liveOffset)) via parseResp runWith remainingDeltas) } yield inside(rescan) { @@ -482,7 +426,7 @@ class WebsocketServiceIntegrationTest _ = r2._1 shouldBe 'success cid2 = getContractId(getResult(r2._2)) - lastState <- singleClientFetchStream(uri, fetchRequest()) + lastState <- singleClientFetchStream(jwt, uri, fetchRequest()) .via(parseResp) runWith resp(cid1, cid2) liveOffset = inside(lastState) { @@ -494,7 +438,7 @@ class WebsocketServiceIntegrationTest // check contractIdAtOffsets' effects on phantom filtering resumes <- Future.traverse(Seq((None, 2L), (Some(None), 0L), (Some(Some(cid1)), 1L))) { case (abcHint, expectArchives) => - (singleClientFetchStream(uri, fetchRequest(abcHint), Some(liveOffset)) + (singleClientFetchStream(jwt, uri, fetchRequest(abcHint), Some(liveOffset)) via parseResp runWith remainingDeltas) .map { case (creates, archives, _) => @@ -508,7 +452,7 @@ class WebsocketServiceIntegrationTest "fetch should should return an error if empty list of (templateId, key) pairs is passed" in withHttpService { (uri, _, _) => - singleClientFetchStream(uri, "[]") + singleClientFetchStream(jwt, uri, "[]") .runWith(collectResultsAsTextMessageSkipOffsetTicks) .map { clientMsgs => inside(clientMsgs) { @@ -530,7 +474,7 @@ class WebsocketServiceIntegrationTest """[ {"templateIds": ["Iou:Iou"]} ]""" - singleClientQueryStream(uri, query) + singleClientQueryStream(jwt, uri, query) .via(parseResp) .map(iouSplitResult) .filterNot(_ == \/-((Vector(), Vector()))) // liveness marker/heartbeat @@ -682,27 +626,13 @@ class WebsocketServiceIntegrationTest case \/-(eventsBlock) => eventsBlock.events shouldBe Vector.empty[JsValue] inside(eventsBlock.offset) { - case JsString(offset) => + case Some(JsString(offset)) => offset.length should be > 0 - case JsNull => + case Some(JsNull) => Succeeded } } - private def isOffsetTick(str: String): Boolean = - SprayJson - .decode[EventsBlock](str) - .map { b => - val isEmpty: Boolean = (b.events: Vector[JsValue]) == Vector.empty[JsValue] - val hasOffset: Boolean = b.offset match { - case JsString(offset) => offset.length > 0 - case JsNull => true - case _ => false - } - isEmpty && hasOffset - } - .valueOr(_ => false) - private def decodeErrorResponse(str: String): domain.ErrorResponse = { import json.JsonProtocol._ inside(SprayJson.decode[domain.ErrorResponse](str)) { @@ -718,8 +648,11 @@ class WebsocketServiceIntegrationTest } } -object WebsocketServiceIntegrationTest { +private[http] object WebsocketServiceIntegrationTest extends StrictLogging { import spray.json._ + import WebsocketEndpoints._ + + private def validSubprotocol(jwt: Jwt) = Option(s"""$tokenPrefix${jwt.value},$wsProtocol""") def dummyFlow[A](source: Source[A, NotUsed]): Flow[A, A, NotUsed] = Flow.fromSinkAndSource(Sink.foreach(println), source) @@ -793,10 +726,21 @@ object WebsocketServiceIntegrationTest { private object Archived extends JsoField("archived") private object MatchedQueries extends JsoField("matchedQueries") - private final case class EventsBlock(events: Vector[JsValue], offset: JsValue) - private object EventsBlock { + private[http] final case class EventsBlock(events: Vector[JsValue], offset: Option[JsValue]) + private[http] object EventsBlock { + import spray.json._ import DefaultJsonProtocol._ - implicit val EventsBlockFormat: RootJsonFormat[EventsBlock] = jsonFormat2(EventsBlock.apply) + + // cannot rely on default reader, offset: JsNull gets read as None, I want Some(JsNull) for LedgerBegin + implicit val EventsBlockReader: RootJsonReader[EventsBlock] = (json: JsValue) => { + val obj = json.asJsObject + val events = obj.fields("events").convertTo[Vector[JsValue]] + val offset: Option[JsValue] = obj.fields.get("offset").collect { + case x: JsString => x + case JsNull => JsNull + } + EventsBlock(events, offset) + } } type IouSplitResult = @@ -839,4 +783,85 @@ object WebsocketServiceIntegrationTest { ) else Gen const Leaf(x) } + + def singleClientQueryStream( + jwt: Jwt, + serviceUri: Uri, + query: String, + offset: Option[domain.Offset] = None)(implicit asys: ActorSystem): Source[Message, NotUsed] = + singleClientWSStream(jwt, "query", serviceUri, query, offset) + + def singleClientFetchStream( + jwt: Jwt, + serviceUri: Uri, + request: String, + offset: Option[domain.Offset] = None)(implicit asys: ActorSystem): Source[Message, NotUsed] = + singleClientWSStream(jwt, "fetch", serviceUri, request, offset) + + def singleClientWSStream( + jwt: Jwt, + path: String, + serviceUri: Uri, + query: String, + offset: Option[domain.Offset])(implicit asys: ActorSystem): Source[Message, NotUsed] = { + + import spray.json._, json.JsonProtocol._ + val uri = serviceUri.copy(scheme = "ws").withPath(Uri.Path(s"/v1/stream/$path")) + logger.info( + s"---- singleClientWSStream uri: ${uri.toString}, query: $query, offset: ${offset.toString}") + val webSocketFlow = + Http().webSocketClientFlow(WebSocketRequest(uri = uri, subprotocol = validSubprotocol(jwt))) + offset + .cata( + off => + Source.fromIterator(() => + Seq(Map("offset" -> off.unwrap).toJson.compactPrint, query).iterator), + Source single query) + .map(TextMessage(_)) + .via(webSocketFlow) + } + + val collectResultsAsTextMessageSkipOffsetTicks: Sink[Message, Future[Seq[String]]] = + Flow[Message] + .collect { case m: TextMessage => m.getStrictText } + .filterNot(isOffsetTick) + .toMat(Sink.seq)(Keep.right) + + val collectResultsAsTextMessage: Sink[Message, Future[Seq[String]]] = + Flow[Message] + .collect { case m: TextMessage => m.getStrictText } + .toMat(Sink.seq)(Keep.right) + + private def isOffsetTick(str: String): Boolean = + SprayJson + .decode[EventsBlock](str) + .map(isOffsetTick) + .valueOr(_ => false) + + def isOffsetTick(x: EventsBlock): Boolean = { + val hasOffset = x.offset + .collect { + case JsString(offset) => offset.length > 0 + case JsNull => true // JsNull is for LedgerBegin + } + .getOrElse(false) + + x.events.isEmpty && hasOffset + } + + def isAbsoluteOffsetTick(x: EventsBlock): Boolean = { + val hasAbsoluteOffset = x.offset + .collect { + case JsString(offset) => offset.length > 0 + } + .getOrElse(false) + + x.events.isEmpty && hasAbsoluteOffset + } + + def isAcs(x: EventsBlock): Boolean = + x.events.nonEmpty && x.offset.isEmpty + + def eventsBlockVector(msgs: Vector[String]): SprayJson.JsonReaderError \/ Vector[EventsBlock] = + msgs.traverse(SprayJson.decode[EventsBlock]) } diff --git a/ledger-service/http-json/src/it/scala/http/WebsocketServiceOffsetTickIntTest.scala b/ledger-service/http-json/src/it/scala/http/WebsocketServiceOffsetTickIntTest.scala new file mode 100644 index 0000000000..8f3caf58d5 --- /dev/null +++ b/ledger-service/http-json/src/it/scala/http/WebsocketServiceOffsetTickIntTest.scala @@ -0,0 +1,66 @@ +// Copyright (c) 2020 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +package com.daml.http + +import com.daml.http.HttpServiceTestFixture.UseTls +import com.typesafe.scalalogging.StrictLogging +import org.scalatest._ +import scalaz.\/- + +import scala.concurrent.duration._ + +@SuppressWarnings(Array("org.wartremover.warts.NonUnitStatements")) +class WebsocketServiceOffsetTickIntTest + extends AsyncFreeSpec + with Matchers + with Inside + with StrictLogging + with AbstractHttpServiceIntegrationTestFuns + with BeforeAndAfterAll { + + override def jdbcConfig: Option[JdbcConfig] = None + + override def staticContentConfig: Option[StaticContentConfig] = None + + override def useTls: UseTls = UseTls.NoTls + + // make sure websocket heartbeats non-stop, DO NOT CHANGE `0.second` + override def wsConfig: Option[WebsocketConfig] = + Some(Config.DefaultWsConfig.copy(heartBeatPer = 0.second)) + + import WebsocketServiceIntegrationTest._ + + "Given empty ACS, JSON API should emit only offset ticks" in withHttpService { (uri, _, _) => + for { + msgs <- singleClientQueryStream(jwt, uri, """{"templateIds": ["Iou:Iou"]}""") + .take(10) + .runWith(collectResultsAsTextMessage) + } yield { + inside(eventsBlockVector(msgs.toVector)) { + case \/-(offsetTicks) => + offsetTicks.forall(isOffsetTick) shouldBe true + offsetTicks should have length 10 + } + } + } + + "Given non-empty ACS, JSON API should emit ACS block and after it only absolute offset ticks" in withHttpService { + (uri, _, _) => + for { + _ <- initialIouCreate(uri) + + msgs <- singleClientQueryStream(jwt, uri, """{"templateIds": ["Iou:Iou"]}""") + .take(10) + .runWith(collectResultsAsTextMessage) + } yield { + inside(eventsBlockVector(msgs.toVector)) { + case \/-(acs +: offsetTicks) => + isAcs(acs) shouldBe true + acs.events should have length 1 + offsetTicks.forall(isAbsoluteOffsetTick) shouldBe true + offsetTicks should have length 9 + } + } + } +} diff --git a/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala b/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala index 4a7327527b..2c5eb744a1 100644 --- a/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala +++ b/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala @@ -4,16 +4,16 @@ package com.daml.http import akka.NotUsed -import akka.http.scaladsl.model.ws.{Message, TextMessage, BinaryMessage} -import akka.stream.scaladsl.{Flow, Source, Sink} +import akka.http.scaladsl.model.ws.{BinaryMessage, Message, TextMessage} +import akka.stream.scaladsl.{Flow, Sink, Source} import akka.stream.Materializer import com.daml.http.EndpointsCompanion._ import com.daml.http.domain.{JwtPayload, SearchForeverRequest} import com.daml.http.json.{DomainJsonDecoder, JsonProtocol, SprayJson} import com.daml.http.LedgerClientJwt.Terminates import util.ApiValueToLfValueConverter.apiValueToLfValue -import util.{AbsoluteBookmark, ContractStreamStep, InsertDeleteStep, LedgerBegin} -import ContractStreamStep.{Acs, LiveBegin, Txn} +import util.{BeginBookmark, ContractStreamStep, InsertDeleteStep} +import ContractStreamStep.LiveBegin import json.JsonProtocol.LfValueCodec.{apiValueToJsValue => lfValueToJsValue} import query.ValuePredicate.{LfV, TypeLookup} import com.daml.jwt.domain.Jwt @@ -281,6 +281,10 @@ object WebSocketService { request traverse (_.contractIdAtOffset) map NelO.toSet } } + + private abstract sealed class TickTriggerOrStep[+A] extends Product with Serializable + private final case object TickTrigger extends TickTriggerOrStep[Nothing] + private final case class Step[A](payload: StepAndErrors[A, JsValue]) extends TickTriggerOrStep[A] } class WebSocketService( @@ -388,7 +392,7 @@ class WebSocketService( contractsService .insertDeleteStepSource(jwt, party, resolved.toList, offPrefix, Terminates.Never) .via(convertFilterContracts(fn)) - .via(emitOffsetTicksAndFilterOutEmptySteps(offPrefix)) + .via(emitOffsetTicksAndFilterOutEmptySteps) .via(removePhantomArchives(remove = Q.removePhantomArchives(request))) .map(_.mapPos(Q.renderCreatedMetadata).render) .prepend(reportUnresolvedTemplateIds(unresolved)) @@ -400,41 +404,32 @@ class WebSocketService( } } - private def emitOffsetTicksAndFilterOutEmptySteps[Pos](startFrom: Option[domain.StartingOffset]) + private def emitOffsetTicksAndFilterOutEmptySteps[Pos] : Flow[StepAndErrors[Pos, JsValue], StepAndErrors[Pos, JsValue], NotUsed] = { - type TickTriggerOrStep = Unit \/ StepAndErrors[Pos, JsValue] + val zero = (Option.empty[BeginBookmark[domain.Offset]], TickTrigger: TickTriggerOrStep[Pos]) - val tickTrigger: TickTriggerOrStep = -\/(()) - val zeroState: StepAndErrors[Pos, JsValue] = startFrom.cata( - x => StepAndErrors(Seq(), LiveBegin(AbsoluteBookmark(x.offset))), - StepAndErrors(Seq(), LiveBegin(LedgerBegin)) - ) Flow[StepAndErrors[Pos, JsValue]] - .map(a => \/-(a): TickTriggerOrStep) - .keepAlive(config.heartBeatPer, () => tickTrigger) - .scan((zeroState, tickTrigger)) { - case ((state, _), -\/(())) => - // convert tick trigger into a tick message, get the last seen offset from the state - state.step match { - case Acs(_) => (ledgerBeginTick, \/-(ledgerBeginTick)) - case LiveBegin(LedgerBegin) => (ledgerBeginTick, \/-(ledgerBeginTick)) - case LiveBegin(AbsoluteBookmark(offset)) => (state, \/-(offsetTick(offset))) - case Txn(_, offset) => (state, \/-(offsetTick(offset))) - } - case ((_, _), x @ \/-(step)) => - // filter out empty steps, capture the current step, so we keep the last seen offset for the next tick - val nonEmptyStep: TickTriggerOrStep = if (step.nonEmpty) x else tickTrigger - (step, nonEmptyStep) + .map(a => Step(a)) + .keepAlive(config.heartBeatPer, () => TickTrigger) + .scan(zero) { + case ((None, _), TickTrigger) => + // skip all ticks we don't have the offset yet + (None, TickTrigger) + case ((Some(offset), _), TickTrigger) => + // emit an offset tick + (Some(offset), Step(offsetTick(offset))) + case ((_, _), msg @ Step(_)) => + // capture the new offset and emit the current step + val newOffset = msg.payload.step.bookmark + (newOffset, msg) } - .collect { case (_, \/-(x)) => x } + // filter non-empty Steps, we don't want to spam client with empty events + .collect { case (_, Step(x)) if x.nonEmpty => x } } - private def ledgerBeginTick[Pos] = - StepAndErrors[Pos, JsValue](Seq(), LiveBegin(LedgerBegin)) - - private def offsetTick[Pos](offset: domain.Offset) = - StepAndErrors[Pos, JsValue](Seq(), Txn(InsertDeleteStep.Empty, offset)) + private def offsetTick[Pos](offset: BeginBookmark[domain.Offset]) = + StepAndErrors[Pos, JsValue](Seq.empty, LiveBegin(offset)) private def removePhantomArchives[A, B](remove: Option[Set[domain.ContractId]]) = remove cata (removePhantomArchives_[A, B], Flow[StepAndErrors[A, B]])