Drop gRPC server side custom codegen (#15763)

* Drop gRPC server side custom codegen

[CHANGELOG_BEGIN]
[CHANGELOG_END]

* Do not generate maven jar for ledger-api-akka and more cleanup

* Extract by-name source evaluation outside synchronized

* Addressed Martino's review comment

Co-authored-by: Tudor Voicu <tudor.voicu@digitalasset.com>
This commit is contained in:
Marton Nagy 2022-12-12 08:35:56 +01:00 committed by GitHub
parent 393dda5578
commit 6d27505010
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 108 additions and 472 deletions

View File

@ -52,8 +52,4 @@ object ServerAdapter {
promise.future
})
}
/** Used in [[com.daml.protoc.plugins.akka.AkkaGrpcServicePrinter]] */
def closingError(): StatusRuntimeException =
CommonErrors.ServerIsShuttingDown.Reject()(errorLogger).asGrpcError
}

View File

@ -6,7 +6,10 @@ package com.daml.grpc.adapter.server.akka
import akka.NotUsed
import akka.stream.scaladsl.{Keep, Source}
import akka.stream.{KillSwitch, KillSwitches, Materializer}
import com.daml.error.ContextualizedErrorLogger
import com.daml.error.definitions.CommonErrors
import com.daml.grpc.adapter.ExecutionSequencerFactory
import io.grpc.StatusRuntimeException
import io.grpc.stub.StreamObserver
import scala.collection.concurrent.TrieMap
@ -16,31 +19,33 @@ trait StreamingServiceLifecycleManagement extends AutoCloseable {
@volatile private var _closed = false
private val _killSwitches = TrieMap.empty[KillSwitch, Object]
protected val contextualizedErrorLogger: ContextualizedErrorLogger
def close(): Unit = synchronized {
if (!_closed) {
_closed = true
_killSwitches.keySet.foreach(_.abort(ServerAdapter.closingError()))
_killSwitches.keySet.foreach(_.abort(closingError(contextualizedErrorLogger)))
_killSwitches.clear()
}
}
protected def registerStream[RespT](
buildSource: () => Source[RespT, NotUsed],
responseObserver: StreamObserver[RespT],
)(implicit
responseObserver: StreamObserver[RespT]
)(createSource: => Source[RespT, NotUsed])(implicit
materializer: Materializer,
executionSequencerFactory: ExecutionSequencerFactory,
): Unit = {
def ifNotClosed(run: () => Unit): Unit =
if (_closed) responseObserver.onError(ServerAdapter.closingError())
if (_closed) responseObserver.onError(closingError(contextualizedErrorLogger))
else run()
// Double-checked locking to keep the (potentially expensive)
// by-name `source` evaluation out of the synchronized block
ifNotClosed { () =>
val sink = ServerAdapter.toSink(responseObserver)
val source = buildSource()
// Force evaluation before synchronized block
val source = createSource
// Double-checked locking to keep the (potentially expensive)
// buildSource() step out of the synchronized block
synchronized {
ifNotClosed { () =>
val (killSwitch, doneF) = source
@ -58,4 +63,7 @@ trait StreamingServiceLifecycleManagement extends AutoCloseable {
}
}
}
def closingError(errorLogger: ContextualizedErrorLogger): StatusRuntimeException =
CommonErrors.ServerIsShuttingDown.Reject()(errorLogger).asGrpcError
}

View File

@ -23,17 +23,9 @@ proto_gen(
],
)
proto_gen(
name = "sample-service-akka-sources",
srcs = [":sample-service-proto"],
plugin_exec = "//scala-protoc-plugins/scala-akka:protoc-gen-scala-akka",
plugin_name = "scala-akka",
)
scala_library(
name = "sample-service-scalapb",
srcs = [
":sample-service-akka-sources",
":sample-service-scalapb-sources",
],
unused_dependency_checker_mode = "error",
@ -42,16 +34,12 @@ scala_library(
"@maven//:io_grpc_grpc_api",
"@maven//:io_grpc_grpc_protobuf",
"@maven//:io_grpc_grpc_stub",
"//ledger-api/rs-grpc-bridge",
"//ledger-api/rs-grpc-akka",
] + [
"{}_{}".format(dep, scala_major_version_suffix)
for dep in [
"@maven//:com_thesamet_scalapb_lenses",
"@maven//:com_thesamet_scalapb_scalapb_runtime",
"@maven//:com_thesamet_scalapb_scalapb_runtime_grpc",
"@maven//:com_typesafe_akka_akka_actor",
"@maven//:com_typesafe_akka_akka_stream",
]
],
)

View File

@ -1,34 +0,0 @@
# Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load("//bazel_tools:scala.bzl", "da_scala_library")
load("//bazel_tools:proto.bzl", "proto_gen")
proto_gen(
name = "ledger-api-akka-srcs",
srcs = [
"//ledger-api/grpc-definitions:ledger_api_proto",
"@com_github_grpc_grpc//src/proto/grpc/health/v1:health_proto_descriptor",
],
plugin_exec = "//scala-protoc-plugins/scala-akka:protoc-gen-scala-akka",
plugin_name = "scala-akka",
)
da_scala_library(
name = "ledger-api-akka",
srcs = [":ledger-api-akka-srcs"],
scala_deps = [
"@maven//:com_typesafe_akka_akka_actor",
"@maven//:com_typesafe_akka_akka_stream",
],
tags = ["maven_coordinates=com.daml:ledger-api-akka:__VERSION__"],
visibility = [
"//visibility:public",
],
deps = [
"//ledger-api/grpc-definitions:ledger_api_proto_scala",
"//ledger-api/rs-grpc-akka",
"//ledger-api/rs-grpc-bridge",
"@maven//:io_grpc_grpc_stub",
],
)

View File

@ -1,3 +0,0 @@
# bindings-akka
This project is composed entirely of generated code.

View File

@ -37,7 +37,6 @@ da_scala_library(
"//ledger-api/rs-grpc-akka",
"//ledger-api/rs-grpc-bridge",
"//ledger/error",
"//ledger/ledger-api-akka",
"//ledger/ledger-api-domain",
"//ledger/ledger-api-errors",
"//ledger/ledger-api-health",
@ -125,7 +124,6 @@ da_scala_test_suite(
"//ledger-api/testing-utils",
"//ledger/error",
"//ledger/error:error-test-lib",
"//ledger/ledger-api-akka",
"//ledger/ledger-api-client",
"//ledger/ledger-api-domain",
"//ledger/ledger-api-errors",

View File

@ -7,12 +7,14 @@ import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.error.DamlContextualizedErrorLogger
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.ledger.api.v1.command_completion_service._
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset
import com.daml.ledger.api.validation.CompletionServiceRequestValidator
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.platform.server.api.ValidationLogger
import com.daml.platform.server.api.services.domain.CommandCompletionService
import io.grpc.stub.StreamObserver
import scala.concurrent.{ExecutionContext, Future}
@ -20,19 +22,21 @@ class GrpcCommandCompletionService(
service: CommandCompletionService,
validator: CompletionServiceRequestValidator,
)(implicit
protected val mat: Materializer,
protected val esf: ExecutionSequencerFactory,
mat: Materializer,
esf: ExecutionSequencerFactory,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends CommandCompletionServiceAkkaGrpc {
) extends CommandCompletionServiceGrpc.CommandCompletionService
with StreamingServiceLifecycleManagement {
protected implicit val logger: ContextualizedLogger = ContextualizedLogger.get(getClass)
private implicit val contextualizedErrorLogger: DamlContextualizedErrorLogger =
protected implicit val contextualizedErrorLogger: DamlContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
override def completionStreamSource(
request: CompletionStreamRequest
): Source[CompletionStreamResponse, akka.NotUsed] = {
def completionStream(
request: CompletionStreamRequest,
responseObserver: StreamObserver[CompletionStreamResponse],
): Unit = registerStream(responseObserver) {
validator
.validateGrpcCompletionStreamRequest(request)
.fold(

View File

@ -3,11 +3,11 @@
package com.daml.platform.server.api.services.grpc
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.ledger.api.health.HealthChecks
import com.daml.ledger.api.validation.ValidationErrors.invalidArgument
import com.daml.logging.{ContextualizedLogger, LoggingContext}
@ -15,12 +15,8 @@ import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.server.api.DropRepeated
import com.daml.platform.server.api.services.grpc.GrpcHealthService._
import io.grpc.ServerServiceDefinition
import io.grpc.health.v1.health.{
HealthAkkaGrpc,
HealthCheckRequest,
HealthCheckResponse,
HealthGrpc,
}
import io.grpc.health.v1.health.{HealthCheckRequest, HealthCheckResponse, HealthGrpc}
import io.grpc.stub.StreamObserver
import scala.concurrent.duration.{DurationInt, FiniteDuration}
import scala.concurrent.{ExecutionContext, Future}
@ -30,15 +26,16 @@ class GrpcHealthService(
healthChecks: HealthChecks,
maximumWatchFrequency: FiniteDuration = 1.second,
)(implicit
protected val esf: ExecutionSequencerFactory,
protected val mat: Materializer,
esf: ExecutionSequencerFactory,
mat: Materializer,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends HealthAkkaGrpc
) extends HealthGrpc.Health
with StreamingServiceLifecycleManagement
with GrpcApiService {
private val logger = ContextualizedLogger.get(getClass)
private val errorLogger: ContextualizedErrorLogger =
protected val contextualizedErrorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
override def bindService(): ServerServiceDefinition =
@ -47,18 +44,22 @@ class GrpcHealthService(
override def check(request: HealthCheckRequest): Future[HealthCheckResponse] =
Future.fromTry(matchResponse(serviceFrom(request)))
override def watchSource(request: HealthCheckRequest): Source[HealthCheckResponse, NotUsed] =
override def watch(
request: HealthCheckRequest,
responseObserver: StreamObserver[HealthCheckResponse],
): Unit = registerStream(responseObserver) {
Source
.fromIterator(() => Iterator.continually(matchResponse(serviceFrom(request)).get))
.throttle(1, per = maximumWatchFrequency)
.via(DropRepeated())
}
private def matchResponse(componentName: Option[String]): Try[HealthCheckResponse] =
componentName
.collect {
case component if !healthChecks.hasComponent(component) =>
Failure(
invalidArgument(s"Component $component does not exist.")(errorLogger)
invalidArgument(s"Component $component does not exist.")(contextualizedErrorLogger)
)
}
.getOrElse {

View File

@ -3,11 +3,11 @@
package com.daml.platform.server.api.services.grpc
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.ledger.api.domain.LedgerId
import com.daml.ledger.api.v1.ledger_offset.LedgerOffset
import com.daml.ledger.api.v1.transaction_service._
@ -18,6 +18,7 @@ import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.server.api.ValidationLogger
import com.daml.platform.server.api.services.domain.TransactionService
import io.grpc.ServerServiceDefinition
import io.grpc.stub.StreamObserver
import scala.concurrent.{ExecutionContext, Future}
@ -26,23 +27,25 @@ final class GrpcTransactionService(
val ledgerId: LedgerId,
partyNameChecker: PartyNameChecker,
)(implicit
protected val esf: ExecutionSequencerFactory,
protected val mat: Materializer,
esf: ExecutionSequencerFactory,
mat: Materializer,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends TransactionServiceAkkaGrpc
) extends TransactionServiceGrpc.TransactionService
with StreamingServiceLifecycleManagement
with GrpcApiService {
protected implicit val logger: ContextualizedLogger = ContextualizedLogger.get(getClass)
private implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
private implicit val logger: ContextualizedLogger = ContextualizedLogger.get(getClass)
protected implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
private val validator =
new TransactionServiceRequestValidator(ledgerId, partyNameChecker)
override protected def getTransactionsSource(
request: GetTransactionsRequest
): Source[GetTransactionsResponse, NotUsed] = {
def getTransactions(
request: GetTransactionsRequest,
responseObserver: StreamObserver[GetTransactionsResponse],
): Unit = registerStream(responseObserver) {
logger.debug(s"Received new transaction request $request")
Source.future(service.getLedgerEnd(request.ledgerId)).flatMapConcat { ledgerEnd =>
val validation = validator.validate(request, ledgerEnd)
@ -56,9 +59,10 @@ final class GrpcTransactionService(
}
}
override protected def getTransactionTreesSource(
request: GetTransactionsRequest
): Source[GetTransactionTreesResponse, NotUsed] = {
def getTransactionTrees(
request: GetTransactionsRequest,
responseObserver: StreamObserver[GetTransactionTreesResponse],
): Unit = registerStream(responseObserver) {
logger.debug(s"Received new transaction tree request $request")
Source.future(service.getLedgerEnd(request.ledgerId)).flatMapConcat { ledgerEnd =>
val validation = validator.validateTree(request, ledgerEnd)

View File

@ -45,7 +45,6 @@ compile_deps = [
"//ledger/caching",
"//ledger/error",
"//ledger/ledger-api-errors",
"//ledger/ledger-api-akka",
"//ledger/ledger-api-auth",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
@ -284,7 +283,6 @@ da_scala_test_suite(
"//ledger/caching",
"//ledger/error",
"//ledger/error:error-test-lib",
"//ledger/ledger-api-akka",
"//ledger/ledger-api-client",
"//ledger/ledger-api-common",
"//ledger/ledger-api-common:ledger-api-common-scala-tests-lib",

View File

@ -3,11 +3,11 @@
package com.daml.platform.apiserver.services
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.ledger.api.domain.LedgerId
import com.daml.ledger.api.v1.active_contracts_service.ActiveContractsServiceGrpc.ActiveContractsService
import com.daml.ledger.api.v1.active_contracts_service._
@ -19,6 +19,7 @@ import com.daml.metrics.Metrics
import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.server.api.ValidationLogger
import com.daml.platform.server.api.validation.ActiveContractsServiceValidation
import io.grpc.stub.StreamObserver
import io.grpc.{BindableService, ServerServiceDefinition}
import scala.concurrent.ExecutionContext
@ -27,20 +28,22 @@ private[apiserver] final class ApiActiveContractsService private (
backend: ACSBackend,
metrics: Metrics,
)(implicit
protected val mat: Materializer,
protected val esf: ExecutionSequencerFactory,
mat: Materializer,
esf: ExecutionSequencerFactory,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends ActiveContractsServiceAkkaGrpc
) extends ActiveContractsServiceGrpc.ActiveContractsService
with StreamingServiceLifecycleManagement
with GrpcApiService {
private implicit val logger: ContextualizedLogger = ContextualizedLogger.get(this.getClass)
private implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
protected implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
override protected def getActiveContractsSource(
request: GetActiveContractsRequest
): Source[GetActiveContractsResponse, NotUsed] =
def getActiveContracts(
request: GetActiveContractsRequest,
responseObserver: StreamObserver[GetActiveContractsResponse],
): Unit = registerStream(responseObserver) {
TransactionFilterValidator
.validate(request.getFilter)
.fold(
@ -53,6 +56,7 @@ private[apiserver] final class ApiActiveContractsService private (
)
.via(logger.logErrorsOnStream)
.via(StreamMetrics.countElements(metrics.daml.lapi.streams.acs))
}
override def bindService(): ServerServiceDefinition =
ActiveContractsServiceGrpc.bindService(this, executionContext)

View File

@ -3,17 +3,18 @@
package com.daml.platform.apiserver.services
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.api.util.DurationConversion._
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.ledger.api.domain.LedgerId
import com.daml.ledger.api.v1.ledger_configuration_service._
import com.daml.ledger.participant.state.index.v2.IndexConfigurationService
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.platform.api.grpc.GrpcApiService
import com.daml.platform.server.api.validation.LedgerConfigurationServiceValidation
import io.grpc.stub.StreamObserver
import io.grpc.{BindableService, ServerServiceDefinition}
import scala.concurrent.ExecutionContext
@ -21,18 +22,22 @@ import scala.concurrent.ExecutionContext
private[apiserver] final class ApiLedgerConfigurationService private (
configurationService: IndexConfigurationService
)(implicit
protected val esf: ExecutionSequencerFactory,
protected val mat: Materializer,
esf: ExecutionSequencerFactory,
mat: Materializer,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends LedgerConfigurationServiceAkkaGrpc
) extends LedgerConfigurationServiceGrpc.LedgerConfigurationService
with StreamingServiceLifecycleManagement
with GrpcApiService {
private val logger = ContextualizedLogger.get(this.getClass)
protected implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
override protected def getLedgerConfigurationSource(
request: GetLedgerConfigurationRequest
): Source[GetLedgerConfigurationResponse, NotUsed] = {
def getLedgerConfiguration(
request: GetLedgerConfigurationRequest,
responseObserver: StreamObserver[GetLedgerConfigurationResponse],
): Unit = registerStream(responseObserver) {
logger.info(s"Received request for configuration subscription: $request")
configurationService
.getLedgerConfiguration()

View File

@ -4,13 +4,12 @@
package com.daml.platform.apiserver.services
import java.time.Instant
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.Source
import com.daml.api.util.TimestampConversion._
import com.daml.error.{ContextualizedErrorLogger, DamlContextualizedErrorLogger}
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.ledger.api.domain.{LedgerId, optionalLedgerId}
import com.daml.ledger.api.v1.testing.time_service.TimeServiceGrpc.TimeService
import com.daml.ledger.api.v1.testing.time_service._
@ -27,6 +26,7 @@ import scalaz.syntax.tag._
import scala.concurrent.{Await, ExecutionContext, Future}
import com.daml.timer.Timeout._
import io.grpc.stub.StreamObserver
import scala.concurrent.duration.Duration
import scala.util.{Failure, Success}
@ -36,15 +36,16 @@ private[apiserver] final class ApiTimeService private (
backend: TimeServiceBackend,
apiStreamShutdownTimeout: Duration,
)(implicit
protected val mat: Materializer,
protected val esf: ExecutionSequencerFactory,
mat: Materializer,
esf: ExecutionSequencerFactory,
executionContext: ExecutionContext,
loggingContext: LoggingContext,
) extends TimeServiceAkkaGrpc
) extends TimeServiceGrpc.TimeService
with StreamingServiceLifecycleManagement
with GrpcApiService {
private implicit val logger: ContextualizedLogger = ContextualizedLogger.get(getClass)
private implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
protected implicit val contextualizedErrorLogger: ContextualizedErrorLogger =
new DamlContextualizedErrorLogger(logger, loggingContext, None)
private val dispatcher = SignalDispatcher[Instant]()
@ -55,9 +56,10 @@ private[apiserver] final class ApiTimeService private (
s"${getClass.getSimpleName} initialized with ledger ID ${ledgerId.unwrap}, start time ${backend.getCurrentTime}"
)
override protected def getTimeSource(
request: GetTimeRequest
): Source[GetTimeResponse, NotUsed] = {
def getTime(
request: GetTimeRequest,
responseObserver: StreamObserver[GetTimeResponse],
): Unit = registerStream(responseObserver) {
val validated =
matchLedgerId(ledgerId)(optionalLedgerId(request.ledgerId))
validated.fold(

View File

@ -3,7 +3,6 @@
package com.daml.platform.apiserver.error
import akka.NotUsed
import akka.stream.Materializer
import akka.stream.scaladsl.{Flow, Source}
import ch.qos.logback.classic.Level
@ -11,19 +10,20 @@ import com.daml.error._
import com.daml.error.definitions.{CommonErrors, DamlError}
import com.daml.error.utils.ErrorDetails
import com.daml.grpc.adapter.ExecutionSequencerFactory
import com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement
import com.daml.grpc.sampleservice.HelloServiceResponding
import com.daml.ledger.api.testing.utils.{AkkaBeforeAndAfterAll, TestingServerInterceptors}
import com.daml.ledger.resources.{ResourceOwner, TestResourceContext}
import com.daml.platform.hello.HelloServiceGrpc.HelloService
import com.daml.platform.hello.{HelloRequest, HelloResponse, HelloServiceAkkaGrpc, HelloServiceGrpc}
import com.daml.platform.hello.{HelloRequest, HelloResponse, HelloServiceGrpc}
import com.daml.platform.testing.LogCollector.{ThrowableCause, ThrowableEntry}
import com.daml.platform.testing.{LogCollector, LogCollectorAssertions, StreamConsumer}
import io.grpc._
import io.grpc.stub.StreamObserver
import org.scalatest._
import org.scalatest.concurrent.Eventually
import org.scalatest.freespec.AsyncFreeSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.{Assertion, Assertions, BeforeAndAfter, Checkpoints, OptionValues}
import scala.concurrent.{ExecutionContext, ExecutionContextExecutor, Future}
@ -313,15 +313,20 @@ object ErrorInterceptorSpec {
* @param errorInsideFutureOrStream - whether to signal the exception inside a Future or a Stream, or outside to them
*/
class HelloServiceFailing(useSelfService: Boolean, errorInsideFutureOrStream: Boolean)(implicit
protected val esf: ExecutionSequencerFactory,
protected val mat: Materializer,
) extends HelloServiceAkkaGrpc
esf: ExecutionSequencerFactory,
mat: Materializer,
) extends HelloService
with StreamingServiceLifecycleManagement
with HelloServiceResponding
with HelloServiceBase {
override protected def serverStreamingSource(
request: HelloRequest
): Source[HelloResponse, NotUsed] = {
override protected val contextualizedErrorLogger: ContextualizedErrorLogger =
DamlContextualizedErrorLogger.forTesting(getClass)
override def serverStreaming(
request: HelloRequest,
responseObserver: StreamObserver[HelloResponse],
): Unit = registerStream(responseObserver) {
val where = if (errorInsideFutureOrStream) "inside" else "outside"
val t: Throwable = if (useSelfService) {
FooMissingErrorCode
@ -359,8 +364,7 @@ object ErrorInterceptorSpec {
class HelloServiceFailingDirectlyObserverOnError(implicit
protected val esf: ExecutionSequencerFactory,
protected val mat: Materializer,
) extends HelloServiceAkkaGrpc
with HelloServiceResponding
) extends HelloServiceResponding
with HelloServiceBase {
override def serverStreaming(
@ -373,10 +377,6 @@ object ErrorInterceptorSpec {
)
)
override protected def serverStreamingSource(
request: HelloRequest
): Source[HelloResponse, NotUsed] = Assertions.fail("This method should have been unreachable")
override def single(request: HelloRequest): Future[HelloResponse] =
Assertions.fail("This class is not intended to test unary endpoints")
}

View File

@ -99,8 +99,6 @@
type: jar-scala
- target: //ledger/indexer-benchmark:indexer-benchmark-lib
type: jar-scala
- target: //ledger/ledger-api-akka:ledger-api-akka
type: jar-scala
- target: //ledger/ledger-api-auth:ledger-api-auth
type: jar-scala
- target: //ledger/ledger-api-auth-client:ledger-api-auth-client

View File

@ -1,78 +0,0 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.protoc.plugins.akka
import com.google.protobuf.Descriptors.{MethodDescriptor, ServiceDescriptor}
import scalapb.compiler.FunctionalPrinter.PrinterEndo
import scalapb.compiler.{DescriptorImplicits, FunctionalPrinter, StreamType}
final class AkkaGrpcServicePrinter(
service: ServiceDescriptor
)(implicit descriptorImplicits: DescriptorImplicits) {
import descriptorImplicits._
private val StreamObserver = "_root_.io.grpc.stub.StreamObserver"
def printService(printer: FunctionalPrinter): Option[FunctionalPrinter] = {
val hasStreamingEndpoint: Boolean = service.methods.exists(_.isServerStreaming)
if (hasStreamingEndpoint) Some {
printer
.add(
"package " + service.getFile.scalaPackage.fullName,
"",
s"trait ${service.name}AkkaGrpc extends ${service.getName}Grpc.${service.getName} with com.daml.grpc.adapter.server.akka.StreamingServiceLifecycleManagement {",
)
.call(traitBody)
.add("}")
}
else None
}
private def responseType(method: MethodDescriptor): String = method.outputType.scalaType
private def observer(typeParam: String): String = s"$StreamObserver[$typeParam]"
private def serviceMethodSignature(method: MethodDescriptor): PrinterEndo = { p =>
method.streamType match {
case StreamType.Unary => p
case StreamType.ClientStreaming => p
case StreamType.ServerStreaming =>
p
.add(s"def ${method.name}(")
.indent
.add(s"request: ${method.inputType.scalaType},")
.add(s"responseObserver: ${observer(responseType(method))}")
.outdent
.add("): Unit =")
.indent
.add(
s"registerStream(() => ${method.name}Source(request), responseObserver)"
)
.outdent
.newline
.add(s"protected def ${method.name}Source(")
.indent
.add(s"request: ${method.inputType.scalaType}")
.outdent
.add(s"): akka.stream.scaladsl.Source[${responseType(method)}, akka.NotUsed]")
.newline
case StreamType.Bidirectional => p
}
}
private def traitBody: PrinterEndo = {
val endos: PrinterEndo = { p =>
p.newline
.call(service.methods.map(m => serviceMethodSignature(m)): _*)
}
p =>
p.indent
.add("protected implicit def esf: com.daml.grpc.adapter.ExecutionSequencerFactory")
.add("protected implicit def mat: akka.stream.Materializer")
.call(endos)
.outdent
}
}

View File

@ -1,25 +0,0 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.protoc.plugins.akka
import com.google.protobuf.ExtensionRegistry
import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest
import protocgen.CodeGenRequest
import scalapb.options.Scalapb
import scala.reflect.io.Streamable
object AkkaStreamCompilerPlugin {
def main(args: Array[String]): Unit = {
val registry = ExtensionRegistry.newInstance()
Scalapb.registerAllExtensions(registry)
val request = CodeGeneratorRequest.parseFrom(Streamable.bytes(System.in), registry)
System.out.write(
AkkaStreamGenerator
.handleCodeGeneratorRequest(CodeGenRequest(request))
.toCodeGeneratorResponse
.toByteArray
)
}
}

View File

@ -1,83 +0,0 @@
// Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.protoc.plugins.akka
import com.google.protobuf.Descriptors._
import com.google.protobuf.ExtensionRegistry
import com.google.protobuf.compiler.PluginProtos.{CodeGeneratorRequest, CodeGeneratorResponse}
import protocbridge.ProtocCodeGenerator
import protocgen.{CodeGenRequest, CodeGenResponse}
import scalapb.compiler._
import scalapb.options.Scalapb
import scala.jdk.CollectionConverters._
// This file is mostly copied over from ScalaPbCodeGenerator and ProtobufGenerator
object AkkaStreamGenerator extends ProtocCodeGenerator {
override def run(req: Array[Byte]): Array[Byte] = {
val registry = ExtensionRegistry.newInstance()
Scalapb.registerAllExtensions(registry)
val request = CodeGeneratorRequest.parseFrom(req, registry)
handleCodeGeneratorRequest(CodeGenRequest(request)).toCodeGeneratorResponse.toByteArray
}
def handleCodeGeneratorRequest(request: CodeGenRequest): CodeGenResponse =
parseParameters(request.parameter) match {
case Right(params) =>
implicit val descriptorImplicits: DescriptorImplicits =
DescriptorImplicits.fromCodeGenRequest(params, request)
try {
val filesByName: Map[String, FileDescriptor] =
request.allProtos.map(fd => fd.getName -> fd).toMap
val validator = new ProtoValidation(descriptorImplicits)
filesByName.values.foreach(validator.validateFile)
val responseFiles = request.filesToGenerate.flatMap(generateServiceFiles(_))
CodeGenResponse.succeed(responseFiles)
} catch {
case exception: GeneratorException =>
CodeGenResponse.fail(exception.message)
}
case Left(error) =>
CodeGenResponse.fail(error)
}
private def parseParameters(params: String): Either[String, GeneratorParams] = {
params
.split(",")
.map(_.trim)
.filter(_.nonEmpty)
.foldLeft[Either[String, GeneratorParams]](Right(GeneratorParams())) {
case (Right(params), "java_conversions") => Right(params.copy(javaConversions = true))
case (Right(params), "flat_package") => Right(params.copy(flatPackage = true))
case (Right(params), "grpc") => Right(params.copy(grpc = true))
case (Right(params), "single_line_to_proto_string") =>
Right(params.copy(singleLineToProtoString = true))
case (Right(params), "ascii_format_to_string") =>
Right(params.copy(asciiFormatToString = true))
case (Right(_), p) => Left(s"Unrecognized parameter: '$p'")
case (x, _) => x
}
}
private def generateServiceFiles(
file: FileDescriptor
)(implicit
descriptorImplicits: DescriptorImplicits
): collection.Seq[CodeGeneratorResponse.File] = {
import descriptorImplicits._
file.getServices.asScala.flatMap { service =>
val printer = new AkkaGrpcServicePrinter(service)
printer.printService(FunctionalPrinter()).fold[List[CodeGeneratorResponse.File]](Nil) { p =>
val code = p.result()
val fileBuilder = CodeGeneratorResponse.File.newBuilder()
fileBuilder.setName(file.scalaDirectory + "/" + service.name + "AkkaGrpc.scala")
fileBuilder.setContent(code)
List(fileBuilder.build())
}
}
}
val deprecatedAnnotation: String =
"""@scala.deprecated(message="Marked as deprecated in proto file", "")"""
}

View File

@ -1,61 +0,0 @@
# Copyright (c) 2022 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load("//bazel_tools:scala.bzl", "da_scala_binary", "da_scala_library")
da_scala_binary(
name = "compiler_plugin",
srcs = glob(["*.scala"]),
main_class = "com.daml.protoc.plugins.akka.AkkaStreamCompilerPlugin",
scala_deps = [
"@maven//:com_thesamet_scalapb_compilerplugin",
"@maven//:com_thesamet_scalapb_protoc_bridge",
"@maven//:com_thesamet_scalapb_protoc_gen",
],
visibility = ["//visibility:public"],
deps = [
"@maven//:com_google_protobuf_protobuf_java",
],
)
# From https://github.com/stackb/rules_proto/blob/3f890f5d6774bd74df28e89b20f34155dfe77732/scala/BUILD.bazel#L78-L97
# Curiously this didn't work
#
# genrule(
# name = "gen_protoc_gen_scala",
# srcs = ["compiler_plugin_deploy.jar", "@local_jdk//:bin/java"],
# outs = ["protoc-gen-scala.sh"],
# cmd = """
# echo '$(location @local_jdk//:bin/java) -jar $(location protoc_gen_deploy.jar) $$@' > $@
# """,
# executable = True,
# )
# ======================================================================
#
# Unable to get either bazel or maybe protoc to call a plugin whose
# implementation was fronted by a shell script (from a genrule). So, the only
# way this seemed to work was compile an executable that calls 'java -jar
# protoc_gen_scala_deploy.jar'. Either figure out how to do this in java
# directly or write the wrapper in C++ ot remove the go dependency here.
#
load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library")
go_library(
name = "go_default_library",
srcs = [
"protoc-gen-scala-akka.go",
],
importpath = "github.com/digital-asset/daml/scala/protoc-gen-scala-akka",
visibility = ["//visibility:public"],
)
go_binary(
name = "protoc-gen-scala-akka",
data = [
":compiler_plugin_deploy.jar",
"@bazel_tools//tools/jdk",
],
embed = [":go_default_library"],
visibility = ["//visibility:public"],
)

View File

@ -1,86 +0,0 @@
package main
import (
"log"
"os"
"os/exec"
"path"
"syscall"
)
func main() {
// This works due to the expected sandbox layout:
//
// ./bazel-out/host/bin/external/build_stack_rules_proto/scala/compiler_plugin_deploy.jar
// ./bazel-out/host/bin/external/build_stack_rules_proto/scala/linux_amd64_stripped
// ./bazel-out/host/bin/external/build_stack_rules_proto/scala/linux_amd64_stripped/protoc-gen-scala
jar := mustFindInSandbox(path.Dir(os.Args[0]), "compiler_plugin_deploy.jar")
err, exitCode := run("external/local_jdk/bin/java", append([]string{"-jar", jar}, os.Args...), ".", nil)
if err != nil {
log.Printf("%v", err)
}
os.Exit(exitCode)
}
func mustFindInSandbox(dir, file string) string {
attempts := 0
for {
// Just in case we have a bug that will loop forever in some random
// filesystem pattern we haven't thought of
if attempts > 1000 {
log.Fatalf("Too many attempts to find %s within %s", file, dir)
}
if dir == "" {
log.Fatalf("Failed to find %s within %s", file, dir)
}
abs := path.Join(dir, file)
if exists(abs) {
return abs
}
dir = path.Dir(dir)
attempts++
}
}
// exists - return true if a file entry exists
func exists(name string) bool {
if _, err := os.Stat(name); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// run a command
func run(entrypoint string, args []string, dir string, env []string) (error, int) {
cmd := exec.Command(entrypoint, args...)
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = env
cmd.Dir = dir
err := cmd.Run()
var exitCode int
if err != nil {
// try to get the exit code
if exitError, ok := err.(*exec.ExitError); ok {
ws := exitError.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// This will happen (in OSX) if `name` is not available in $PATH,
// in this situation, exit code could not be get, and stderr will be
// empty string very likely, so we use the default fail code, and format err
// to string and set to stderr
log.Printf("Could not get exit code for failed program: %v, %v", entrypoint, args)
exitCode = -1
}
} else {
// success, exitCode should be 0 if go is ok
ws := cmd.ProcessState.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
}
return err, exitCode
}