don't increment counters for outgoing messages blocked by the rate limit. (#3518)

This was causing a problem where outbound messages, blocked by the rate limiter,
would still increment the counters as-if they had been sent. This, in turn,
could cause other message types to get blocked becuase the rate limiter thought
we had sent a lot of the other (blocked) message type.
This commit is contained in:
Arvid Norberg 2021-05-03 20:18:29 +02:00 committed by GitHub
parent 5ce1bfc34c
commit 912dc84663
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 124 additions and 45 deletions

View File

@ -102,7 +102,24 @@ rate_limits_other = {
class RateLimiter:
def __init__(self, reset_seconds=60, percentage_of_limit=100):
incoming: bool
reset_seconds: int
current_minute: int
message_counts: Counter
message_cumulative_sizes: Counter
percentage_of_limit: int
non_tx_message_counts: int = 0
non_tx_cumulative_size: int = 0
def __init__(self, incoming: bool, reset_seconds=60, percentage_of_limit=100):
"""
The incoming parameter affects whether counters are incremented
unconditionally or not. For incoming messages, the counters are always
incremeneted. For outgoing messages, the counters are only incremented
if they are allowed to be sent by the rate limiter, since we won't send
the messages otherwise.
"""
self.incoming = incoming
self.reset_seconds = reset_seconds
self.current_minute = time.time() // reset_seconds
self.message_counts = Counter()
@ -116,7 +133,7 @@ class RateLimiter:
Returns True if message can be processed successfully, false if a rate limit is passed.
"""
current_minute = time.time() // self.reset_seconds
current_minute = int(time.time() // self.reset_seconds)
if current_minute != self.current_minute:
self.current_minute = current_minute
self.message_counts = Counter()
@ -129,31 +146,49 @@ class RateLimiter:
log.warning(f"Invalid message: {message.type}, {e}")
return True
self.message_counts[message_type] += 1
self.message_cumulative_sizes[message_type] += len(message.data)
proportion_of_limit = self.percentage_of_limit / 100
new_message_counts: int = self.message_counts[message_type] + 1
new_cumulative_size: int = self.message_cumulative_sizes[message_type] + len(message.data)
new_non_tx_count: int = self.non_tx_message_counts
new_non_tx_size: int = self.non_tx_cumulative_size
proportion_of_limit: float = self.percentage_of_limit / 100
limits = DEFAULT_SETTINGS
if message_type in rate_limits_tx:
limits = rate_limits_tx[message_type]
elif message_type in rate_limits_other:
limits = rate_limits_other[message_type]
self.non_tx_message_counts += 1
self.non_tx_cumulative_size += len(message.data)
if self.non_tx_message_counts > NON_TX_FREQ * proportion_of_limit:
ret: bool = False
try:
limits = DEFAULT_SETTINGS
if message_type in rate_limits_tx:
limits = rate_limits_tx[message_type]
elif message_type in rate_limits_other:
limits = rate_limits_other[message_type]
new_non_tx_count = self.non_tx_message_counts + 1
new_non_tx_size = self.non_tx_cumulative_size + len(message.data)
if new_non_tx_count > NON_TX_FREQ * proportion_of_limit:
return False
if new_non_tx_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
return False
else:
log.warning(f"Message type {message_type} not found in rate limits")
if limits.max_total_size is None:
limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
assert limits.max_total_size is not None
if new_message_counts > limits.frequency * proportion_of_limit:
return False
if self.non_tx_cumulative_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
if len(message.data) > limits.max_size:
return False
if new_cumulative_size > limits.max_total_size * proportion_of_limit:
return False
else:
log.warning(f"Message type {message_type} not found in rate limits")
if limits.max_total_size is None:
limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
if self.message_counts[message_type] > limits.frequency * proportion_of_limit:
return False
if len(message.data) > limits.max_size:
return False
if self.message_cumulative_sizes[message_type] > limits.max_total_size * proportion_of_limit:
return False
return True
ret = True
return True
finally:
if self.incoming or ret:
# now that we determined that it's OK to send the message, commit the
# updates to the counters. Alternatively, if this was an
# incoming message, we already received it and it should
# increment the counters unconditionally
self.message_counts[message_type] = new_message_counts
self.message_cumulative_sizes[message_type] = new_cumulative_size
self.non_tx_message_counts = new_non_tx_count
self.non_tx_cumulative_size = new_non_tx_size

View File

@ -100,8 +100,8 @@ class WSChiaConnection:
# This means that even if the other peer's boundaries for each minute are not aligned, we will not
# disconnect. Also it allows a little flexibility.
self.outbound_rate_limiter = RateLimiter(percentage_of_limit=outbound_rate_limit_percent)
self.inbound_rate_limiter = RateLimiter(percentage_of_limit=inbound_rate_limit_percent)
self.outbound_rate_limiter = RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent)
self.inbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent)
async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType):
if self.is_outbound:

View File

@ -177,7 +177,7 @@ class TestDos:
assert not ws_con.closed
# Remove outbound rate limiter to test inbound limits
ws_con.outbound_rate_limiter = RateLimiter(percentage_of_limit=10000)
ws_con.outbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=10000)
for i in range(6000):
await ws_con._send_message(new_tx_message)
@ -232,7 +232,7 @@ class TestDos:
assert not ws_con.closed
# Remove outbound rate limiter to test inbound limits
ws_con.outbound_rate_limiter = RateLimiter(percentage_of_limit=10000)
ws_con.outbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=10000)
for i in range(6):
await ws_con._send_message(new_message)

View File

@ -4,7 +4,7 @@ import pytest
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.server.rate_limits import RateLimiter
from chia.server.rate_limits import RateLimiter, NON_TX_FREQ
from tests.setup_nodes import test_constants
@ -21,7 +21,7 @@ class TestRateLimits:
@pytest.mark.asyncio
async def test_too_many_messages(self):
# Too many messages
r = RateLimiter()
r = RateLimiter(incoming=True)
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
for i in range(3000):
assert r.process_msg_and_check(new_tx_message)
@ -34,7 +34,7 @@ class TestRateLimits:
assert saw_disconnect
# Non-tx message
r = RateLimiter()
r = RateLimiter(incoming=True)
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
for i in range(20):
assert r.process_msg_and_check(new_peak_message)
@ -52,14 +52,14 @@ class TestRateLimits:
small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024))
r = RateLimiter()
r = RateLimiter(incoming=True)
assert r.process_msg_and_check(small_tx_message)
assert r.process_msg_and_check(small_tx_message)
assert not r.process_msg_and_check(large_tx_message)
small_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024))
large_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024))
r = RateLimiter()
r = RateLimiter(incoming=True)
assert r.process_msg_and_check(small_vdf_message)
assert r.process_msg_and_check(small_vdf_message)
assert not r.process_msg_and_check(large_vdf_message)
@ -67,7 +67,7 @@ class TestRateLimits:
@pytest.mark.asyncio
async def test_too_much_data(self):
# Too much data
r = RateLimiter()
r = RateLimiter(incoming=True)
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
for i in range(10):
assert r.process_msg_and_check(tx_message)
@ -79,7 +79,7 @@ class TestRateLimits:
saw_disconnect = True
assert saw_disconnect
r = RateLimiter()
r = RateLimiter(incoming=True)
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
for i in range(10):
assert r.process_msg_and_check(block_message)
@ -94,7 +94,7 @@ class TestRateLimits:
@pytest.mark.asyncio
async def test_non_tx_aggregate_limits(self):
# Frequency limits
r = RateLimiter()
r = RateLimiter(incoming=True)
message_1 = make_msg(ProtocolMessageTypes.request_additions, bytes([1] * 5 * 1024))
message_2 = make_msg(ProtocolMessageTypes.request_removals, bytes([1] * 1024))
message_3 = make_msg(ProtocolMessageTypes.respond_additions, bytes([1] * 1024))
@ -112,7 +112,7 @@ class TestRateLimits:
assert saw_disconnect
# Size limits
r = RateLimiter()
r = RateLimiter(incoming=True)
message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 49 * 1024 * 1024))
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 49 * 1024 * 1024))
@ -128,7 +128,7 @@ class TestRateLimits:
@pytest.mark.asyncio
async def test_periodic_reset(self):
r = RateLimiter(5)
r = RateLimiter(True, 5)
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
for i in range(10):
assert r.process_msg_and_check(tx_message)
@ -144,7 +144,7 @@ class TestRateLimits:
assert r.process_msg_and_check(tx_message)
# Counts reset also
r = RateLimiter(5)
r = RateLimiter(True, 5)
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
for i in range(3000):
assert r.process_msg_and_check(new_tx_message)
@ -161,7 +161,7 @@ class TestRateLimits:
@pytest.mark.asyncio
async def test_percentage_limits(self):
r = RateLimiter(60, 40)
r = RateLimiter(True, 60, 40)
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
for i in range(50):
assert r.process_msg_and_check(new_peak_message)
@ -173,7 +173,7 @@ class TestRateLimits:
saw_disconnect = True
assert saw_disconnect
r = RateLimiter(60, 40)
r = RateLimiter(True, 60, 40)
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
for i in range(5):
assert r.process_msg_and_check(block_message)
@ -186,7 +186,7 @@ class TestRateLimits:
assert saw_disconnect
# Aggregate percentage limit count
r = RateLimiter(60, 40)
r = RateLimiter(True, 60, 40)
message_1 = make_msg(ProtocolMessageTypes.request_additions, bytes([1] * 5 * 1024))
message_2 = make_msg(ProtocolMessageTypes.request_removals, bytes([1] * 1024))
message_3 = make_msg(ProtocolMessageTypes.respond_additions, bytes([1] * 1024))
@ -204,7 +204,7 @@ class TestRateLimits:
assert saw_disconnect
# Aggregate percentage limit max total size
r = RateLimiter(60, 40)
r = RateLimiter(True, 60, 40)
message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 18 * 1024 * 1024))
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 24 * 1024 * 1024))
@ -217,3 +217,47 @@ class TestRateLimits:
if not response:
saw_disconnect = True
assert saw_disconnect
@pytest.mark.asyncio
async def test_too_many_outgoing_messages(self):
# Too many messages
r = RateLimiter(incoming=False)
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
passed = 0
blocked = 0
for i in range(NON_TX_FREQ):
if r.process_msg_and_check(new_peers_message):
passed += 1
else:
blocked += 1
assert passed == 10
assert blocked == NON_TX_FREQ - passed
# ensure that *another* message type is not blocked because of this
new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
assert r.process_msg_and_check(new_signatures_message)
@pytest.mark.asyncio
async def test_too_many_incoming_messages(self):
# Too many messages
r = RateLimiter(incoming=True)
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
passed = 0
blocked = 0
for i in range(NON_TX_FREQ):
if r.process_msg_and_check(new_peers_message):
passed += 1
else:
blocked += 1
assert passed == 10
assert blocked == NON_TX_FREQ - passed
# ensure that other message types *are* blocked because of this
new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
assert not r.process_msg_and_check(new_signatures_message)