protocols: Introduce error protocol message (#15493)

* protocols: Introduce `error` protocol message

* Add `Optional` extra data `Error.data`

* Test `Err` entries
This commit is contained in:
dustinface 2023-07-13 20:46:20 +02:00 committed by GitHub
parent f96690bdc5
commit ede354c58f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 177 additions and 14 deletions

View File

@ -113,3 +113,5 @@ class ProtocolMessageTypes(Enum):
respond_block_headers = 88
request_fee_estimates = 89
respond_fee_estimates = 90
error = 255

View File

@ -2,12 +2,13 @@ from __future__ import annotations
from dataclasses import dataclass
from enum import IntEnum
from typing import List, Tuple
from typing import List, Optional, Tuple
from chia.util.ints import uint8, uint16
from chia.util.ints import int16, uint8, uint16
from chia.util.streamable import Streamable, streamable
protocol_version = "0.0.34"
protocol_version = "0.0.35"
"""
Handshake when establishing a connection between two servers.
@ -49,3 +50,11 @@ capabilities = [
(uint16(Capability.RATE_LIMITS_V2.value), "1"),
# (uint16(Capability.NONE_RESPONSE.value), "1"), # capability removed but functionality is still supported
]
@streamable
@dataclass(frozen=True)
class Error(Streamable):
code: int16 # Err
message: str
data: Optional[bytes] = None

View File

@ -179,6 +179,7 @@ rate_limits = {
ProtocolMessageTypes.respond_puzzle_solution: RLSettings(5000, 1024 * 1024),
ProtocolMessageTypes.reject_puzzle_solution: RLSettings(5000, 100),
ProtocolMessageTypes.none_response: RLSettings(500, 100),
ProtocolMessageTypes.error: RLSettings(50000, 100),
},
"rate_limits_other": { # These will have a lower cap since they don't scale with high TPS (NON_TX_FREQ)
ProtocolMessageTypes.request_header_blocks: RLSettings(5000, 100),

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import contextlib
import logging
import math
import time
@ -20,7 +19,7 @@ from chia.cmds.init_funcs import chia_full_version_str
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_state_machine import message_response_ok
from chia.protocols.protocol_timing import API_EXCEPTION_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
from chia.protocols.shared_protocol import Capability, Handshake
from chia.protocols.shared_protocol import Capability, Error, Handshake
from chia.server.api_protocol import ApiProtocol
from chia.server.capabilities import known_active_capabilities
from chia.server.outbound_message import Message, NodeType, make_msg
@ -28,8 +27,8 @@ from chia.server.rate_limits import RateLimiter
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.api_decorators import get_metadata
from chia.util.errors import Err, ProtocolError
from chia.util.ints import uint8, uint16
from chia.util.errors import ApiError, Err, ProtocolError
from chia.util.ints import int16, uint8, uint16
from chia.util.log_exceptions import log_exceptions
# Each message is prepended with LENGTH_BYTES bytes specifying the length
@ -42,6 +41,8 @@ LENGTH_BYTES: int = 4
WebSocket = Union[WebSocketResponse, ClientWebSocketResponse]
ConnectionCallback = Callable[["WSChiaConnection"], Awaitable[None]]
error_response_version = Version("0.0.35")
def create_default_last_message_time_dict() -> Dict[ProtocolMessageTypes, float]:
return {message_type: -math.inf for message_type in ProtocolMessageTypes}
@ -362,6 +363,11 @@ class WSChiaConnection:
)
message_type = ProtocolMessageTypes(full_message.type).name
if full_message.type == ProtocolMessageTypes.error.value:
error = Error.from_bytes(full_message.data)
self.api.log.warning(f"ApiError: {error} from {self.peer_node_id}, {self.peer_info}")
return None
f = getattr(self.api, message_type, None)
if f is None:
@ -396,6 +402,15 @@ class WSChiaConnection:
return result
except asyncio.CancelledError:
pass
except ApiError as api_error:
self.log.warning(f"ApiError: {api_error} from {self.peer_node_id}, {self.peer_info}")
if self.protocol_version >= error_response_version:
return make_msg(
ProtocolMessageTypes.error,
Error(int16(api_error.code.value), api_error.message, api_error.data),
)
else:
return None
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Exception: {e}, {self.get_peer_logging()}. {tb}")
@ -497,6 +512,8 @@ class WSChiaConnection:
return None
sent_message_type = ProtocolMessageTypes(request.type)
recv_message_type = ProtocolMessageTypes(response.type)
if recv_message_type == ProtocolMessageTypes.error:
return Error.from_bytes(response.data)
if not message_response_ok(sent_message_type, recv_message_type):
# peer protocol violation
error_message = f"WSConnection.invoke sent message {sent_message_type.name} "
@ -532,9 +549,10 @@ class WSChiaConnection:
self.pending_requests[message.id] = event
await self.outgoing_queue.put(message)
# Either the result is available below or not, no need to detect the timeout error
with contextlib.suppress(asyncio.TimeoutError):
try:
await asyncio.wait_for(event.wait(), timeout=timeout)
except asyncio.TimeoutError:
self.log.debug(f"Request timeout: {message}")
self.pending_requests.pop(message.id)
result: Optional[Message] = None

View File

@ -2,7 +2,7 @@ from __future__ import annotations
from enum import Enum
from pathlib import Path
from typing import Any, List
from typing import Any, List, Optional
class Err(Enum):
@ -206,6 +206,14 @@ class ProtocolError(Exception):
self.errors = errors
class ApiError(Exception):
def __init__(self, code: Err, message: str, data: Optional[bytes] = None):
super().__init__(f"{code.name}: {message}")
self.code: Err = code
self.message: str = message
self.data: Optional[bytes] = data
##
# Keychain errors
##

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Callable, Tuple, cast
import pytest
@ -8,19 +9,37 @@ from packaging.version import Version
from chia.cmds.init_funcs import chia_full_version_str
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols.full_node_protocol import RequestTransaction
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import protocol_version
from chia.protocols.shared_protocol import Error, protocol_version
from chia.protocols.wallet_protocol import RejectHeaderRequest
from chia.server.outbound_message import make_msg
from chia.server.server import ChiaServer
from chia.server.ws_connection import WSChiaConnection, error_response_version
from chia.simulator.block_tools import BlockTools
from chia.simulator.setup_nodes import SimulatorsAndWalletsServices
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32
from chia.util.api_decorators import api_request
from chia.util.errors import ApiError, Err
from chia.util.ints import int16, uint16, uint32
from tests.connection_utils import connect_and_get_peer
@dataclass
class TestAPI:
log: logging.Logger = logging.getLogger(__name__)
def ready(self) -> bool:
return True
# API call from FullNodeAPI
@api_request()
async def request_transaction(self, request: RequestTransaction) -> None:
raise ApiError(Err.NO_TRANSACTIONS_WHILE_SYNCING, f"Some error message: {request.transaction_id}", bytes(b"ab"))
@pytest.mark.asyncio
async def test_duplicate_client_connection(
two_nodes: Tuple[FullNodeAPI, FullNodeAPI, ChiaServer, ChiaServer, BlockTools], self_hostname: str
@ -88,3 +107,68 @@ async def test_api_not_ready(
make_msg(ProtocolMessageTypes.reject_header_request, RejectHeaderRequest(uint32(0)))
)
await time_out_assert(10, request_ignored)
@pytest.mark.parametrize("version", ["0.0.34", "0.0.35", "0.0.36"])
@pytest.mark.asyncio
async def test_error_response(
one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices,
self_hostname: str,
caplog: pytest.LogCaptureFixture,
version: str,
) -> None:
[full_node_service], [wallet_service], _ = one_wallet_and_one_simulator_services
wallet_node = wallet_service._node
full_node = full_node_service._node
full_node.server.api = TestAPI()
await wallet_node.server.start_client(
PeerInfo(self_hostname, uint16(cast(FullNodeAPI, full_node_service._api).server._port)), None
)
wallet_connection = full_node.server.all_connections[wallet_node.server.node_id]
full_node_connection = wallet_node.server.all_connections[full_node.server.node_id]
test_version = Version(version)
wallet_connection.protocol_version = test_version
request = RequestTransaction(bytes32(32 * b"1"))
error_message = f"Some error message: {request.transaction_id}"
with caplog.at_level(logging.DEBUG):
response = await full_node_connection.call_api(TestAPI.request_transaction, request, timeout=5)
error = ApiError(Err.NO_TRANSACTIONS_WHILE_SYNCING, error_message)
assert f"ApiError: {error} from {wallet_connection.peer_node_id}, {wallet_connection.peer_info}" in caplog.text
if test_version >= error_response_version:
assert response == Error(int16(Err.NO_TRANSACTIONS_WHILE_SYNCING.value), error_message, bytes(b"ab"))
assert "Request timeout:" not in caplog.text
else:
assert response is None
assert "Request timeout:" in caplog.text
@pytest.mark.parametrize(
"error", [Error(int16(Err.UNKNOWN.value), "1", bytes([1, 2, 3])), Error(int16(Err.UNKNOWN.value), "2", None)]
)
@pytest.mark.asyncio
async def test_error_receive(
one_wallet_and_one_simulator_services: SimulatorsAndWalletsServices,
self_hostname: str,
caplog: pytest.LogCaptureFixture,
error: Error,
) -> None:
[full_node_service], [wallet_service], _ = one_wallet_and_one_simulator_services
wallet_node = wallet_service._node
full_node = full_node_service._node
await wallet_node.server.start_client(
PeerInfo(self_hostname, uint16(cast(FullNodeAPI, full_node_service._api).server._port)), None
)
wallet_connection = full_node.server.all_connections[wallet_node.server.node_id]
full_node_connection = wallet_node.server.all_connections[full_node.server.node_id]
message = make_msg(ProtocolMessageTypes.error, error)
def error_log_found(connection: WSChiaConnection) -> bool:
return f"ApiError: {error} from {connection.peer_node_id}, {connection.peer_info}" in caplog.text
with caplog.at_level(logging.WARNING):
await full_node_connection.outgoing_queue.put(message)
await wallet_connection.outgoing_queue.put(message)
await time_out_assert(10, error_log_found, True, full_node_connection)
await time_out_assert(10, error_log_found, True, wallet_connection)

View File

@ -139,6 +139,11 @@ def visit_timelord_protocol(visitor: Callable[[Any, str], None]) -> None:
visitor(respond_compact_proof_of_time, "respond_compact_proof_of_time")
def visit_shared_protocol(visitor: Callable[[Any, str], None]) -> None:
visitor(error_without_data, "error_without_data")
visitor(error_with_data, "error_with_data")
def visit_all_messages(visitor: Callable[[Any, str], None]) -> None:
visit_farmer_protocol(visitor)
visit_full_node(visitor)
@ -147,6 +152,7 @@ def visit_all_messages(visitor: Callable[[Any, str], None]) -> None:
visit_introducer_protocol(visitor)
visit_pool_protocol(visitor)
visit_timelord_protocol(visitor)
visit_shared_protocol(visitor)
def get_protocol_bytes() -> bytes:

View File

@ -13,6 +13,7 @@ from chia.protocols import (
timelord_protocol,
wallet_protocol,
)
from chia.protocols.shared_protocol import Error
from chia.types.blockchain_format.classgroup import ClassgroupElement
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.foliage import Foliage, FoliageBlockData, FoliageTransactionBlock, TransactionsInfo
@ -38,7 +39,13 @@ from chia.types.peer_info import TimestampedPeerInfo
from chia.types.spend_bundle import SpendBundle
from chia.types.unfinished_block import UnfinishedBlock
from chia.types.weight_proof import RecentChainData, SubEpochChallengeSegment, SubEpochData, SubSlotData, WeightProof
from chia.util.ints import uint8, uint16, uint32, uint64, uint128
from chia.util.errors import Err
from chia.util.ints import int16, uint8, uint16, uint32, uint64, uint128
# SHARED PROTOCOL
error_without_data = Error(int16(Err.UNKNOWN.value), "Unknown", None)
error_with_data = Error(int16(Err.UNKNOWN.value), "Unknown", bytes(b"extra data"))
### FARMER PROTOCOL
new_signage_point = farmer_protocol.NewSignagePoint(

View File

@ -2494,3 +2494,7 @@ respond_compact_proof_of_time_json: Dict[str, Any] = {
"height": 386395693,
"field_vdf": 224,
}
error_without_data_json: Dict[str, Any] = {"code": 1, "message": "Unknown", "data": None}
error_with_data_json: Dict[str, Any] = {"code": 1, "message": "Unknown", "data": "0x65787472612064617461"}

10
tests/util/test_errors.py Normal file
View File

@ -0,0 +1,10 @@
from __future__ import annotations
from chia.util.errors import Err
from chia.util.ints import int16
def test_error_codes_int16() -> None:
# Make sure all Err codes fit into int16 because its part of the ProtocolMessageTypes.error message structure
for err in Err:
assert int16(err.value) == err.value

View File

@ -490,4 +490,14 @@ def test_protocol_bytes() -> None:
assert message_92 == respond_compact_proof_of_time
assert bytes(message_92) == bytes(respond_compact_proof_of_time)
message_bytes, input_bytes = parse_blob(input_bytes)
message_93 = type(error_without_data).from_bytes(message_bytes)
assert message_93 == error_without_data
assert bytes(message_93) == bytes(error_without_data)
message_bytes, input_bytes = parse_blob(input_bytes)
message_94 = type(error_with_data).from_bytes(message_bytes)
assert message_94 == error_with_data
assert bytes(message_94) == bytes(error_with_data)
assert input_bytes == b""

View File

@ -218,3 +218,7 @@ def test_protocol_json() -> None:
type(respond_compact_proof_of_time).from_json_dict(respond_compact_proof_of_time_json)
== respond_compact_proof_of_time
)
assert str(error_without_data_json) == str(error_without_data.to_json_dict())
assert type(error_without_data).from_json_dict(error_without_data_json) == error_without_data
assert str(error_with_data_json) == str(error_with_data.to_json_dict())
assert type(error_with_data).from_json_dict(error_with_data_json) == error_with_data

View File

@ -169,7 +169,7 @@ def test_missing_messages() -> None:
"RespondCompactProofOfTime",
}
shared_msgs = {"Handshake", "Capability"}
shared_msgs = {"Handshake", "Capability", "Error"}
# if these asserts fail, make sure to add the new network protocol messages
# to the visitor in build_network_protocol_files.py and rerun it. Then