server: Enable and fix mypy in ws_connection.py (#13878)

* server: Enable and fix `mypy` in `ws_connection.py`

* Apply suggestions from code review

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Tweak error message

* Tweak formatting

* Make `WSChiaConnection.close_callback` optional

* Tweak assert message

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* Don't provide a default for `close_callback`

* Adjust assertion

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
dustinface 2022-11-12 21:18:11 +01:00 committed by GitHub
parent df519c8a5d
commit 2706f5995b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 68 additions and 42 deletions

View File

@ -4,7 +4,7 @@ from typing import Any
from chia.plot_sync.util import ErrorCodes, State
from chia.protocols.harvester_protocol import PlotSyncIdentifier
from chia.server.ws_connection import NodeType
from chia.server.outbound_message import NodeType
from chia.util.ints import uint64

View File

@ -27,7 +27,9 @@ from chia.protocols.harvester_protocol import (
PlotSyncResponse,
PlotSyncStart,
)
from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import int16, uint32, uint64
from chia.util.misc import get_list_or_len

View File

@ -23,7 +23,9 @@ from chia.protocols.harvester_protocol import (
PlotSyncResponse,
PlotSyncStart,
)
from chia.server.ws_connection import NodeType, ProtocolMessageTypes, WSChiaConnection, make_msg
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import NodeType, make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.util.generator_tools import list_to_batches
from chia.util.ints import int16, uint32, uint64

View File

@ -5,9 +5,11 @@ import contextlib
import logging
import time
import traceback
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from aiohttp import WSCloseCode, WSMessage, WSMsgType
from aiohttp import ClientSession, WSCloseCode, WSMessage, WSMsgType
from aiohttp.client import ClientWebSocketResponse
from aiohttp.web import WebSocketResponse
from chia.cmds.init_funcs import chia_full_version_str
from chia.protocols.protocol_message_types import ProtocolMessageTypes
@ -16,6 +18,7 @@ from chia.protocols.protocol_timing import INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
from chia.protocols.shared_protocol import Capability, Handshake
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.api_decorators import get_metadata
from chia.util.errors import Err, ProtocolError
@ -23,10 +26,13 @@ from chia.util.ints import uint8, uint16
# Each message is prepended with LENGTH_BYTES bytes specifying the length
from chia.util.network import class_for_type, is_localhost
from chia.util.streamable import Streamable
# Max size 2^(8*4) which is around 4GiB
LENGTH_BYTES: int = 4
WebSocket = Union[WebSocketResponse, ClientWebSocketResponse]
class WSChiaConnection:
"""
@ -38,23 +44,23 @@ class WSChiaConnection:
def __init__(
self,
local_type: NodeType,
ws: Any, # Websocket
ws: WebSocket,
server_port: int,
log: logging.Logger,
is_outbound: bool,
is_feeler: bool, # Special type of connection, that disconnects after the handshake.
peer_host,
incoming_queue,
close_callback: Callable,
peer_id,
peer_host: str,
incoming_queue: asyncio.Queue[Tuple[Message, WSChiaConnection]],
close_callback: Optional[Callable[[WSChiaConnection, int], None]],
peer_id: bytes32,
inbound_rate_limit_percent: int,
outbound_rate_limit_percent: int,
local_capabilities_for_handshake: List[Tuple[uint16, str]],
close_event=None,
session=None,
):
close_event: Optional[asyncio.Event] = None,
session: Optional[ClientSession] = None,
) -> None:
# Local properties
self.ws: Any = ws
self.ws = ws
self.local_type = local_type
self.local_port = server_port
self.local_capabilities_for_handshake = local_capabilities_for_handshake
@ -87,13 +93,13 @@ class WSChiaConnection:
self.last_message_time: float = 0
# Messaging
self.incoming_queue: asyncio.Queue = incoming_queue
self.outgoing_queue: asyncio.Queue = asyncio.Queue()
self.incoming_queue = incoming_queue
self.outgoing_queue: asyncio.Queue[Message] = asyncio.Queue()
self.inbound_task: Optional[asyncio.Task] = None
self.outbound_task: Optional[asyncio.Task] = None
self.inbound_task: Optional[asyncio.Task[None]] = None
self.outbound_task: Optional[asyncio.Task[None]] = None
self.active: bool = False # once handshake is successful this will be changed to True
self.close_event: asyncio.Event = close_event
self.close_event = close_event
self.session = session
self.close_callback = close_callback
@ -118,6 +124,7 @@ class WSChiaConnection:
self.protocol_version = ""
def _get_extra_info(self, name: str) -> Optional[Any]:
assert self.ws._writer is not None, "websocket's ._writer is None, was .prepare() called?"
return self.ws._writer.transport.get_extra_info(name)
async def perform_handshake(
@ -205,7 +212,12 @@ class WSChiaConnection:
self.outbound_task = asyncio.create_task(self.outbound_handler())
self.inbound_task = asyncio.create_task(self.inbound_handler())
async def close(self, ban_time: int = 0, ws_close_code: WSCloseCode = WSCloseCode.OK, error: Optional[Err] = None):
async def close(
self,
ban_time: int = 0,
ws_close_code: WSCloseCode = WSCloseCode.OK,
error: Optional[Err] = None,
) -> None:
"""
Closes the connection, and finally calls the close_callback on the server, so the connection gets removed
from the global list.
@ -236,24 +248,26 @@ class WSChiaConnection:
error_stack = traceback.format_exc()
self.log.warning(f"Exception closing socket: {error_stack}")
try:
self.close_callback(self, ban_time)
if self.close_callback is not None:
self.close_callback(self, ban_time)
except Exception:
error_stack = traceback.format_exc()
self.log.error(f"Error closing1: {error_stack}")
raise
try:
self.close_callback(self, ban_time)
if self.close_callback is not None:
self.close_callback(self, ban_time)
except Exception:
error_stack = traceback.format_exc()
self.log.error(f"Error closing2: {error_stack}")
async def ban_peer_bad_protocol(self, log_err_msg: str):
async def ban_peer_bad_protocol(self, log_err_msg: str) -> None:
"""Ban peer for protocol violation"""
ban_seconds = INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
self.log.error(f"Banning peer for {ban_seconds} seconds: {self.peer_host} {log_err_msg}")
await self.close(ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.INVALID_PROTOCOL_MESSAGE)
def cancel_pending_requests(self):
def cancel_pending_requests(self) -> None:
for message_id, event in self.pending_requests.items():
try:
event.set()
@ -283,10 +297,10 @@ class WSChiaConnection:
self.log.error(f"Exception: {e} with {self.peer_host}")
self.log.error(f"Exception Stack: {error_stack}")
async def inbound_handler(self):
async def inbound_handler(self) -> None:
try:
while not self.closed:
message: Message = await self._read_one_message()
message = await self._read_one_message()
if message is not None:
if message.id in self.pending_requests:
self.request_results[message.id] = message
@ -310,17 +324,19 @@ class WSChiaConnection:
await self.outgoing_queue.put(message)
return True
def __getattr__(self, attr_name: str):
def __getattr__(self, attr_name: str) -> Any:
# TODO KWARGS
async def invoke(*args, **kwargs):
async def invoke(*args: Any, **kwargs: Any) -> Optional[Streamable]:
timeout = 60
if "timeout" in kwargs:
timeout = kwargs["timeout"]
if self.connection_type is None:
raise ValueError("handshake not done yet")
attribute = getattr(class_for_type(self.connection_type), attr_name, None)
if attribute is None:
raise AttributeError(f"Node type {self.connection_type} does not have method {attr_name}")
request = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0])
request = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, bytes(args[0]))
request_start_t = time.time()
response = await self.send_request(request, timeout)
self.log.debug(
@ -339,7 +355,9 @@ class WSChiaConnection:
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message])
recv_method = getattr(class_for_type(self.local_type), recv_message_type.name)
return get_metadata(recv_method).message_class.from_bytes(response.data)
api_metadata = get_metadata(recv_method)
assert api_metadata is not None, f"ApiMetadata unavailable for {recv_method}"
return api_metadata.message_class.from_bytes(response.data)
return invoke
@ -380,13 +398,13 @@ class WSChiaConnection:
return result
async def send_messages(self, messages: List[Message]):
async def send_messages(self, messages: List[Message]) -> None:
if self.closed:
return None
for message in messages:
await self.outgoing_queue.put(message)
async def _wait_and_retry(self, msg: Message):
async def _wait_and_retry(self, msg: Message) -> None:
try:
await asyncio.sleep(1)
await self.outgoing_queue.put(msg)
@ -394,7 +412,7 @@ class WSChiaConnection:
self.log.debug(f"Exception {e} while waiting to retry sending rate limited message")
return None
async def _send_message(self, message: Message):
async def _send_message(self, message: Message) -> None:
encoded: bytes = bytes(message)
size = len(encoded)
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
@ -507,7 +525,7 @@ class WSChiaConnection:
def get_tls_version(self) -> str:
ssl_obj = self._get_extra_info("ssl_object")
if ssl_obj is not None:
return ssl_obj.version()
return str(ssl_obj.version())
else:
return "unknown"

File diff suppressed because one or more lines are too long

View File

@ -63,7 +63,7 @@ async def add_dummy_connection(
False,
self_hostname,
incoming_queue,
lambda x, y: x,
None,
peer_id,
100,
30,

View File

@ -9,6 +9,7 @@ from chia.server.server import ChiaServer, ssl_context_for_client
from chia.server.ssl_context import chia_ssl_ca_paths, private_ssl_ca_paths
from chia.server.ws_connection import WSChiaConnection
from chia.ssl.create_ssl import generate_ca_signed_cert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16
@ -29,8 +30,8 @@ async def establish_connection(server: ChiaServer, self_hostname: str, ssl_conte
False,
self_hostname,
incoming_queue,
lambda x, y: x,
None,
bytes32(b"\x00" * 32),
100,
30,
local_capabilities_for_handshake=capabilities,

View File

@ -20,8 +20,8 @@ from chia.plot_sync.util import Constants, State
from chia.plotting.manager import PlotManager
from chia.plotting.util import add_plot_directory, remove_plot_directory
from chia.protocols.harvester_protocol import Plot
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.start_service import Service
from chia.server.ws_connection import ProtocolMessageTypes
from chia.simulator.block_tools import BlockTools
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32

View File

@ -22,7 +22,7 @@ from chia.protocols.harvester_protocol import (
PlotSyncResponse,
PlotSyncStart,
)
from chia.server.ws_connection import NodeType
from chia.server.outbound_message import NodeType
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint8, uint32, uint64
from chia.util.misc import get_list_or_len

View File

@ -6,7 +6,8 @@ from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionType
from chia.plot_sync.sender import ExpectedResponse, Sender
from chia.plot_sync.util import Constants
from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse
from chia.server.ws_connection import NodeType, ProtocolMessageTypes
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import NodeType
from chia.simulator.block_tools import BlockTools
from chia.util.ints import int16, uint64
from tests.plot_sync.util import get_dummy_connection, plot_sync_identifier

View File

@ -21,8 +21,10 @@ from chia.plot_sync.util import Constants
from chia.plotting.manager import PlotManager
from chia.plotting.util import PlotInfo
from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.server.start_service import Service
from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.simulator.block_tools import BlockTools
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32

View File

@ -9,8 +9,8 @@ from chia.farmer.farmer import Farmer
from chia.harvester.harvester import Harvester
from chia.plot_sync.sender import Sender
from chia.protocols.harvester_protocol import PlotSyncIdentifier
from chia.server.outbound_message import Message, NodeType
from chia.server.start_service import Service
from chia.server.ws_connection import Message, NodeType
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo