ContractsService#search to return stream of ActiveContracts (#3625)

* Changing search so it returns `Source[ActiveContract]`

* wip

* function to convert Source[JsVal] into result Source[ByteString]

* compiles now

* cleanup

* tests passing

* test case for contracts search streaming error handling

* in memory search relying on acsFollowingAndBoundary, so we fetch all `following` updates

* Using ContractsFetch.acsFollowingAndBoundary

* Addressing code review comments, thanks @S11001001
This commit is contained in:
Leonid Shlyapnikov 2019-11-27 16:36:20 -05:00 committed by mergify[bot]
parent ae9bb35b85
commit f9f9672fd5
5 changed files with 294 additions and 115 deletions

View File

@ -37,7 +37,7 @@ import scalaz.syntax.show._
import scalaz.syntax.tag._
import scalaz.syntax.functor._
import scalaz.syntax.std.option._
import scalaz.{-\/, \/, \/-}
import scalaz.{-\/, Liskov, \/, \/-}
import spray.json.JsValue
import com.typesafe.scalalogging.StrictLogging
import scalaz.Liskov.<~<
@ -190,7 +190,7 @@ private class ContractsFetch(
}
}
private object ContractsFetch {
private[http] object ContractsFetch {
type Contract = domain.Contract[lav1.value.Value]
type PreInsertContract = DBContract[TemplateId.RequiredPkg, JsValue, Seq[domain.Party]]
@ -261,7 +261,7 @@ private object ContractsFetch {
* after the ACS's last offset, terminating with the last offset of the last transaction,
* or the ACS's last offset if there were no transactions.
*/
private def acsFollowingAndBoundary(
private[http] def acsFollowingAndBoundary(
transactionsSince: lav1.ledger_offset.LedgerOffset => Source[Transaction, NotUsed]): Graph[
FanOutShape2[
lav1.active_contracts_service.GetActiveContractsResponse,
@ -380,11 +380,16 @@ private object ContractsFetch {
}
final case class InsertDeleteStep[+C](inserts: Vector[C], deletes: Set[String]) {
@SuppressWarnings(Array("org.wartremover.warts.Any"))
def append[CC >: C](o: InsertDeleteStep[CC])(
implicit cid: CC <~< DBContract[Any, Any, Any]): InsertDeleteStep[CC] =
appendWithCid(o)(
Liskov.contra1_2[Function1, DBContract[Any, Any, Any], CC, String](cid)(_.contractId))
def appendWithCid[CC >: C](o: InsertDeleteStep[CC])(cid: CC => String): InsertDeleteStep[CC] =
InsertDeleteStep(
(if (o.deletes.isEmpty) inserts
else inserts.filter(c => !o.deletes.contains(cid(c).contractId))) ++ o.inserts,
else inserts.filter(c => !o.deletes.contains(cid(c)))) ++ o.inserts,
deletes union o.deletes)
}

View File

@ -3,18 +3,25 @@
package com.digitalasset.http
import akka.stream.Materializer
import akka.stream.scaladsl.Sink
import akka.NotUsed
import akka.stream.scaladsl._
import akka.stream.{Materializer, SourceShape}
import com.digitalasset.daml.lf
import com.digitalasset.daml.lf.value.{Value => V}
import com.digitalasset.http.ContractsFetch.{InsertDeleteStep, OffsetBookmark}
import com.digitalasset.http.domain.{GetActiveContractsRequest, JwtPayload, TemplateId}
import com.digitalasset.http.query.ValuePredicate
import com.digitalasset.http.query.ValuePredicate.LfV
import com.digitalasset.http.util.FutureUtil.toFuture
import com.digitalasset.http.util.IdentifierConverters.apiIdentifier
import com.digitalasset.http.util.{ApiValueToLfValueConverter, FutureUtil}
import com.digitalasset.jwt.domain.Jwt
import com.digitalasset.ledger.api.refinements.{ApiTypes => lar}
import com.digitalasset.ledger.api.{v1 => lav1}
import com.digitalasset.ledger.api.{v1 => api}
import com.typesafe.scalalogging.StrictLogging
import scalaz.syntax.show._
import scalaz.syntax.std.option._
import scalaz.syntax.traverse._
import scalaz.{-\/, Show, \/, \/-}
import spray.json.JsValue
@ -26,22 +33,19 @@ class ContractsService(
getCreatesAndArchivesSince: LedgerClientJwt.GetCreatesAndArchivesSince,
lookupType: query.ValuePredicate.TypeLookup,
contractDao: Option[dbbackend.ContractDao],
parallelism: Int = 8)(implicit ec: ExecutionContext, mat: Materializer) {
parallelism: Int = 8)(implicit ec: ExecutionContext, mat: Materializer)
extends StrictLogging {
import ContractsService._
type Result = (Seq[ActiveContract], CompiledPredicates)
type CompiledPredicates = Map[domain.TemplateId.RequiredPkg, query.ValuePredicate]
private val contractsFetch = contractDao.map { dao =>
new ContractsFetch(getActiveContracts, getCreatesAndArchivesSince, lookupType)(dao.logHandler)
}
def lookup(
jwt: Jwt,
jwtPayload: JwtPayload,
request: domain.ContractLookupRequest[lav1.value.Value])
: Future[Option[domain.ActiveContract[lav1.value.Value]]] =
def lookup(jwt: Jwt, jwtPayload: JwtPayload, request: domain.ContractLookupRequest[ApiValue])
: Future[Option[domain.ActiveContract[LfValue]]] =
request.id match {
case -\/((templateId, contractKey)) =>
lookup(jwt, jwtPayload.party, templateId, contractKey)
@ -53,64 +57,166 @@ class ContractsService(
jwt: Jwt,
party: lar.Party,
templateId: TemplateId.OptionalPkg,
contractKey: lav1.value.Value): Future[Option[ActiveContract]] =
contractKey: api.value.Value): Future[Option[domain.ActiveContract[LfValue]]] =
for {
(as, _) <- search(jwt, party, Set(templateId), Map.empty)
a = findByContractKey(contractKey)(as)
} yield a
private def findByContractKey(k: lav1.value.Value)(
as: Seq[ActiveContract]): Option[domain.ActiveContract[lav1.value.Value]] =
as.view.find(isContractKey(k))
lfKey <- FutureUtil.toFuture(apiValueToLfValue(contractKey)): Future[LfValue]
private def isContractKey(k: lav1.value.Value)(
a: domain.ActiveContract[lav1.value.Value]): Boolean =
errorOrAc <- search(jwt, party, Set(templateId), Map.empty)
.collect {
case e @ -\/(_) => e
case a @ \/-(ac) if isContractKey(lfKey)(ac) => a
}
.runWith(Sink.headOption): Future[Option[Error \/ domain.ActiveContract[LfValue]]]
result <- lookupResult(errorOrAc)
} yield result
private def isContractKey(k: LfValue)(a: domain.ActiveContract[LfValue]): Boolean =
a.key.fold(false)(_ == k)
def lookup(
jwt: Jwt,
party: lar.Party,
templateId: Option[TemplateId.OptionalPkg],
contractId: domain.ContractId): Future[Option[ActiveContract]] =
contractId: domain.ContractId): Future[Option[domain.ActiveContract[LfValue]]] =
for {
(as, _) <- search(jwt, party, templateIds(templateId), Map.empty)
a = findByContractId(contractId)(as)
} yield a
errorOrAc <- search(jwt, party, templateIds(templateId), Map.empty)
.collect {
case e @ -\/(_) => e
case a @ \/-(ac) if isContractId(contractId)(ac) => a
}
.runWith(Sink.headOption): Future[Option[Error \/ domain.ActiveContract[LfValue]]]
result <- lookupResult(errorOrAc)
} yield result
private def lookupResult(errorOrAc: Option[Error \/ domain.ActiveContract[LfValue]])
: Future[Option[domain.ActiveContract[LfValue]]] = {
errorOrAc.cata(x => toFuture(x).map(Some(_)), Future.successful(None))
}
private def templateIds(a: Option[TemplateId.OptionalPkg]): Set[TemplateId.OptionalPkg] =
a.toList.toSet
private def findByContractId(k: domain.ContractId)(
as: Seq[ActiveContract]): Option[ActiveContract] =
as.find(x => (x.contractId: domain.ContractId) == k)
private def isContractId(k: domain.ContractId)(a: domain.ActiveContract[LfValue]): Boolean =
(a.contractId: domain.ContractId) == k
def search(jwt: Jwt, jwtPayload: JwtPayload, request: GetActiveContractsRequest): Future[Result] =
def search(jwt: Jwt, jwtPayload: JwtPayload, request: GetActiveContractsRequest)
: Source[Error \/ domain.ActiveContract[LfValue], NotUsed] =
search(jwt, jwtPayload.party, request.templateIds, request.query)
def search(
jwt: Jwt,
party: lar.Party,
templateIds: Set[domain.TemplateId.OptionalPkg],
queryParams: Map[String, JsValue]): Future[Result] =
for {
templateIds <- toFuture(resolveTemplateIds(templateIds))
_ <- fetchAndPersistContracts(jwt, party, templateIds)
allActiveContracts <- getActiveContracts(jwt, transactionFilter(party, templateIds), true)
.mapAsyncUnordered(parallelism)(gacr => toFuture(activeContracts(gacr)))
.runWith(Sink.seq)
.map(_.flatten): Future[Seq[ActiveContract]]
predicates = templateIds.iterator.map(a => (a, valuePredicate(a, queryParams))).toMap
} yield (allActiveContracts, predicates)
queryParams: Map[String, JsValue])
: Source[Error \/ domain.ActiveContract[LfValue], NotUsed] = {
resolveTemplateIds(templateIds) match {
case -\/(e) => Source.single(-\/(Error('search, e.shows)))
case \/-(resolvedTemplateIds) =>
// TODO(Leo/Stephen): fetchAndPersistContracts should be removed once we have SQL search ready
val persistF: Future[Option[Unit]] =
fetchAndPersistContracts(jwt, party, resolvedTemplateIds)
// return source after we fetch and persist
val resultSourceF = persistF.map { _ =>
Source(resolvedTemplateIds)
.flatMapConcat(tpId => searchInMemory(jwt, party, tpId, queryParams))
}
Source.fromFutureSource(resultSourceF).mapMaterializedValue(_ => NotUsed)
}
}
private def searchInMemory(
jwt: Jwt,
party: lar.Party,
templateId: domain.TemplateId.RequiredPkg,
queryParams: Map[String, JsValue])
: Source[Error \/ domain.ActiveContract[LfValue], NotUsed] = {
val empty = InsertDeleteStep[api.event.CreatedEvent](Vector.empty, Set.empty)
def cid(a: api.event.CreatedEvent): String = a.contractId
def append(
a: InsertDeleteStep[api.event.CreatedEvent],
b: InsertDeleteStep[api.event.CreatedEvent]) = a.appendWithCid(b)(cid)
val predicate: ValuePredicate = valuePredicate(templateId, queryParams)
val funPredicate: LfV => Boolean = predicate.toFunPredicate
insertDeleteStepSource(jwt, party, templateId)
.fold(empty)(append)
.mapConcat(_.inserts)
.map { apiEvent =>
domain.ActiveContract
.fromLedgerApi(apiEvent)
.leftMap(e => Error('searchInMemory, e.shows))
.flatMap(apiAcToLfAc): Error \/ domain.ActiveContract[LfValue]
}
.collect(collectActiveContracts(funPredicate))
}
private def insertDeleteStepSource(
jwt: Jwt,
party: lar.Party,
templateId: domain.TemplateId.RequiredPkg)
: Source[InsertDeleteStep[api.event.CreatedEvent], NotUsed] = {
val graph = GraphDSL.create() { implicit b =>
import GraphDSL.Implicits._
val source = getActiveContracts(jwt, transactionFilter(party, List(templateId)), true)
val transactionsSince
: api.ledger_offset.LedgerOffset => Source[api.transaction.Transaction, NotUsed] =
getCreatesAndArchivesSince(
jwt,
transactionFilter(party, List(templateId)),
_: api.ledger_offset.LedgerOffset)
val contractsAndBoundary = b add ContractsFetch.acsFollowingAndBoundary(transactionsSince)
val offsetSink = b add Sink.foreach[OffsetBookmark[String]] { a =>
logger.debug(s"contracts fetch completed at: ${a.toString}")
}
source ~> contractsAndBoundary.in
contractsAndBoundary.out1 ~> offsetSink
new SourceShape(contractsAndBoundary.out0)
}
Source.fromGraph(graph)
}
@SuppressWarnings(Array("org.wartremover.warts.Any"))
private def apiAcToLfAc(
ac: domain.ActiveContract[ApiValue]): Error \/ domain.ActiveContract[LfValue] =
ac.traverse(ApiValueToLfValueConverter.apiValueToLfValue)
.leftMap(e => Error('apiAcToLfAc, e.shows))
private def apiValueToLfValue(a: ApiValue): Error \/ LfValue =
ApiValueToLfValueConverter.apiValueToLfValue(a).leftMap(e => Error('apiValueToLfValue, e.shows))
private def collectActiveContracts(predicate: LfValue => Boolean): PartialFunction[
Error \/ domain.ActiveContract[LfValue],
Error \/ domain.ActiveContract[LfValue]
] = {
case e @ -\/(_) => e
case a @ \/-(ac) if predicate(ac.argument) => a
}
private def fetchAndPersistContracts(
jwt: Jwt,
party: lar.Party,
templateIds: List[domain.TemplateId.RequiredPkg]): Future[Option[Unit]] = {
import scalaz.syntax.applicative._
import scalaz.syntax.traverse._
import scalaz.std.option._
import scalaz.std.scalaFuture._
import scalaz.syntax.applicative._
import scalaz.syntax.traverse._
val option: Option[Future[Unit]] = ^(contractDao, contractsFetch)((dao, fetch) =>
fetchAndPersistContracts(dao, fetch)(jwt, party, templateIds))
@ -124,10 +230,6 @@ class ContractsService(
templateIds: List[domain.TemplateId.RequiredPkg]): Future[Unit] =
dao.transact(fetch.contractsIo2(jwt, party, templateIds)).unsafeToFuture().map(_ => ())
private def activeContracts(gacr: lav1.active_contracts_service.GetActiveContractsResponse)
: Error \/ List[ActiveContract] =
domain.ActiveContract.fromLedgerApi(gacr).leftMap(e => Error('activeContracts, e.shows))
def filterSearch(
compiledPredicates: CompiledPredicates,
activeContracts: Seq[domain.ActiveContract[V[V.AbsoluteContractId]]])
@ -143,19 +245,21 @@ class ContractsService(
private def transactionFilter(
party: lar.Party,
templateIds: List[TemplateId.RequiredPkg]): lav1.transaction_filter.TransactionFilter = {
import lav1.transaction_filter._
templateIds: List[TemplateId.RequiredPkg]): api.transaction_filter.TransactionFilter = {
import api.transaction_filter._
val filters =
if (templateIds.isEmpty) Filters.defaultInstance
else Filters(Some(lav1.transaction_filter.InclusiveFilters(templateIds.map(apiIdentifier))))
else Filters(Some(api.transaction_filter.InclusiveFilters(templateIds.map(apiIdentifier))))
TransactionFilter(Map(lar.Party.unwrap(party) -> filters))
}
}
object ContractsService {
type ActiveContract = domain.ActiveContract[lav1.value.Value]
private type ApiValue = api.value.Value
private type LfValue = lf.value.Value[lf.value.Value.AbsoluteContractId]
case class Error(id: Symbol, message: String)

View File

@ -5,13 +5,14 @@ package com.digitalasset.http
import akka.http.scaladsl.model.HttpMethods.{GET, POST}
import akka.http.scaladsl.model._
import akka.NotUsed
import akka.http.scaladsl.model.headers.{Authorization, OAuth2BearerToken}
import akka.stream.Materializer
import akka.util.ByteString
import com.digitalasset.daml.lf
import com.digitalasset.http.Statement.discard
import com.digitalasset.http.domain.JwtPayload
import com.digitalasset.http.json.ResponseFormats._
import com.digitalasset.http.json.ResponseFormats
import com.digitalasset.http.json.{DomainJsonDecoder, DomainJsonEncoder, SprayJson}
import com.digitalasset.http.util.FutureUtil.{either, eitherT}
import com.digitalasset.http.util.{ApiValueToLfValueConverter, FutureUtil}
@ -31,6 +32,9 @@ import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ExecutionContext, Future}
import scala.util.control.NonFatal
import akka.stream.Materializer
import akka.stream.scaladsl.{Source, Flow}
@SuppressWarnings(Array("org.wartremover.warts.Any"))
class Endpoints(
ledgerId: lar.LedgerId,
@ -114,6 +118,15 @@ class Endpoints(
-\/(ServerError(e.getMessage))
}
private def handleSourceFailure[E: Show, A]: Flow[E \/ A, ServerError \/ A, NotUsed] =
Flow
.fromFunction((_: E \/ A).leftMap(e => ServerError(e.shows)))
.recover {
case NonFatal(e) =>
logger.error("Source failed", e)
-\/(ServerError(e.getMessage))
}
private def encodeList(as: Seq[JsValue]): ServerError \/ JsValue =
SprayJson.encode(as).leftMap(e => ServerError(e.shows))
@ -131,17 +144,14 @@ class Endpoints(
decoder
.decodeV[domain.ContractLookupRequest](reqBody)
.leftMap(e => InvalidUserInput(e.shows))
): ET[domain.ContractLookupRequest[lav1.value.Value]]
): ET[domain.ContractLookupRequest[ApiValue]]
ac <- eitherT(
handleFutureFailure(contractsService.lookup(jwt, jwtPayload, cmd))
): ET[Option[domain.ActiveContract[lav1.value.Value]]]
): ET[Option[domain.ActiveContract[LfValue]]]
jsVal <- either(
ac match {
case None => \/-(JsObject())
case Some(x) => encoder.encodeV(x).leftMap(e => ServerError(e.shows))
}
ac.cata(x => lfAcToJsValue(x).leftMap(e => ServerError(e.shows)), \/-(JsObject()))
): ET[JsValue]
} yield jsVal
@ -149,58 +159,41 @@ class Endpoints(
httpResponse(et)
case req @ HttpRequest(GET, Uri.Path("/contracts/search"), _, _, _) =>
val et: ET[JsValue] = for {
input <- FutureUtil.eitherT(input(req)): ET[(Jwt, JwtPayload, String)]
val sourceF: Future[Error \/ Source[JsValue, NotUsed]] = input(req).map {
_.map {
case (jwt, jwtPayload, _) =>
contractsService
.search(jwt, jwtPayload, emptyGetActiveContractsRequest)
.via(handleSourceFailure)
.map {
_.flatMap(lfAcToJsValue)
.fold(errorToJsValue, identity): JsValue
}: Source[JsValue, NotUsed]
}
}
(jwt, jwtPayload, _) = input
as <- eitherT(
handleFutureFailure(contractsService
.search(jwt, jwtPayload, emptyGetActiveContractsRequest))): ET[contractsService.Result]
jsVal <- either(
as._1.toList
.traverse(a => encoder.encodeV(a))
.leftMap(e => ServerError(e.shows))
.flatMap(js => encodeList(js))
): ET[JsValue]
} yield jsVal
httpResponse(et)
httpResponse(sourceF)
case req @ HttpRequest(POST, Uri.Path("/contracts/search"), _, _, _) =>
val et: ET[JsValue] = for {
input <- FutureUtil.eitherT(input(req)): ET[(Jwt, JwtPayload, String)]
val sourceF: Future[Error \/ Source[JsValue, NotUsed]] = input(req).map {
_.flatMap {
case (jwt, jwtPayload, reqBody) =>
SprayJson
.decode[domain.GetActiveContractsRequest](reqBody)
.leftMap(e => InvalidUserInput(e.shows))
.map { cmd =>
contractsService
.search(jwt, jwtPayload, cmd)
.via(handleSourceFailure)
.map {
_.flatMap(lfAcToJsValue)
.fold(errorToJsValue, identity): JsValue
}: Source[JsValue, NotUsed]
}
}
}
(jwt, jwtPayload, reqBody) = input
cmd <- either(
SprayJson
.decode[domain.GetActiveContractsRequest](reqBody)
.leftMap(e => InvalidUserInput(e.shows))
): ET[domain.GetActiveContractsRequest]
as <- eitherT(
handleFutureFailure(contractsService.search(jwt, jwtPayload, cmd))
): ET[contractsService.Result]
xs <- either(
as._1.toList.traverse(_.traverse(v => apValueToLfValue(v)))
): ET[List[domain.ActiveContract[LfValue]]]
ys = contractsService
.filterSearch(as._2, xs): Seq[domain.ActiveContract[LfValue]]
js <- either(
ys.toList.traverse(_.traverse(v => lfValueToJsValue(v)))
): ET[Seq[domain.ActiveContract[JsValue]]]
j <- either(SprayJson.encode(js).leftMap(e => ServerError(e.shows))): ET[JsValue]
} yield j
httpResponse(et)
httpResponse(sourceF)
}
lazy val parties: PartialFunction[HttpRequest, Future[HttpResponse]] = {
@ -214,13 +207,31 @@ class Endpoints(
httpResponse(et)
}
private def apValueToLfValue(a: ApiValue): Error \/ LfValue =
private def apiValueToLfValue(a: ApiValue): Error \/ LfValue =
ApiValueToLfValueConverter.apiValueToLfValue(a).leftMap(e => ServerError(e.shows))
private def lfValueToJsValue(a: LfValue): Error \/ JsValue =
\/.fromTryCatchNonFatal(LfValueCodec.apiValueToJsValue(a)).leftMap(e =>
ServerError(e.getMessage))
private def collectActiveContracts(
predicates: Map[domain.TemplateId.RequiredPkg, LfValue => Boolean]): PartialFunction[
Error \/ domain.ActiveContract[LfValue],
Error \/ domain.ActiveContract[LfValue]
] = {
case e @ -\/(_) => e
case a @ \/-(ac) if predicates.get(ac.templateId).forall(f => f(ac.argument)) => a
}
private def errorToJsValue(e: Error): JsValue = errorsJsObject(e)._2
private def lfAcToJsValue(a: domain.ActiveContract[LfValue]): Error \/ JsValue = {
for {
b <- a.traverse(lfValueToJsValue): Error \/ domain.ActiveContract[JsValue]
c <- SprayJson.encode(b).leftMap(e => ServerError(e.shows))
} yield c
}
private def httpResponse(output: ET[JsValue]): Future[HttpResponse] = {
val fa: Future[Error \/ JsValue] = output.run
fa.map {
@ -232,18 +243,33 @@ class Endpoints(
}
}
private def httpResponse(
output: Future[Error \/ Source[JsValue, NotUsed]]): Future[HttpResponse] =
output
.map {
case \/-(source) => httpResponseFromSource(StatusCodes.OK, source)
case -\/(e) => httpResponseError(e)
}
.recover {
case NonFatal(e) => httpResponseError(ServerError(e.getMessage))
}
private def httpResponseOk(data: JsValue): HttpResponse =
httpResponse(StatusCodes.OK, resultJsObject(data))
httpResponse(StatusCodes.OK, ResponseFormats.resultJsObject(data))
private def httpResponseError(error: Error): HttpResponse = {
val (status, jsObject) = errorsJsObject(error)
httpResponse(status, jsObject)
}
private def errorsJsObject(error: Error): (StatusCode, JsObject) = {
val (status, errorMsg): (StatusCode, String) = error match {
case InvalidUserInput(e) => StatusCodes.BadRequest -> e
case ServerError(e) => StatusCodes.InternalServerError -> e
case Unauthorized(e) => StatusCodes.Unauthorized -> e
case NotFound(e) => StatusCodes.NotFound -> e
}
httpResponse(status, errorsJsObject(status, errorMsg))
(status, ResponseFormats.errorsJsObject(status, errorMsg))
}
private def httpResponse(status: StatusCode, data: JsValue): HttpResponse = {
@ -252,6 +278,15 @@ class Endpoints(
entity = HttpEntity.Strict(ContentTypes.`application/json`, format(data)))
}
private def httpResponseFromSource(
status: StatusCode,
data: Source[JsValue, NotUsed]): HttpResponse =
HttpResponse(
status = status,
entity = HttpEntity
.CloseDelimited(ContentTypes.`application/json`, ResponseFormats.resultJsObject(data))
)
lazy val notFound: PartialFunction[HttpRequest, Future[HttpResponse]] = {
case HttpRequest(method, uri, _, _, _) =>
Future.successful(httpResponseError(NotFound(s"${method: HttpMethod}, uri: ${uri: Uri}")))
@ -296,6 +331,8 @@ object Endpoints {
private type LfValue = lf.value.Value[lf.value.Value.AbsoluteContractId]
private type ActiveContractStream[A] = Source[A, NotUsed]
type ValidateJwt = Jwt => Unauthorized \/ DecodedJwt[String]
sealed abstract class Error(message: String) extends Product with Serializable

View File

@ -3,7 +3,10 @@
package com.digitalasset.http.json
import akka.NotUsed
import akka.http.scaladsl.model._
import akka.stream.scaladsl.{Concat, Source}
import akka.util.ByteString
import spray.json._
import spray.json.DefaultJsonProtocol._
@ -21,6 +24,21 @@ private[http] object ResponseFormats {
JsObject(statusField(StatusCodes.OK), ("result", a))
}
private val start: Source[ByteString, NotUsed] =
Source.single(ByteString("""{"status":200,"result":["""))
private val end: Source[ByteString, NotUsed] = Source.single(ByteString("]}"))
def resultJsObject(jsVals: Source[JsValue, NotUsed]): Source[ByteString, NotUsed] = {
val csv: Source[ByteString, NotUsed] = jsVals.zipWithIndex.map {
case (a, i) =>
if (i == 0L) ByteString(a.compactPrint)
else ByteString("," + a.compactPrint)
}
Source.combine(start, csv, end)(Concat.apply)
}
def statusField(status: StatusCode): (String, JsNumber) =
("status", JsNumber(status.intValue()))
}

View File

@ -168,6 +168,15 @@ abstract class AbstractHttpServiceIntegrationTest
}
}
"contracts/search with invalid JSON query should return error" in withHttpService { (uri, _, _) =>
postJsonStringRequest(uri.withPath(Uri.Path("/contracts/search")), "{NOT A VALID JSON OBJECT")
.flatMap {
case (status, output) =>
status shouldBe StatusCodes.BadRequest
assertStatus(output, StatusCodes.BadRequest)
}: Future[Assertion]
}
protected def jsObject(s: String): JsObject = {
val r: JsonError \/ JsObject = for {
jsVal <- SprayJson.parse(s).leftMap(e => JsonError(e.shows))
@ -491,18 +500,18 @@ abstract class AbstractHttpServiceIntegrationTest
domain.ExerciseCommand(templateId, contractId, choice, arg, None)
}
private def postJsonRequest(
private def postJsonStringRequest(
uri: Uri,
json: JsValue,
jsonString: String,
headers: List[HttpHeader] = headersWithAuth): Future[(StatusCode, JsValue)] = {
logger.info(s"postJson: $uri json: $json")
logger.info(s"postJson: ${uri.toString} json: ${jsonString: String}")
Http()
.singleRequest(
HttpRequest(
method = HttpMethods.POST,
uri = uri,
headers = headers,
entity = HttpEntity(ContentTypes.`application/json`, json.prettyPrint))
entity = HttpEntity(ContentTypes.`application/json`, jsonString))
)
.flatMap { resp =>
val bodyF: Future[String] = getResponseDataBytes(resp, debug = true)
@ -510,6 +519,12 @@ abstract class AbstractHttpServiceIntegrationTest
}
}
private def postJsonRequest(
uri: Uri,
json: JsValue,
headers: List[HttpHeader] = headersWithAuth): Future[(StatusCode, JsValue)] =
postJsonStringRequest(uri, json.prettyPrint, headers)
private def getRequest(
uri: Uri,
headers: List[HttpHeader] = headersWithAuth): Future[(StatusCode, JsValue)] = {