Allow the rate limiting interceptor to be passed to api services (#15190)

changelog_begin
changelog_end
This commit is contained in:
Simon Maxen 2022-10-10 11:24:01 +01:00 committed by GitHub
parent 3f3408f16f
commit b9abbfdcbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 106 additions and 64 deletions

View File

@ -20,6 +20,7 @@ import com.daml.logging.LoggingContext
import com.daml.logging.LoggingContext.newLoggingContextWith
import com.daml.metrics.Metrics
import com.daml.platform.apiserver._
import com.daml.platform.apiserver.ratelimiting.RateLimitingInterceptor
import com.daml.platform.config.ParticipantConfig
import com.daml.platform.configuration.{IndexServiceConfig, ServerRole}
import com.daml.platform.index.{InMemoryStateUpdater, IndexServiceOwner}
@ -49,6 +50,7 @@ class LedgerApiServer(
// Currently, we provide this flag outside the HOCON configuration objects
// in order to ensure that participants cannot be configured to accept explicitly disclosed contracts.
explicitDisclosureUnsafeEnabled: Boolean = false,
rateLimitingInterceptor: Option[RateLimitingInterceptor] = None,
)(implicit actorSystem: ActorSystem, materializer: Materializer) {
def owner: ResourceOwner[ApiService] = {
@ -147,7 +149,7 @@ class LedgerApiServer(
healthChecks = healthChecksWithIndexer + ("write" -> writeService),
metrics = metrics,
timeServiceBackend = timeServiceBackend,
otherInterceptors = List.empty,
otherInterceptors = rateLimitingInterceptor.toList,
engine = sharedEngine,
servicesExecutionContext = servicesExecutionContext,
userManagementStore = PersistentUserManagementStore.cached(

View File

@ -137,7 +137,6 @@ object ApiServiceOwner {
) :: otherInterceptors,
servicesExecutionContext,
metrics,
config.rateLimit,
)
_ <- ResourceOwner.forTry(() => writePortFile(apiService.port))
} yield {

View File

@ -5,9 +5,7 @@ package com.daml.platform.apiserver
import com.daml.ledger.resources.ResourceOwner
import com.daml.metrics.Metrics
import com.daml.platform.apiserver.configuration.RateLimitingConfig
import com.daml.platform.apiserver.error.ErrorInterceptor
import com.daml.platform.apiserver.ratelimiting.RateLimitingInterceptor
import com.daml.ports.Port
import com.google.protobuf.Message
import io.grpc._
@ -40,7 +38,6 @@ private[apiserver] object GrpcServer {
metrics: Metrics,
servicesExecutor: Executor,
services: Iterable[BindableService],
rateLimitingConfig: Option[RateLimitingConfig],
): ResourceOwner[Server] = {
val host = address.map(InetAddress.getByName).getOrElse(InetAddress.getLoopbackAddress)
val builder = NettyServerBuilder.forAddress(new InetSocketAddress(host, desiredPort.value))
@ -51,7 +48,6 @@ private[apiserver] object GrpcServer {
builder.maxInboundMessageSize(maxInboundMessageSize)
// NOTE: Interceptors run in the reverse order in which they were added.
interceptors.foreach(builder.intercept)
rateLimitingConfig.foreach(c => builder.intercept(RateLimitingInterceptor(metrics, config = c)))
builder.intercept(new MetricsInterceptor(metrics))
builder.intercept(new TruncatedStatusInterceptor(MaximumStatusDescriptionLength))
builder.intercept(new ErrorInterceptor)

View File

@ -8,7 +8,6 @@ import com.daml.ledger.api.tls.TlsConfiguration
import com.daml.ledger.resources.{Resource, ResourceContext, ResourceOwner}
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.metrics.Metrics
import com.daml.platform.apiserver.configuration.RateLimitingConfig
import com.daml.ports.Port
import io.grpc.ServerInterceptor
@ -24,7 +23,6 @@ private[daml] final class LedgerApiService(
interceptors: List[ServerInterceptor] = List.empty,
servicesExecutor: Executor,
metrics: Metrics,
rateLimitingConfig: Option[RateLimitingConfig],
)(implicit loggingContext: LoggingContext)
extends ResourceOwner[ApiService] {
@ -48,7 +46,6 @@ private[daml] final class LedgerApiService(
metrics,
servicesExecutor,
apiServices.services,
rateLimitingConfig,
)
.acquire()
// Notify the caller that the services have been closed, so a reset request can complete

View File

@ -22,7 +22,7 @@ import java.util.concurrent.atomic.AtomicBoolean
import scala.jdk.CollectionConverters.ListHasAsScala
import scala.util.Try
private[apiserver] final class RateLimitingInterceptor(
final class RateLimitingInterceptor(
metrics: Metrics,
checks: List[LimitResultCheck],
) extends ServerInterceptor {
@ -74,12 +74,17 @@ private[apiserver] final class RateLimitingInterceptor(
object RateLimitingInterceptor {
def apply(metrics: Metrics, config: RateLimitingConfig): RateLimitingInterceptor = {
def apply(
metrics: Metrics,
config: RateLimitingConfig,
additionalChecks: List[LimitResultCheck] = List.empty,
): RateLimitingInterceptor = {
apply(
metrics = metrics,
config = config,
tenuredMemoryPools = ManagementFactory.getMemoryPoolMXBeans.asScala.toList,
memoryMxBean = ManagementFactory.getMemoryMXBean,
additionalChecks = additionalChecks,
)
}
@ -88,12 +93,9 @@ object RateLimitingInterceptor {
config: RateLimitingConfig,
tenuredMemoryPools: List[MemoryPoolMXBean],
memoryMxBean: MemoryMXBean,
additionalChecks: List[LimitResultCheck],
): RateLimitingInterceptor = {
val apiServices: ThreadpoolCount = new ThreadpoolCount(metrics)(
"Api Services Threadpool",
metrics.daml.lapi.threadpool.apiServices,
)
val indexDbThreadpool: ThreadpoolCount = new ThreadpoolCount(metrics)(
"Index Database Connection Threadpool",
MetricName(metrics.daml.index.db.threadpool.connection, ServerRole.ApiServer.threadPoolSuffix),
@ -106,10 +108,9 @@ object RateLimitingInterceptor {
metrics = metrics,
checks = List[LimitResultCheck](
MemoryCheck(tenuredMemoryPools, memoryMxBean, config),
ThreadpoolCheck(apiServices, config.maxApiServicesQueueSize),
ThreadpoolCheck(indexDbThreadpool, config.maxApiServicesIndexDbQueueSize),
StreamCheck(activeStreamsCounter, activeStreamsName, config.maxStreams),
),
) ::: additionalChecks,
)
}

View File

@ -27,19 +27,24 @@ object ThreadpoolCheck {
def queueSize: Long = submitted.getCount - running.getCount - completed.getCount
}
def apply(count: ThreadpoolCount, limit: Int): LimitResultCheck = (fullMethodName, _) => {
val queued = count.queueSize
if (queued > limit) {
OverLimit(
ThreadpoolOverloaded.Rejection(
name = count.name,
queued = queued,
limit = limit,
metricPrefix = count.prefix,
fullMethodName = fullMethodName,
)
)
} else UnderLimit
def apply(count: ThreadpoolCount, limit: Int): LimitResultCheck = {
apply(count.name, count.prefix, () => count.queueSize, limit)
}
def apply(name: String, prefix: String, queueSize: () => Long, limit: Int): LimitResultCheck =
(fullMethodName, _) => {
val queued = queueSize()
if (queued > limit) {
OverLimit(
ThreadpoolOverloaded.Rejection(
name = name,
queued = queued,
limit = limit,
metricPrefix = prefix,
fullMethodName = fullMethodName,
)
)
} else UnderLimit
}
}

View File

@ -81,7 +81,6 @@ case class TlsFixture(
tlsConfiguration = Some(serverTlsConfiguration),
servicesExecutor = servicesExecutor,
metrics = new Metrics(new MetricRegistry),
rateLimitingConfig = None,
)
)
}

View File

@ -10,13 +10,15 @@ import com.daml.grpc.sampleservice.implementations.HelloServiceReferenceImplemen
import com.daml.ledger.client.GrpcChannel
import com.daml.ledger.client.configuration.LedgerClientChannelConfiguration
import com.daml.ledger.resources.{ResourceOwner, TestResourceContext}
import com.daml.metrics.Metrics
import com.daml.metrics.{MetricName, Metrics}
import com.daml.platform.apiserver.GrpcServerSpec._
import com.daml.platform.apiserver.configuration.RateLimitingConfig
import com.daml.platform.apiserver.ratelimiting.RateLimitingInterceptor
import com.daml.platform.configuration.ServerRole
import com.daml.platform.hello.{HelloRequest, HelloResponse, HelloServiceGrpc}
import com.daml.ports.Port
import com.google.protobuf.ByteString
import io.grpc.{ManagedChannel, Status, StatusRuntimeException}
import io.grpc.{ManagedChannel, ServerInterceptor, Status, StatusRuntimeException}
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AsyncWordSpec
@ -85,12 +87,17 @@ final class GrpcServerSpec extends AsyncWordSpec with Matchers with TestResource
}
}
"rate limit interceptor is installed" in {
"install rate limit interceptor" in {
val metrics = new Metrics(new MetricRegistry)
resources(metrics).use { channel =>
val rateLimitingInterceptor = RateLimitingInterceptor(metrics, rateLimitingConfig)
resources(metrics, List(rateLimitingInterceptor)).use { channel =>
val metricName = MetricName(
metrics.daml.index.db.threadpool.connection,
ServerRole.ApiServer.threadPoolSuffix,
)
metrics.registry
.meter(MetricRegistry.name(metrics.daml.lapi.threadpool.apiServices, "submitted"))
.mark(rateLimitingConfig.maxApiServicesQueueSize.toLong + 1) // Over limit
.meter(MetricRegistry.name(metricName, "submitted"))
.mark(rateLimitingConfig.maxApiServicesIndexDbQueueSize.toLong + 1) // Over limit
val helloService = HelloServiceGrpc.stub(channel)
helloService.single(HelloRequest(7)).failed.map {
case s: StatusRuntimeException =>
@ -125,7 +132,8 @@ object GrpcServerSpec {
}
private def resources(
metrics: Metrics = new Metrics(new MetricRegistry)
metrics: Metrics = new Metrics(new MetricRegistry),
interceptors: List[ServerInterceptor] = List.empty,
): ResourceOwner[ManagedChannel] =
for {
executor <- ResourceOwner.forExecutorService(() => Executors.newSingleThreadExecutor())
@ -136,7 +144,7 @@ object GrpcServerSpec {
metrics = metrics,
servicesExecutor = executor,
services = Seq(new TestedHelloService),
rateLimitingConfig = Some(rateLimitingConfig),
interceptors = interceptors,
)
channel <- new GrpcChannel.Owner(
Port(server.getPort),

View File

@ -13,6 +13,8 @@ import com.daml.ledger.resources.{Resource, ResourceContext, ResourceOwner, Test
import com.daml.logging.LoggingContext
import com.daml.metrics.Metrics
import com.daml.platform.apiserver.configuration.RateLimitingConfig
import com.daml.platform.apiserver.ratelimiting.LimitResult.LimitResultCheck
import com.daml.platform.apiserver.ratelimiting.ThreadpoolCheck.ThreadpoolCount
import com.daml.platform.apiserver.services.GrpcClientResource
import com.daml.platform.configuration.ServerRole
import com.daml.platform.hello.{HelloRequest, HelloResponse, HelloServiceGrpc}
@ -61,26 +63,6 @@ final class RateLimitingInterceptorSpec
behavior of "RateLimitingInterceptor"
it should "limit calls when apiServices executor service is over limit" in {
val metrics = new Metrics(new MetricRegistry)
withChannel(metrics, new HelloServiceAkkaImplementation, config).use { channel: Channel =>
val helloService = HelloServiceGrpc.stub(channel)
val submitted = metrics.registry.meter(
MetricRegistry.name(metrics.daml.lapi.threadpool.apiServices, "submitted")
)
for {
_ <- helloService.single(HelloRequest(1))
_ = submitted.mark(config.maxApiServicesQueueSize.toLong + 1)
exception <- helloService.single(HelloRequest(2)).failed
_ = submitted.mark(-config.maxApiServicesQueueSize.toLong - 1)
_ <- helloService.single(HelloRequest(3))
} yield {
exception.getMessage should include(metrics.daml.lapi.threadpool.apiServices)
}
}
}
it should "limit calls when apiServices DB thread pool executor service is over limit" in {
val metrics = new Metrics(new MetricRegistry)
withChannel(metrics, new HelloServiceAkkaImplementation, config).use { channel: Channel =>
@ -391,6 +373,37 @@ final class RateLimitingInterceptorSpec
underTest.calculateCollectionUsageThreshold(101000) shouldBe 100000 // 101000 - 1000
}
it should "support addition checks" in {
val metrics = new Metrics(new MetricRegistry)
val apiServices: ThreadpoolCount = new ThreadpoolCount(metrics)(
"Api Services Threadpool",
metrics.daml.lapi.threadpool.apiServices,
)
val apiServicesCheck = ThreadpoolCheck(apiServices, config.maxApiServicesQueueSize)
withChannel(
metrics,
new HelloServiceAkkaImplementation,
config,
additionalChecks = List(apiServicesCheck),
).use { channel: Channel =>
val helloService = HelloServiceGrpc.stub(channel)
val submitted = metrics.registry.meter(
MetricRegistry.name(metrics.daml.lapi.threadpool.apiServices, "submitted")
)
for {
_ <- helloService.single(HelloRequest(1))
_ = submitted.mark(config.maxApiServicesQueueSize.toLong + 1)
exception <- helloService.single(HelloRequest(2)).failed
_ = submitted.mark(-config.maxApiServicesQueueSize.toLong - 1)
_ <- helloService.single(HelloRequest(3))
} yield {
exception.getMessage should include(metrics.daml.lapi.threadpool.apiServices)
}
}
}
}
object RateLimitingInterceptorSpec extends MockitoSugar {
@ -414,9 +427,13 @@ object RateLimitingInterceptorSpec extends MockitoSugar {
config: RateLimitingConfig,
pool: List[MemoryPoolMXBean] = List(underLimitMemoryPoolMXBean()),
memoryBean: MemoryMXBean = ManagementFactory.getMemoryMXBean,
additionalChecks: List[LimitResultCheck] = List.empty,
): ResourceOwner[Channel] =
for {
server <- serverOwner(RateLimitingInterceptor(metrics, config, pool, memoryBean), service)
server <- serverOwner(
RateLimitingInterceptor(metrics, config, pool, memoryBean, additionalChecks),
service,
)
channel <- GrpcClientResource.owner(Port(server.getPort))
} yield channel

View File

@ -35,11 +35,12 @@ import com.daml.logging.LoggingContext.newLoggingContext
import com.daml.logging.{ContextualizedLogger, LoggingContext}
import com.daml.metrics.{JvmMetricSet, Metrics}
import com.daml.platform.LedgerApiServer
import com.daml.platform.apiserver.LedgerFeatures
import com.daml.platform.apiserver.TimeServiceBackend
import com.daml.platform.apiserver.{LedgerFeatures, TimeServiceBackend}
import com.daml.platform.apiserver.configuration.RateLimitingConfig
import com.daml.platform.apiserver.ratelimiting.ThreadpoolCheck.ThreadpoolCount
import com.daml.platform.apiserver.ratelimiting.{RateLimitingInterceptor, ThreadpoolCheck}
import com.daml.platform.config.MetricsConfig.MetricRegistryType
import com.daml.platform.config.MetricsConfig
import com.daml.platform.config.ParticipantConfig
import com.daml.platform.config.{MetricsConfig, ParticipantConfig}
import com.daml.platform.store.DbSupport.ParticipantDataSourceConfig
import com.daml.platform.store.DbType
import com.daml.ports.Port
@ -134,6 +135,8 @@ object SandboxOnXRunner {
servicesExecutionContext = servicesExecutionContext,
metrics = metrics,
explicitDisclosureUnsafeEnabled = explicitDisclosureUnsafeEnabled,
rateLimitingInterceptor =
participantConfig.apiServer.rateLimit.map(buildRateLimitingInterceptor(metrics)),
)(actorSystem, materializer).owner
} yield {
logInitializationHeader(
@ -299,4 +302,19 @@ object SandboxOnXRunner {
ledgerDetails,
)
}
def buildRateLimitingInterceptor(
metrics: Metrics
)(config: RateLimitingConfig): RateLimitingInterceptor = {
val apiServices: ThreadpoolCount = new ThreadpoolCount(metrics)(
"Api Services Threadpool",
metrics.daml.lapi.threadpool.apiServices,
)
val apiServicesCheck = ThreadpoolCheck(apiServices, config.maxApiServicesQueueSize)
RateLimitingInterceptor(metrics, config, List(apiServicesCheck))
}
}