mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-20 16:08:51 +03:00
Improves error handling for unknown protocol types (#7073)
* Improves error handling for unknown protocol types * Add test for invalid protocol
This commit is contained in:
parent
3b71210a9c
commit
46395374ae
@ -122,8 +122,16 @@ class WSChiaConnection:
|
||||
if inbound_handshake_msg is None:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
|
||||
if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake:
|
||||
|
||||
# Handle case of invalid ProtocolMessageType
|
||||
try:
|
||||
message_type: ProtocolMessageTypes = ProtocolMessageTypes(inbound_handshake_msg.type)
|
||||
except Exception:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
|
||||
if message_type != ProtocolMessageTypes.handshake:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
|
||||
if inbound_handshake.network_id != network_id:
|
||||
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
|
||||
|
||||
@ -138,9 +146,17 @@ class WSChiaConnection:
|
||||
|
||||
if message is None:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
inbound_handshake = Handshake.from_bytes(message.data)
|
||||
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake:
|
||||
|
||||
# Handle case of invalid ProtocolMessageType
|
||||
try:
|
||||
message_type = ProtocolMessageTypes(message.type)
|
||||
except Exception:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
|
||||
if message_type != ProtocolMessageTypes.handshake:
|
||||
raise ProtocolError(Err.INVALID_HANDSHAKE)
|
||||
|
||||
inbound_handshake = Handshake.from_bytes(message.data)
|
||||
if inbound_handshake.network_id != network_id:
|
||||
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
|
||||
outbound_handshake = make_msg(
|
||||
|
@ -8,12 +8,14 @@ from aiohttp import ClientSession, ClientTimeout, ServerDisconnectedError, WSClo
|
||||
from chia.full_node.full_node_api import FullNodeAPI
|
||||
from chia.protocols import full_node_protocol
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import make_msg
|
||||
from chia.protocols.shared_protocol import Handshake
|
||||
from chia.server.outbound_message import make_msg, Message
|
||||
from chia.server.rate_limits import RateLimiter
|
||||
from chia.server.server import ssl_context_for_client
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.ints import uint16, uint64
|
||||
from chia.util.errors import Err
|
||||
from tests.setup_nodes import self_hostname, setup_simulators_and_wallets
|
||||
from tests.time_out_assert import time_out_assert
|
||||
|
||||
@ -142,6 +144,39 @@ class TestDos:
|
||||
|
||||
await session.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_protocol_handshake(self, setup_two_nodes):
|
||||
nodes, _ = setup_two_nodes
|
||||
server_1 = nodes[0].full_node.server
|
||||
server_2 = nodes[1].full_node.server
|
||||
|
||||
server_1.invalid_protocol_ban_seconds = 10
|
||||
# Use the server_2 ssl information to connect to server_1
|
||||
timeout = ClientTimeout(total=10)
|
||||
session = ClientSession(timeout=timeout)
|
||||
url = f"wss://{self_hostname}:{server_1._port}/ws"
|
||||
|
||||
ssl_context = ssl_context_for_client(
|
||||
server_2.chia_ca_crt_path, server_2.chia_ca_key_path, server_2.p2p_crt_path, server_2.p2p_key_path
|
||||
)
|
||||
ws = await session.ws_connect(
|
||||
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
|
||||
)
|
||||
|
||||
# Construct an otherwise valid handshake message
|
||||
handshake: Handshake = Handshake("test", "0.0.32", "1.0.0.0", 3456, 1, [(1, "1")])
|
||||
outbound_handshake: Message = Message(2, None, bytes(handshake)) # 2 is an invalid ProtocolType
|
||||
await ws.send_bytes(bytes(outbound_handshake))
|
||||
|
||||
response: WSMessage = await ws.receive()
|
||||
print(response)
|
||||
assert response.type == WSMsgType.CLOSE
|
||||
assert response.data == WSCloseCode.PROTOCOL_ERROR
|
||||
assert response.extra == str(int(Err.INVALID_HANDSHAKE.value)) # We want INVALID_HANDSHAKE and not UNKNOWN
|
||||
await ws.close()
|
||||
await session.close()
|
||||
await asyncio.sleep(1) # give some time for cleanup to work
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spam_tx(self, setup_two_nodes):
|
||||
nodes, _ = setup_two_nodes
|
||||
|
Loading…
Reference in New Issue
Block a user