Add grpc-reverse-proxy with optional server interceptors (#8481)

* Factor 'withServices' out of ServerReflectionClientSpec

* Add grpc-reverse-proxy with optional server interceptors

changelog_begin
changelog_end

* Switch from Protobuf byte strings to native byte arrays

* Fix year in copyright headers

* Address https://github.com/digital-asset/daml/pull/8481#discussion_r555915935

* Add missing maven_coordinates tag

* Address https://github.com/digital-asset/daml/pull/8481#discussion_r555916684
This commit is contained in:
Stefano Baghino 2021-01-13 15:41:30 +01:00 committed by GitHub
parent 544b0a2caa
commit 9ed787cb3d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 414 additions and 37 deletions

View File

@ -0,0 +1,41 @@
# Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"//bazel_tools:scala.bzl",
"da_scala_library",
"da_scala_test_suite",
)
da_scala_library(
name = "grpc-reverse-proxy",
srcs = glob(["src/main/scala/**/*.scala"]),
tags = ["maven_coordinates=com.daml:grpc-reverse-proxy:__VERSION__"],
visibility = [
"//:__subpackages__",
],
deps = [
"//libs-scala/grpc-server-reflection-client",
"@maven//:com_google_guava_guava",
"@maven//:io_grpc_grpc_api",
"@maven//:io_grpc_grpc_services",
"@maven//:io_grpc_grpc_stub",
],
)
da_scala_test_suite(
name = "test",
srcs = glob(["src/test/scala/**/*.scala"]),
versioned_scala_deps = {
"2.12": ["@maven//:org_scala_lang_modules_scala_collection_compat"],
},
deps = [
":grpc-reverse-proxy",
"//libs-scala/grpc-test-utils",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:io_grpc_grpc_api",
"@maven//:io_grpc_grpc_core",
"@maven//:io_grpc_grpc_services",
"@maven//:io_grpc_grpc_stub",
],
)

View File

@ -0,0 +1,59 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.grpc
import io.grpc.MethodDescriptor.MethodType
import io.grpc.{CallOptions, Channel, ClientCall, MethodDescriptor, ServerMethodDefinition}
import io.grpc.stub.{ClientCalls, ServerCalls, StreamObserver}
private[grpc] object ForwardCall {
def apply[ReqT, RespT](
method: MethodDescriptor[ReqT, RespT],
backend: Channel,
options: CallOptions,
): ServerMethodDefinition[ReqT, RespT] = {
val forward = () => backend.newCall(method, options)
ServerMethodDefinition.create[ReqT, RespT](
method,
method.getType match {
case MethodType.UNARY =>
ServerCalls.asyncUnaryCall(new UnaryMethod(forward))
case MethodType.CLIENT_STREAMING =>
ServerCalls.asyncClientStreamingCall(new ClientStreamingMethod(forward))
case MethodType.SERVER_STREAMING =>
ServerCalls.asyncServerStreamingCall(new ServerStreamingMethod(forward))
case MethodType.BIDI_STREAMING =>
ServerCalls.asyncBidiStreamingCall(new BidiStreamMethod(forward))
case MethodType.UNKNOWN =>
sys.error(s"${method.getFullMethodName} has MethodType.UNKNOWN")
},
)
}
private final class UnaryMethod[ReqT, RespT](call: () => ClientCall[ReqT, RespT])
extends ServerCalls.UnaryMethod[ReqT, RespT] {
override def invoke(request: ReqT, responseObserver: StreamObserver[RespT]): Unit =
ClientCalls.asyncUnaryCall(call(), request, responseObserver)
}
private final class ClientStreamingMethod[ReqT, RespT](call: () => ClientCall[ReqT, RespT])
extends ServerCalls.ClientStreamingMethod[ReqT, RespT] {
override def invoke(responseObserver: StreamObserver[RespT]): StreamObserver[ReqT] =
ClientCalls.asyncClientStreamingCall(call(), responseObserver)
}
private final class ServerStreamingMethod[ReqT, RespT](call: () => ClientCall[ReqT, RespT])
extends ServerCalls.ServerStreamingMethod[ReqT, RespT] {
override def invoke(request: ReqT, responseObserver: StreamObserver[RespT]): Unit =
ClientCalls.asyncServerStreamingCall(call(), request, responseObserver)
}
private final class BidiStreamMethod[ReqT, RespT](call: () => ClientCall[ReqT, RespT])
extends ServerCalls.BidiStreamingMethod[ReqT, RespT] {
override def invoke(responseObserver: StreamObserver[RespT]): StreamObserver[ReqT] =
ClientCalls.asyncBidiStreamingCall(call(), responseObserver)
}
}

View File

@ -0,0 +1,26 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.grpc
import java.io.{ByteArrayInputStream, InputStream}
import com.daml.grpc.reflection.ServiceDescriptorInfo
import com.google.common.io.ByteStreams
import io.grpc.{CallOptions, Channel, MethodDescriptor, ServerServiceDefinition}
private[grpc] object ForwardService {
def apply(backend: Channel, service: ServiceDescriptorInfo): ServerServiceDefinition =
service.methods
.map(_.toMethodDescriptor(ByteArrayMarshaller, ByteArrayMarshaller))
.map(ForwardCall(_, backend, CallOptions.DEFAULT))
.foldLeft(ServerServiceDefinition.builder(service.fullServiceName))(_ addMethod _)
.build()
private object ByteArrayMarshaller extends MethodDescriptor.Marshaller[Array[Byte]] {
override def parse(input: InputStream): Array[Byte] = ByteStreams.toByteArray(input)
override def stream(bytes: Array[Byte]): InputStream = new ByteArrayInputStream(bytes)
}
}

View File

@ -0,0 +1,38 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.grpc
import com.daml.grpc.reflection.ServerReflectionClient
import io.grpc._
import io.grpc.reflection.v1alpha.ServerReflectionGrpc
import scala.concurrent.{ExecutionContext, Future}
object ReverseProxy {
def create(
backend: Channel,
serverBuilder: ServerBuilder[_],
interceptors: Map[String, Seq[ServerInterceptor]],
)(implicit
ec: ExecutionContext
): Future[Server] = {
val stub = ServerReflectionGrpc.newStub(backend)
val client = new ServerReflectionClient(stub)
val future = client.getAllServices()
future
.map { services =>
for (service <- services) {
serverBuilder.addService(
ServerInterceptors.interceptForward(
ForwardService(backend, service),
interceptors.getOrElse(service.fullServiceName, Seq.empty): _*
)
)
}
serverBuilder.build()
}
}
}

View File

@ -0,0 +1,155 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.grpc
import java.util.concurrent.atomic.AtomicReference
import com.daml.grpc.test.GrpcServer
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall
import io.grpc.ForwardingServerCallListener.SimpleForwardingServerCallListener
import io.grpc.health.v1.HealthCheckResponse.ServingStatus
import io.grpc.health.v1.{HealthCheckRequest, HealthGrpc}
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
import io.grpc.reflection.v1alpha.{
ServerReflectionGrpc,
ServerReflectionRequest,
ServerReflectionResponse,
}
import io.grpc.stub.StreamObserver
import io.grpc.{
Channel,
Metadata,
ServerCall,
ServerCallHandler,
ServerInterceptor,
StatusRuntimeException,
}
import org.scalatest.Inside
import org.scalatest.flatspec.AsyncFlatSpec
import org.scalatest.matchers.should.Matchers
import scala.concurrent.{Await, Promise}
import scala.concurrent.duration.DurationInt
import scala.jdk.CollectionConverters._
final class ReverseProxySpec extends AsyncFlatSpec with Matchers with GrpcServer with Inside {
import ReverseProxySpec._
import Services._
behavior of "ReverseProxy.create"
it should "fail if the backend does not support reflection" in withServices(health) { backend =>
val proxyBuilder = InProcessServerBuilder.forName(InProcessServerBuilder.generateName())
ReverseProxy
.create(backend, proxyBuilder, interceptors = Map.empty)
.failed
.map(_ shouldBe a[StatusRuntimeException])
}
it should "expose the backend's own services" in withServices(health, reflection) { backend =>
val proxyName = InProcessServerBuilder.generateName()
val proxyBuilder = InProcessServerBuilder.forName(proxyName)
ReverseProxy
.create(backend, proxyBuilder, interceptors = Map.empty)
.map { proxyServer =>
proxyServer.start()
val proxy = InProcessChannelBuilder.forName(proxyName).build()
getHealthStatus(backend) shouldEqual getHealthStatus(proxy)
listServices(backend) shouldEqual listServices(proxy)
}
}
it should "correctly set up an interceptor" in withServices(health, reflection) { backend =>
val proxyName = InProcessServerBuilder.generateName()
val proxyBuilder = InProcessServerBuilder.forName(proxyName)
val recorder = new RecordingInterceptor
ReverseProxy
.create(backend, proxyBuilder, Map(HealthGrpc.SERVICE_NAME -> Seq(recorder)))
.map { proxyServer =>
proxyServer.start()
val proxy = InProcessChannelBuilder.forName(proxyName).build()
getHealthStatus(backend)
recorder.latestRequest() shouldBe empty
getHealthStatus(proxy)
inside(recorder.latestRequest()) { case Some(request: Array[Byte]) =>
HealthCheckRequest.parseFrom(request)
succeed
}
}
}
}
object ReverseProxySpec {
private def listServices(channel: Channel): Iterable[String] = {
val response = Promise[Iterable[String]]()
lazy val serverStream: StreamObserver[ServerReflectionRequest] =
ServerReflectionGrpc
.newStub(channel)
.serverReflectionInfo(new StreamObserver[ServerReflectionResponse] {
override def onNext(value: ServerReflectionResponse): Unit = {
if (value.hasListServicesResponse) {
val services = value.getListServicesResponse.getServiceList.asScala.map(_.getName)
response.trySuccess(services)
} else {
response.tryFailure(new IllegalStateException("Received unexpected response type"))
}
serverStream.onCompleted()
}
override def onError(throwable: Throwable): Unit = {
val _ = response.tryFailure(throwable)
}
override def onCompleted(): Unit = {
val _ = response.tryFailure(new IllegalStateException("No response received"))
}
})
serverStream.onNext(ServerReflectionRequest.newBuilder().setListServices("").build())
Await.result(response.future, 5.seconds)
}
private def getHealthStatus(channel: Channel): ServingStatus =
HealthGrpc
.newBlockingStub(channel)
.check(HealthCheckRequest.newBuilder().build())
.getStatus
final class Forward[ReqT, RespT](call: ServerCall[ReqT, RespT])
extends SimpleForwardingServerCall[ReqT, RespT](call) {
override def sendMessage(message: RespT): Unit = {
super.sendMessage(message)
}
}
final class Callback[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT],
callback: ReqT => Unit,
) extends SimpleForwardingServerCallListener[ReqT](next.startCall(new Forward(call), headers)) {
override def onMessage(message: ReqT): Unit = {
callback(message)
super.onMessage(message)
}
}
final class RecordingInterceptor extends ServerInterceptor {
private val latestRequestReference = new AtomicReference[Any]()
def latestRequest(): Option[Any] = Option(latestRequestReference.get)
override def interceptCall[ReqT, RespT](
call: ServerCall[ReqT, RespT],
headers: Metadata,
next: ServerCallHandler[ReqT, RespT],
): ServerCall.Listener[ReqT] = {
new Callback(call, headers, next, latestRequestReference.set)
}
}
}

View File

@ -30,6 +30,7 @@ da_scala_test_suite(
srcs = glob(["src/test/scala/**/*.scala"]),
deps = [
":grpc-server-reflection-client",
"//libs-scala/grpc-test-utils",
"@maven//:com_google_protobuf_protobuf_java",
"@maven//:io_grpc_grpc_api",
"@maven//:io_grpc_grpc_core",

View File

@ -3,35 +3,28 @@
package com.daml.grpc.reflection
import java.util.concurrent.TimeUnit
import com.daml.grpc.test.GrpcServer
import io.grpc.StatusRuntimeException
import io.grpc.health.v1.HealthGrpc
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.reflection.v1alpha.ServerReflectionGrpc
import io.grpc.services.HealthStatusManager
import io.grpc.{BindableService, Channel, StatusRuntimeException}
import org.scalatest.Assertion
import org.scalatest.concurrent.ScalaFutures
import org.scalatest.flatspec.AsyncFlatSpec
import org.scalatest.matchers.should.Matchers
import scala.concurrent.Future
final class ServerReflectionClientSpec extends AsyncFlatSpec with Matchers with ScalaFutures {
final class ServerReflectionClientSpec extends AsyncFlatSpec with Matchers with GrpcServer {
import ServerReflectionClientSpec._
import Services._
behavior of "getAllServices"
it should "fail if reflection is not supported" in withServices(health) { channel =>
val stub = ServerReflectionGrpc.newStub(channel)
val client = new ServerReflectionClient(stub)
client.getAllServices().failed.futureValue shouldBe a[StatusRuntimeException]
client.getAllServices().failed.map(_ shouldBe a[StatusRuntimeException])
}
it should "show all if reflection is supported" in withServices(health, reflection) { channel =>
val expected = Vector(healthDescriptor, reflectionDescriptor)
val expected = Set(healthDescriptor, reflectionDescriptor)
val stub = ServerReflectionGrpc.newStub(channel)
val client = new ServerReflectionClient(stub)
client.getAllServices().map(_ should contain theSameElementsAs expected)
@ -41,7 +34,6 @@ final class ServerReflectionClientSpec extends AsyncFlatSpec with Matchers with
object ServerReflectionClientSpec {
private def health: BindableService = new HealthStatusManager().getHealthService
private val healthDescriptor =
ServiceDescriptorInfo(
fullServiceName = HealthGrpc.SERVICE_NAME,
@ -51,32 +43,10 @@ object ServerReflectionClientSpec {
),
)
private def reflection: BindableService = ProtoReflectionService.newInstance()
private val reflectionDescriptor =
ServiceDescriptorInfo(
fullServiceName = ServerReflectionGrpc.SERVICE_NAME,
methods = Set(
MethodDescriptorInfo(ServerReflectionGrpc.getServerReflectionInfoMethod)
),
methods = Set(MethodDescriptorInfo(ServerReflectionGrpc.getServerReflectionInfoMethod)),
)
private def withServices(service: BindableService, services: BindableService*)(
f: Channel => Future[Assertion]
): Future[Assertion] = {
val serverName = InProcessServerBuilder.generateName()
val serverBuilder = InProcessServerBuilder.forName(serverName).addService(service)
for (additionalService <- services) {
serverBuilder.addService(additionalService)
}
val server = serverBuilder.build()
val channel = InProcessChannelBuilder.forName(serverName).build()
try {
server.start()
f(channel)
} finally {
server.shutdown()
val _ = server.awaitTermination(5, TimeUnit.SECONDS)
}
}
}

View File

@ -0,0 +1,25 @@
# Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
load(
"//bazel_tools:scala.bzl",
"da_scala_library",
)
da_scala_library(
name = "grpc-test-utils",
srcs = glob(["src/main/scala/**/*.scala"]),
scala_deps = [
"@maven//:org_scalatest_scalatest",
"@maven//:org_scalactic_scalactic",
],
tags = ["maven_coordinates=com.daml:grpc-test-utils:__VERSION__"],
visibility = [
"//:__subpackages__",
],
deps = [
"@maven//:io_grpc_grpc_api",
"@maven//:io_grpc_grpc_core",
"@maven//:io_grpc_grpc_services",
],
)

View File

@ -0,0 +1,62 @@
// Copyright (c) 2021 Digital Asset (Switzerland) GmbH and/or its affiliates. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
package com.daml.grpc.test
import io.grpc.inprocess.{InProcessChannelBuilder, InProcessServerBuilder}
import io.grpc.protobuf.services.ProtoReflectionService
import io.grpc.services.HealthStatusManager
import io.grpc.{BindableService, Channel, ManagedChannel, Server}
import org.scalatest.Assertion
import org.scalatest.flatspec.AsyncFlatSpec
import scala.concurrent.Future
import scala.concurrent.duration.{DurationInt, FiniteDuration}
trait GrpcServer { this: AsyncFlatSpec =>
object Services {
def health: BindableService = new HealthStatusManager().getHealthService
def reflection: BindableService = ProtoReflectionService.newInstance()
}
def withServices(
service: BindableService,
services: BindableService*
)(
test: Channel => Future[Assertion]
): Future[Assertion] = {
val setup = Future {
val serverName = InProcessServerBuilder.generateName()
val serverBuilder = InProcessServerBuilder.forName(serverName).addService(service)
for (additionalService <- services) {
serverBuilder.addService(additionalService)
}
GrpcServer.Setup(
server = serverBuilder.build().start(),
channel = InProcessChannelBuilder.forName(serverName).build(),
)
}
val result = setup.map(_.channel).flatMap(test)
result.onComplete(_ => setup.map(_.shutdownAndAwaitTerminationFor(5.seconds)))
result
}
}
object GrpcServer {
private final case class Setup(server: Server, channel: ManagedChannel) {
def shutdownAndAwaitTerminationFor(timeout: FiniteDuration): Unit = {
server.shutdown()
channel.shutdown()
server.awaitTermination()
channel.awaitTermination(timeout.length, timeout.unit)
()
}
}
}