mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-20 08:05:33 +03:00
don't increment counters for outgoing messages blocked by the rate limit. (#3518)
This was causing a problem where outbound messages, blocked by the rate limiter, would still increment the counters as-if they had been sent. This, in turn, could cause other message types to get blocked becuase the rate limiter thought we had sent a lot of the other (blocked) message type.
This commit is contained in:
parent
5ce1bfc34c
commit
912dc84663
@ -102,7 +102,24 @@ rate_limits_other = {
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, reset_seconds=60, percentage_of_limit=100):
|
||||
incoming: bool
|
||||
reset_seconds: int
|
||||
current_minute: int
|
||||
message_counts: Counter
|
||||
message_cumulative_sizes: Counter
|
||||
percentage_of_limit: int
|
||||
non_tx_message_counts: int = 0
|
||||
non_tx_cumulative_size: int = 0
|
||||
|
||||
def __init__(self, incoming: bool, reset_seconds=60, percentage_of_limit=100):
|
||||
"""
|
||||
The incoming parameter affects whether counters are incremented
|
||||
unconditionally or not. For incoming messages, the counters are always
|
||||
incremeneted. For outgoing messages, the counters are only incremented
|
||||
if they are allowed to be sent by the rate limiter, since we won't send
|
||||
the messages otherwise.
|
||||
"""
|
||||
self.incoming = incoming
|
||||
self.reset_seconds = reset_seconds
|
||||
self.current_minute = time.time() // reset_seconds
|
||||
self.message_counts = Counter()
|
||||
@ -116,7 +133,7 @@ class RateLimiter:
|
||||
Returns True if message can be processed successfully, false if a rate limit is passed.
|
||||
"""
|
||||
|
||||
current_minute = time.time() // self.reset_seconds
|
||||
current_minute = int(time.time() // self.reset_seconds)
|
||||
if current_minute != self.current_minute:
|
||||
self.current_minute = current_minute
|
||||
self.message_counts = Counter()
|
||||
@ -129,31 +146,49 @@ class RateLimiter:
|
||||
log.warning(f"Invalid message: {message.type}, {e}")
|
||||
return True
|
||||
|
||||
self.message_counts[message_type] += 1
|
||||
self.message_cumulative_sizes[message_type] += len(message.data)
|
||||
proportion_of_limit = self.percentage_of_limit / 100
|
||||
new_message_counts: int = self.message_counts[message_type] + 1
|
||||
new_cumulative_size: int = self.message_cumulative_sizes[message_type] + len(message.data)
|
||||
new_non_tx_count: int = self.non_tx_message_counts
|
||||
new_non_tx_size: int = self.non_tx_cumulative_size
|
||||
proportion_of_limit: float = self.percentage_of_limit / 100
|
||||
|
||||
limits = DEFAULT_SETTINGS
|
||||
if message_type in rate_limits_tx:
|
||||
limits = rate_limits_tx[message_type]
|
||||
elif message_type in rate_limits_other:
|
||||
limits = rate_limits_other[message_type]
|
||||
self.non_tx_message_counts += 1
|
||||
self.non_tx_cumulative_size += len(message.data)
|
||||
if self.non_tx_message_counts > NON_TX_FREQ * proportion_of_limit:
|
||||
ret: bool = False
|
||||
try:
|
||||
|
||||
limits = DEFAULT_SETTINGS
|
||||
if message_type in rate_limits_tx:
|
||||
limits = rate_limits_tx[message_type]
|
||||
elif message_type in rate_limits_other:
|
||||
limits = rate_limits_other[message_type]
|
||||
new_non_tx_count = self.non_tx_message_counts + 1
|
||||
new_non_tx_size = self.non_tx_cumulative_size + len(message.data)
|
||||
if new_non_tx_count > NON_TX_FREQ * proportion_of_limit:
|
||||
return False
|
||||
if new_non_tx_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
|
||||
return False
|
||||
else:
|
||||
log.warning(f"Message type {message_type} not found in rate limits")
|
||||
|
||||
if limits.max_total_size is None:
|
||||
limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
|
||||
assert limits.max_total_size is not None
|
||||
|
||||
if new_message_counts > limits.frequency * proportion_of_limit:
|
||||
return False
|
||||
if self.non_tx_cumulative_size > NON_TX_MAX_TOTAL_SIZE * proportion_of_limit:
|
||||
if len(message.data) > limits.max_size:
|
||||
return False
|
||||
if new_cumulative_size > limits.max_total_size * proportion_of_limit:
|
||||
return False
|
||||
else:
|
||||
log.warning(f"Message type {message_type} not found in rate limits")
|
||||
|
||||
if limits.max_total_size is None:
|
||||
limits = dataclasses.replace(limits, max_total_size=limits.frequency * limits.max_size)
|
||||
|
||||
if self.message_counts[message_type] > limits.frequency * proportion_of_limit:
|
||||
return False
|
||||
if len(message.data) > limits.max_size:
|
||||
return False
|
||||
if self.message_cumulative_sizes[message_type] > limits.max_total_size * proportion_of_limit:
|
||||
return False
|
||||
return True
|
||||
ret = True
|
||||
return True
|
||||
finally:
|
||||
if self.incoming or ret:
|
||||
# now that we determined that it's OK to send the message, commit the
|
||||
# updates to the counters. Alternatively, if this was an
|
||||
# incoming message, we already received it and it should
|
||||
# increment the counters unconditionally
|
||||
self.message_counts[message_type] = new_message_counts
|
||||
self.message_cumulative_sizes[message_type] = new_cumulative_size
|
||||
self.non_tx_message_counts = new_non_tx_count
|
||||
self.non_tx_cumulative_size = new_non_tx_size
|
||||
|
@ -100,8 +100,8 @@ class WSChiaConnection:
|
||||
|
||||
# This means that even if the other peer's boundaries for each minute are not aligned, we will not
|
||||
# disconnect. Also it allows a little flexibility.
|
||||
self.outbound_rate_limiter = RateLimiter(percentage_of_limit=outbound_rate_limit_percent)
|
||||
self.inbound_rate_limiter = RateLimiter(percentage_of_limit=inbound_rate_limit_percent)
|
||||
self.outbound_rate_limiter = RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent)
|
||||
self.inbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent)
|
||||
|
||||
async def perform_handshake(self, network_id: str, protocol_version: str, server_port: int, local_type: NodeType):
|
||||
if self.is_outbound:
|
||||
|
@ -177,7 +177,7 @@ class TestDos:
|
||||
assert not ws_con.closed
|
||||
|
||||
# Remove outbound rate limiter to test inbound limits
|
||||
ws_con.outbound_rate_limiter = RateLimiter(percentage_of_limit=10000)
|
||||
ws_con.outbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=10000)
|
||||
|
||||
for i in range(6000):
|
||||
await ws_con._send_message(new_tx_message)
|
||||
@ -232,7 +232,7 @@ class TestDos:
|
||||
assert not ws_con.closed
|
||||
|
||||
# Remove outbound rate limiter to test inbound limits
|
||||
ws_con.outbound_rate_limiter = RateLimiter(percentage_of_limit=10000)
|
||||
ws_con.outbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=10000)
|
||||
|
||||
for i in range(6):
|
||||
await ws_con._send_message(new_message)
|
||||
|
@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import make_msg
|
||||
from chia.server.rate_limits import RateLimiter
|
||||
from chia.server.rate_limits import RateLimiter, NON_TX_FREQ
|
||||
from tests.setup_nodes import test_constants
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ class TestRateLimits:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_many_messages(self):
|
||||
# Too many messages
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
|
||||
for i in range(3000):
|
||||
assert r.process_msg_and_check(new_tx_message)
|
||||
@ -34,7 +34,7 @@ class TestRateLimits:
|
||||
assert saw_disconnect
|
||||
|
||||
# Non-tx message
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
|
||||
for i in range(20):
|
||||
assert r.process_msg_and_check(new_peak_message)
|
||||
@ -52,14 +52,14 @@ class TestRateLimits:
|
||||
small_tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
|
||||
large_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 3 * 1024 * 1024))
|
||||
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
assert r.process_msg_and_check(small_tx_message)
|
||||
assert r.process_msg_and_check(small_tx_message)
|
||||
assert not r.process_msg_and_check(large_tx_message)
|
||||
|
||||
small_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 5 * 1024))
|
||||
large_vdf_message = make_msg(ProtocolMessageTypes.respond_signage_point, bytes([1] * 600 * 1024))
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
assert r.process_msg_and_check(small_vdf_message)
|
||||
assert r.process_msg_and_check(small_vdf_message)
|
||||
assert not r.process_msg_and_check(large_vdf_message)
|
||||
@ -67,7 +67,7 @@ class TestRateLimits:
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_much_data(self):
|
||||
# Too much data
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
|
||||
for i in range(10):
|
||||
assert r.process_msg_and_check(tx_message)
|
||||
@ -79,7 +79,7 @@ class TestRateLimits:
|
||||
saw_disconnect = True
|
||||
assert saw_disconnect
|
||||
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
|
||||
for i in range(10):
|
||||
assert r.process_msg_and_check(block_message)
|
||||
@ -94,7 +94,7 @@ class TestRateLimits:
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_tx_aggregate_limits(self):
|
||||
# Frequency limits
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
message_1 = make_msg(ProtocolMessageTypes.request_additions, bytes([1] * 5 * 1024))
|
||||
message_2 = make_msg(ProtocolMessageTypes.request_removals, bytes([1] * 1024))
|
||||
message_3 = make_msg(ProtocolMessageTypes.respond_additions, bytes([1] * 1024))
|
||||
@ -112,7 +112,7 @@ class TestRateLimits:
|
||||
assert saw_disconnect
|
||||
|
||||
# Size limits
|
||||
r = RateLimiter()
|
||||
r = RateLimiter(incoming=True)
|
||||
message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 49 * 1024 * 1024))
|
||||
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 49 * 1024 * 1024))
|
||||
|
||||
@ -128,7 +128,7 @@ class TestRateLimits:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_periodic_reset(self):
|
||||
r = RateLimiter(5)
|
||||
r = RateLimiter(True, 5)
|
||||
tx_message = make_msg(ProtocolMessageTypes.respond_transaction, bytes([1] * 500 * 1024))
|
||||
for i in range(10):
|
||||
assert r.process_msg_and_check(tx_message)
|
||||
@ -144,7 +144,7 @@ class TestRateLimits:
|
||||
assert r.process_msg_and_check(tx_message)
|
||||
|
||||
# Counts reset also
|
||||
r = RateLimiter(5)
|
||||
r = RateLimiter(True, 5)
|
||||
new_tx_message = make_msg(ProtocolMessageTypes.new_transaction, bytes([1] * 40))
|
||||
for i in range(3000):
|
||||
assert r.process_msg_and_check(new_tx_message)
|
||||
@ -161,7 +161,7 @@ class TestRateLimits:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_percentage_limits(self):
|
||||
r = RateLimiter(60, 40)
|
||||
r = RateLimiter(True, 60, 40)
|
||||
new_peak_message = make_msg(ProtocolMessageTypes.new_peak, bytes([1] * 40))
|
||||
for i in range(50):
|
||||
assert r.process_msg_and_check(new_peak_message)
|
||||
@ -173,7 +173,7 @@ class TestRateLimits:
|
||||
saw_disconnect = True
|
||||
assert saw_disconnect
|
||||
|
||||
r = RateLimiter(60, 40)
|
||||
r = RateLimiter(True, 60, 40)
|
||||
block_message = make_msg(ProtocolMessageTypes.respond_block, bytes([1] * 1024 * 1024))
|
||||
for i in range(5):
|
||||
assert r.process_msg_and_check(block_message)
|
||||
@ -186,7 +186,7 @@ class TestRateLimits:
|
||||
assert saw_disconnect
|
||||
|
||||
# Aggregate percentage limit count
|
||||
r = RateLimiter(60, 40)
|
||||
r = RateLimiter(True, 60, 40)
|
||||
message_1 = make_msg(ProtocolMessageTypes.request_additions, bytes([1] * 5 * 1024))
|
||||
message_2 = make_msg(ProtocolMessageTypes.request_removals, bytes([1] * 1024))
|
||||
message_3 = make_msg(ProtocolMessageTypes.respond_additions, bytes([1] * 1024))
|
||||
@ -204,7 +204,7 @@ class TestRateLimits:
|
||||
assert saw_disconnect
|
||||
|
||||
# Aggregate percentage limit max total size
|
||||
r = RateLimiter(60, 40)
|
||||
r = RateLimiter(True, 60, 40)
|
||||
message_4 = make_msg(ProtocolMessageTypes.respond_proof_of_weight, bytes([1] * 18 * 1024 * 1024))
|
||||
message_5 = make_msg(ProtocolMessageTypes.respond_blocks, bytes([1] * 24 * 1024 * 1024))
|
||||
|
||||
@ -217,3 +217,47 @@ class TestRateLimits:
|
||||
if not response:
|
||||
saw_disconnect = True
|
||||
assert saw_disconnect
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_many_outgoing_messages(self):
|
||||
# Too many messages
|
||||
r = RateLimiter(incoming=False)
|
||||
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
|
||||
|
||||
passed = 0
|
||||
blocked = 0
|
||||
for i in range(NON_TX_FREQ):
|
||||
if r.process_msg_and_check(new_peers_message):
|
||||
passed += 1
|
||||
else:
|
||||
blocked += 1
|
||||
|
||||
assert passed == 10
|
||||
assert blocked == NON_TX_FREQ - passed
|
||||
|
||||
# ensure that *another* message type is not blocked because of this
|
||||
|
||||
new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
|
||||
assert r.process_msg_and_check(new_signatures_message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_too_many_incoming_messages(self):
|
||||
# Too many messages
|
||||
r = RateLimiter(incoming=True)
|
||||
new_peers_message = make_msg(ProtocolMessageTypes.respond_peers, bytes([1]))
|
||||
|
||||
passed = 0
|
||||
blocked = 0
|
||||
for i in range(NON_TX_FREQ):
|
||||
if r.process_msg_and_check(new_peers_message):
|
||||
passed += 1
|
||||
else:
|
||||
blocked += 1
|
||||
|
||||
assert passed == 10
|
||||
assert blocked == NON_TX_FREQ - passed
|
||||
|
||||
# ensure that other message types *are* blocked because of this
|
||||
|
||||
new_signatures_message = make_msg(ProtocolMessageTypes.respond_signatures, bytes([1]))
|
||||
assert not r.process_msg_and_check(new_signatures_message)
|
||||
|
Loading…
Reference in New Issue
Block a user