Improves error handling for unknown protocol types (#7073)

* Improves error handling for unknown protocol types

* Add test for invalid protocol
This commit is contained in:
Earle Lowe 2021-07-01 21:20:21 -07:00 committed by GitHub
parent 3b71210a9c
commit 46395374ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 4 deletions

View File

@ -122,8 +122,16 @@ class WSChiaConnection:
if inbound_handshake_msg is None:
raise ProtocolError(Err.INVALID_HANDSHAKE)
inbound_handshake = Handshake.from_bytes(inbound_handshake_msg.data)
if ProtocolMessageTypes(inbound_handshake_msg.type) != ProtocolMessageTypes.handshake:
# Handle case of invalid ProtocolMessageType
try:
message_type: ProtocolMessageTypes = ProtocolMessageTypes(inbound_handshake_msg.type)
except Exception:
raise ProtocolError(Err.INVALID_HANDSHAKE)
if message_type != ProtocolMessageTypes.handshake:
raise ProtocolError(Err.INVALID_HANDSHAKE)
if inbound_handshake.network_id != network_id:
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
@ -138,9 +146,17 @@ class WSChiaConnection:
if message is None:
raise ProtocolError(Err.INVALID_HANDSHAKE)
inbound_handshake = Handshake.from_bytes(message.data)
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.handshake:
# Handle case of invalid ProtocolMessageType
try:
message_type = ProtocolMessageTypes(message.type)
except Exception:
raise ProtocolError(Err.INVALID_HANDSHAKE)
if message_type != ProtocolMessageTypes.handshake:
raise ProtocolError(Err.INVALID_HANDSHAKE)
inbound_handshake = Handshake.from_bytes(message.data)
if inbound_handshake.network_id != network_id:
raise ProtocolError(Err.INCOMPATIBLE_NETWORK_ID)
outbound_handshake = make_msg(

View File

@ -8,12 +8,14 @@ from aiohttp import ClientSession, ClientTimeout, ServerDisconnectedError, WSClo
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols import full_node_protocol
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.protocols.shared_protocol import Handshake
from chia.server.outbound_message import make_msg, Message
from chia.server.rate_limits import RateLimiter
from chia.server.server import ssl_context_for_client
from chia.server.ws_connection import WSChiaConnection
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint64
from chia.util.errors import Err
from tests.setup_nodes import self_hostname, setup_simulators_and_wallets
from tests.time_out_assert import time_out_assert
@ -142,6 +144,39 @@ class TestDos:
await session.close()
@pytest.mark.asyncio
async def test_invalid_protocol_handshake(self, setup_two_nodes):
nodes, _ = setup_two_nodes
server_1 = nodes[0].full_node.server
server_2 = nodes[1].full_node.server
server_1.invalid_protocol_ban_seconds = 10
# Use the server_2 ssl information to connect to server_1
timeout = ClientTimeout(total=10)
session = ClientSession(timeout=timeout)
url = f"wss://{self_hostname}:{server_1._port}/ws"
ssl_context = ssl_context_for_client(
server_2.chia_ca_crt_path, server_2.chia_ca_key_path, server_2.p2p_crt_path, server_2.p2p_key_path
)
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
# Construct an otherwise valid handshake message
handshake: Handshake = Handshake("test", "0.0.32", "1.0.0.0", 3456, 1, [(1, "1")])
outbound_handshake: Message = Message(2, None, bytes(handshake)) # 2 is an invalid ProtocolType
await ws.send_bytes(bytes(outbound_handshake))
response: WSMessage = await ws.receive()
print(response)
assert response.type == WSMsgType.CLOSE
assert response.data == WSCloseCode.PROTOCOL_ERROR
assert response.extra == str(int(Err.INVALID_HANDSHAKE.value)) # We want INVALID_HANDSHAKE and not UNKNOWN
await ws.close()
await session.close()
await asyncio.sleep(1) # give some time for cleanup to work
@pytest.mark.asyncio
async def test_spam_tx(self, setup_two_nodes):
nodes, _ = setup_two_nodes