diff --git a/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala b/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala index 3931f585c2..edd0e35ff2 100644 --- a/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala +++ b/ledger-service/http-json/src/it/scala/http/WebsocketServiceIntegrationTest.scala @@ -450,6 +450,83 @@ class WebsocketServiceIntegrationTest } yield resumes.foldLeft(1 shouldBe 1)((_, a) => a) } + "fetch multiple keys should work" in withHttpService { (uri, encoder, _) => + def matches(expected: Seq[JsValue], actual: Seq[JsValue]): Boolean = + expected.length == actual.length && (expected, actual).zipped.forall { + case (exp, act) => matchesJs(exp, act) + } + // matches if all the values specified in expected appear with the same + // value in actual; actual is allowed to have extra fields. Arrays must + // have the same length. + def matchesJs(expected: spray.json.JsValue, actual: spray.json.JsValue): Boolean = { + import spray.json._ + (expected, actual) match { + case (JsArray(expected), JsArray(actual)) => + expected.length == actual.length && matches(expected, actual) + case (JsObject(expected), JsObject(actual)) => + expected.keys.forall(k => matchesJs(expected(k), actual(k))) + case (JsString(expected), JsString(actual)) => expected == actual + case (JsNumber(expected), JsNumber(actual)) => expected == actual + case (JsBoolean(expected), JsBoolean(actual)) => expected == actual + case (JsNull, JsNull) => true + case _ => false + } + } + def create(account: String): Future[domain.ContractId] = + for { + r <- postCreateCommand(accountCreateCommand(domain.Party("Alice"), account), encoder, uri) + } yield { + assert(r._1.isSuccess) + getContractId(getResult(r._2)) + } + def archive(id: domain.ContractId): Future[Assertion] = + for { + r <- postArchiveCommand(domain.TemplateId(None, "Account", "Account"), id, encoder, uri) + } yield { + assert(r._1.isSuccess) + } + val req = + """ + |[{"templateId": "Account:Account", "key": ["Alice", "abc123"]}, + | {"templateId": "Account:Account", "key": ["Alice", "def456"]}] + |""".stripMargin + val futureResults = + singleClientFetchStream(jwt, uri, req).via(parseResp).runWith(Sink.seq[JsValue]) + + for { + cid1 <- create("abc123") + _ <- create("abc124") + _ <- create("abc125") + cid2 <- create("def456") + _ <- archive(cid2) + _ <- archive(cid1) + results <- futureResults + } yield { + val expected: Seq[JsValue] = { + import spray.json._ + Seq( + """ + |{"events": []} + |""".stripMargin.parseJson, + """ + |{"events":[{"created":{"payload":{"number":"abc123"}}}]} + |""".stripMargin.parseJson, + """ + |{"events":[{"created":{"payload":{"number":"def456"}}}]} + |""".stripMargin.parseJson, + """ + |{"events":[{"archived":{}}]} + |""".stripMargin.parseJson, + """ + |{"events":[{"archived":{}}]} + |""".stripMargin.parseJson + ) + } + assert(matches(expected, results)) + + } + } + "fetch should should return an error if empty list of (templateId, key) pairs is passed" in withHttpService { (uri, _, _) => singleClientFetchStream(jwt, uri, "[]") diff --git a/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala b/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala index 2c5eb744a1..a06e008047 100644 --- a/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala +++ b/ledger-service/http-json/src/main/scala/com/digitalasset/http/WebSocketService.scala @@ -33,6 +33,7 @@ import com.daml.http.util.FlowUtil.allowOnlyFirstInput import spray.json.{JsArray, JsObject, JsValue, JsonReader} import scala.collection.compat._ +import scala.collection.mutable.HashSet import scala.concurrent.{ExecutionContext, Future} import scala.util.{Failure, Success} @@ -257,11 +258,18 @@ object WebSocketService { .toLeftDisjunction(x.ekey.templateId) } - val q: Map[domain.TemplateId.RequiredPkg, LfV] = resolvedWithKey.toMap + val q: Map[domain.TemplateId.RequiredPkg, HashSet[LfV]] = + resolvedWithKey.foldLeft(Map.empty[domain.TemplateId.RequiredPkg, HashSet[LfV]])( + (acc, el) => + acc.get(el._1) match { + case Some(v) => acc.updated(el._1, v += el._2) + case None => acc.updated(el._1, HashSet(el._2)) + }) val fn: domain.ActiveContract[LfV] => Option[Positive] = { a => - if (q.get(a.templateId).exists(k => domain.ActiveContract.matchesKey(k)(a))) - Some(()) - else None + a.key match { + case None => None + case Some(k) => if (q.getOrElse(a.templateId, HashSet()).contains(k)) Some(()) else None + } } (q.keySet, unresolved, fn) }