Reply type decorator (#8276)

* Check types of messages returned from other peers

* Check message type data structures at startup, check that no peeer messages that expect replies are broadcast, and dynamically check return types of messages that do expect replies.

* Type hint protocol message type check functions

* lint

* typo

* add replay_type decorator

* add api reply decorations

* typo

* Don't check non full-node broadcast messages

* Distinguish internal vs peer protocol error
Added Changelog
Moved static check to import time
Moved protocol timeouts to chia/protocols/protocol_timing.py
Comment typos
Rename create_request -> send_request
Comment that INVALID_PROTOCOL_MESSAGE is bannable, not temporary

* Call static check at module import time

* Rename message_response_ok

* Improve protocol checking for outgoing message validation

* Type-o

* reset submodule

Co-authored-by: almog <almogdepaz@gmail.com>
This commit is contained in:
Adam Kelly 2021-09-20 11:31:15 -07:00 committed by GitHub
parent 5766a8d367
commit e9bf0ec12b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 131 additions and 9 deletions

View File

@ -32,7 +32,7 @@ from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.mempool_item import MempoolItem
from chia.types.peer_info import PeerInfo
from chia.types.unfinished_block import UnfinishedBlock
from chia.util.api_decorators import api_request, peer_required, bytes_required, execute_task
from chia.util.api_decorators import api_request, peer_required, bytes_required, execute_task, reply_type
from chia.util.generator_tools import get_block_header
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint32, uint64, uint128
@ -62,6 +62,7 @@ class FullNodeAPI:
@peer_required
@api_request
@reply_type([ProtocolMessageTypes.respond_peers])
async def request_peers(self, _request: full_node_protocol.RequestPeers, peer: ws.WSChiaConnection):
if peer.peer_server_port is None:
return None
@ -189,6 +190,7 @@ class FullNodeAPI:
return None
@api_request
@reply_type([ProtocolMessageTypes.respond_transaction])
async def request_transaction(self, request: full_node_protocol.RequestTransaction) -> Optional[Message]:
"""Peer has requested a full transaction from us."""
# Ignore if syncing
@ -227,6 +229,7 @@ class FullNodeAPI:
return None
@api_request
@reply_type([ProtocolMessageTypes.respond_proof_of_weight])
async def request_proof_of_weight(self, request: full_node_protocol.RequestProofOfWeight) -> Optional[Message]:
if self.full_node.weight_proof_handler is None:
return None
@ -272,6 +275,7 @@ class FullNodeAPI:
return None
@api_request
@reply_type([ProtocolMessageTypes.respond_block, ProtocolMessageTypes.reject_block])
async def request_block(self, request: full_node_protocol.RequestBlock) -> Optional[Message]:
if not self.full_node.blockchain.contains_height(request.height):
reject = RejectBlock(request.height)
@ -288,6 +292,7 @@ class FullNodeAPI:
return msg
@api_request
@reply_type([ProtocolMessageTypes.respond_blocks, ProtocolMessageTypes.reject_blocks])
async def request_blocks(self, request: full_node_protocol.RequestBlocks) -> Optional[Message]:
if request.end_height < request.start_height or request.end_height - request.start_height > 32:
reject = RejectBlocks(request.start_height, request.end_height)
@ -399,6 +404,7 @@ class FullNodeAPI:
return msg
@api_request
@reply_type([ProtocolMessageTypes.respond_unfinished_block])
async def request_unfinished_block(
self, request_unfinished_block: full_node_protocol.RequestUnfinishedBlock
) -> Optional[Message]:
@ -509,6 +515,7 @@ class FullNodeAPI:
return make_msg(ProtocolMessageTypes.request_signage_point_or_end_of_sub_slot, full_node_request)
@api_request
@reply_type([ProtocolMessageTypes.respond_signage_point, ProtocolMessageTypes.respond_end_of_sub_slot])
async def request_signage_point_or_end_of_sub_slot(
self, request: full_node_protocol.RequestSignagePointOrEndOfSubSlot
) -> Optional[Message]:
@ -1300,6 +1307,7 @@ class FullNodeAPI:
@peer_required
@api_request
@reply_type([ProtocolMessageTypes.respond_compact_vdf])
async def request_compact_vdf(self, request: full_node_protocol.RequestCompactVDF, peer: ws.WSChiaConnection):
if self.full_node.sync_store.get_sync_mode():
return None

View File

@ -0,0 +1,64 @@
from chia.protocols.protocol_message_types import ProtocolMessageTypes as pmt, ProtocolMessageTypes
NO_REPLY_EXPECTED = [
# full_node -> full_node messages
pmt.new_peak,
pmt.new_transaction,
pmt.new_unfinished_block,
pmt.new_signage_point_or_end_of_sub_slot,
pmt.request_mempool_transactions,
pmt.new_compact_vdf,
pmt.request_mempool_transactions,
]
"""
VAILD_REPLY_MESSAGE_MAP:
key: sent message type.
value: valid reply message types, from the view of the requester.
A state machine can be built from this message map.
"""
VAILD_REPLY_MESSAGE_MAP = {
# messages for all services
# pmt.handshake is handled in WSChiaConnection.perform_handshake
# full_node -> full_node protocol messages
pmt.request_transaction: [pmt.respond_transaction],
pmt.request_proof_of_weight: [pmt.respond_proof_of_weight],
pmt.request_block: [pmt.respond_block, pmt.reject_block],
pmt.request_blocks: [pmt.respond_blocks, pmt.reject_blocks],
pmt.request_unfinished_block: [pmt.respond_unfinished_block],
pmt.request_signage_point_or_end_of_sub_slot: [pmt.respond_signage_point, pmt.respond_end_of_sub_slot],
pmt.request_compact_vdf: [pmt.respond_compact_vdf],
pmt.request_peers: [pmt.respond_peers],
}
def static_check_sent_message_response() -> None:
"""Check that allowed message data structures VALID_REPLY_MESSAGE_MAP and NO_REPLY_EXPECTED are consistent."""
# Reply and non-reply sets should not overlap: This check should be static
overlap = set(NO_REPLY_EXPECTED).intersection(set(VAILD_REPLY_MESSAGE_MAP.keys()))
if len(overlap) != 0:
raise AssertionError("Overlapping NO_REPLY_EXPECTED and VAILD_REPLY_MESSAGE_MAP values: {}")
def message_requires_reply(sent: ProtocolMessageTypes) -> bool:
"""Return True if message has an entry in the full node P2P message map"""
# If we knew the peer NodeType is FULL_NODE, we could also check `sent not in NO_REPLY_EXPECTED`
return sent in VAILD_REPLY_MESSAGE_MAP
def message_response_ok(sent: ProtocolMessageTypes, received: ProtocolMessageTypes) -> bool:
"""
Check to see that peers respect protocol message types in reply.
Call with received == None to indicate that we do not expect a specific reply message type.
"""
# Errors below are runtime protocol message mismatches from peers
if sent in VAILD_REPLY_MESSAGE_MAP:
if received not in VAILD_REPLY_MESSAGE_MAP[sent]:
return False
return True
# Run `static_check_sent_message_response` to check this static invariant at import time
static_check_sent_message_response()

View File

@ -0,0 +1,4 @@
# These settings should not be end-user configurable
INVALID_PROTOCOL_BAN_SECONDS = 10
API_EXCEPTION_BAN_SECONDS = 10
INTERNAL_PROTOCOL_ERROR_BAN_SECONDS = 10 # Don't flap if our client is at fault

View File

@ -16,6 +16,8 @@ from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_state_machine import message_requires_reply
from chia.protocols.protocol_timing import INVALID_PROTOCOL_BAN_SECONDS, API_EXCEPTION_BAN_SECONDS
from chia.protocols.shared_protocol import protocol_version
from chia.server.introducer_peers import IntroducerPeers
from chia.server.outbound_message import Message, NodeType
@ -159,8 +161,8 @@ class ChiaServer:
self.tasks_from_peer: Dict[bytes32, Set[bytes32]] = {}
self.banned_peers: Dict[str, float] = {}
self.invalid_protocol_ban_seconds = 10
self.api_exception_ban_seconds = 10
self.invalid_protocol_ban_seconds = INVALID_PROTOCOL_BAN_SECONDS
self.api_exception_ban_seconds = API_EXCEPTION_BAN_SECONDS
self.exempt_peer_networks: List[Union[IPv4Network, IPv6Network]] = [
ip_network(net, strict=False) for net in config.get("exempt_peer_networks", [])
]
@ -611,13 +613,29 @@ class ChiaServer:
for message in messages:
await connection.send_message(message)
async def validate_broadcast_message_type(self, messages: List[Message], node_type: NodeType):
for message in messages:
if message_requires_reply(ProtocolMessageTypes(message.type)):
# Internal protocol logic error - we will raise, blocking messages to all peers
self.log.error(f"Attempt to broadcast message requiring protocol response: {message.type}")
for _, connection in self.all_connections.items():
if connection.connection_type is node_type:
await connection.close(
self.invalid_protocol_ban_seconds,
WSCloseCode.INTERNAL_ERROR,
Err.INTERNAL_PROTOCOL_ERROR,
)
raise ProtocolError(Err.INTERNAL_PROTOCOL_ERROR, [message.type])
async def send_to_all(self, messages: List[Message], node_type: NodeType):
await self.validate_broadcast_message_type(messages, node_type)
for _, connection in self.all_connections.items():
if connection.connection_type is node_type:
for message in messages:
await connection.send_message(message)
async def send_to_all_except(self, messages: List[Message], node_type: NodeType, exclude: bytes32):
await self.validate_broadcast_message_type(messages, node_type)
for _, connection in self.all_connections.items():
if connection.connection_type is node_type and connection.peer_node_id != exclude:
for message in messages:

View File

@ -8,6 +8,8 @@ from aiohttp import WSCloseCode, WSMessage, WSMsgType
from chia.cmds.init_funcs import chia_full_version_str
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_state_machine import message_response_ok
from chia.protocols.protocol_timing import INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
from chia.protocols.shared_protocol import Capability, Handshake
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
@ -217,6 +219,12 @@ class WSChiaConnection:
raise
self.close_callback(self, ban_time)
async def ban_peer_bad_protocol(self, log_err_msg: str):
"""Ban peer for protocol violation"""
ban_seconds = INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
self.log.error(f"Banning peer for {ban_seconds} seconds: {self.peer_host} {log_err_msg}")
await self.close(ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.INVALID_PROTOCOL_MESSAGE)
def cancel_pending_timeouts(self):
for _, task in self.pending_timeouts.items():
task.cancel()
@ -274,14 +282,22 @@ class WSChiaConnection:
if attribute is None:
raise AttributeError(f"Node type {self.connection_type} does not have method {attr_name}")
msg = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0])
msg: Message = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0])
request_start_t = time.time()
result = await self.create_request(msg, timeout)
result = await self.send_request(msg, timeout)
self.log.debug(
f"Time for request {attr_name}: {self.get_peer_logging()} = {time.time() - request_start_t}, "
f"None? {result is None}"
)
if result is not None:
sent_message_type = ProtocolMessageTypes(msg.type)
recv_message_type = ProtocolMessageTypes(result.type)
if not message_response_ok(sent_message_type, recv_message_type):
# peer protocol violation
error_message = f"WSConnection.invoke sent message {sent_message_type.name} "
f"but received {recv_message_type.name}"
await self.ban_peer_bad_protocol(self.error_message)
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message])
ret_attr = getattr(class_for_type(self.local_type), ProtocolMessageTypes(result.type).name, None)
req_annotations = ret_attr.__annotations__
@ -297,7 +313,7 @@ class WSChiaConnection:
return invoke
async def create_request(self, message_no_id: Message, timeout: int) -> Optional[Message]:
async def send_request(self, message_no_id: Message, timeout: int) -> Optional[Message]:
"""Sends a message and waits for a response."""
if self.closed:
return None

View File

@ -59,3 +59,14 @@ def execute_task(func):
return func
return inner()
def reply_type(type):
def wrap(func):
def inner():
setattr(func, "reply_type", type)
return func
return inner()
return wrap

View File

@ -7,7 +7,7 @@ class Err(Enum):
DOES_NOT_EXTEND = -1
BAD_HEADER_SIGNATURE = -2
MISSING_FROM_STORAGE = -3
INVALID_PROTOCOL_MESSAGE = -4
INVALID_PROTOCOL_MESSAGE = -4 # We WILL ban for a protocol violation.
SELF_CONNECTION = -5
INVALID_HANDSHAKE = -6
INVALID_ACK = -7
@ -129,8 +129,8 @@ class Err(Enum):
INVALID_PREFARM = 104
ASSERT_SECONDS_RELATIVE_FAILED = 105
BAD_COINBASE_SIGNATURE = 106
# removed
# INITIAL_TRANSACTION_FREEZE = 107
# INITIAL_TRANSACTION_FREEZE = 107 # removed
NO_TRANSACTIONS_WHILE_SYNCING = 108
ALREADY_INCLUDING_TRANSACTION = 109
INCOMPATIBLE_NETWORK_ID = 110
@ -151,6 +151,7 @@ class Err(Enum):
INVALID_FEE_TOO_CLOSE_TO_ZERO = 123
COIN_AMOUNT_NEGATIVE = 124
INTERNAL_PROTOCOL_ERROR = 125
class ValidationError(Exception):