mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-08-16 14:20:47 +03:00
server: Enable and fix mypy
in ws_connection.py
(#13878)
* server: Enable and fix `mypy` in `ws_connection.py` * Apply suggestions from code review Co-authored-by: Kyle Altendorf <sda@fstab.net> * Tweak error message * Tweak formatting * Make `WSChiaConnection.close_callback` optional * Tweak assert message Co-authored-by: Kyle Altendorf <sda@fstab.net> * Don't provide a default for `close_callback` * Adjust assertion Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
parent
df519c8a5d
commit
2706f5995b
@ -4,7 +4,7 @@ from typing import Any
|
||||
|
||||
from chia.plot_sync.util import ErrorCodes, State
|
||||
from chia.protocols.harvester_protocol import PlotSyncIdentifier
|
||||
from chia.server.ws_connection import NodeType
|
||||
from chia.server.outbound_message import NodeType
|
||||
from chia.util.ints import uint64
|
||||
|
||||
|
||||
|
@ -27,7 +27,9 @@ from chia.protocols.harvester_protocol import (
|
||||
PlotSyncResponse,
|
||||
PlotSyncStart,
|
||||
)
|
||||
from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import make_msg
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.util.ints import int16, uint32, uint64
|
||||
from chia.util.misc import get_list_or_len
|
||||
|
@ -23,7 +23,9 @@ from chia.protocols.harvester_protocol import (
|
||||
PlotSyncResponse,
|
||||
PlotSyncStart,
|
||||
)
|
||||
from chia.server.ws_connection import NodeType, ProtocolMessageTypes, WSChiaConnection, make_msg
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import NodeType, make_msg
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.util.generator_tools import list_to_batches
|
||||
from chia.util.ints import int16, uint32, uint64
|
||||
|
||||
|
@ -5,9 +5,11 @@ import contextlib
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from aiohttp import WSCloseCode, WSMessage, WSMsgType
|
||||
from aiohttp import ClientSession, WSCloseCode, WSMessage, WSMsgType
|
||||
from aiohttp.client import ClientWebSocketResponse
|
||||
from aiohttp.web import WebSocketResponse
|
||||
|
||||
from chia.cmds.init_funcs import chia_full_version_str
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
@ -16,6 +18,7 @@ from chia.protocols.protocol_timing import INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
|
||||
from chia.protocols.shared_protocol import Capability, Handshake
|
||||
from chia.server.outbound_message import Message, NodeType, make_msg
|
||||
from chia.server.rate_limits import RateLimiter
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.api_decorators import get_metadata
|
||||
from chia.util.errors import Err, ProtocolError
|
||||
@ -23,10 +26,13 @@ from chia.util.ints import uint8, uint16
|
||||
|
||||
# Each message is prepended with LENGTH_BYTES bytes specifying the length
|
||||
from chia.util.network import class_for_type, is_localhost
|
||||
from chia.util.streamable import Streamable
|
||||
|
||||
# Max size 2^(8*4) which is around 4GiB
|
||||
LENGTH_BYTES: int = 4
|
||||
|
||||
WebSocket = Union[WebSocketResponse, ClientWebSocketResponse]
|
||||
|
||||
|
||||
class WSChiaConnection:
|
||||
"""
|
||||
@ -38,23 +44,23 @@ class WSChiaConnection:
|
||||
def __init__(
|
||||
self,
|
||||
local_type: NodeType,
|
||||
ws: Any, # Websocket
|
||||
ws: WebSocket,
|
||||
server_port: int,
|
||||
log: logging.Logger,
|
||||
is_outbound: bool,
|
||||
is_feeler: bool, # Special type of connection, that disconnects after the handshake.
|
||||
peer_host,
|
||||
incoming_queue,
|
||||
close_callback: Callable,
|
||||
peer_id,
|
||||
peer_host: str,
|
||||
incoming_queue: asyncio.Queue[Tuple[Message, WSChiaConnection]],
|
||||
close_callback: Optional[Callable[[WSChiaConnection, int], None]],
|
||||
peer_id: bytes32,
|
||||
inbound_rate_limit_percent: int,
|
||||
outbound_rate_limit_percent: int,
|
||||
local_capabilities_for_handshake: List[Tuple[uint16, str]],
|
||||
close_event=None,
|
||||
session=None,
|
||||
):
|
||||
close_event: Optional[asyncio.Event] = None,
|
||||
session: Optional[ClientSession] = None,
|
||||
) -> None:
|
||||
# Local properties
|
||||
self.ws: Any = ws
|
||||
self.ws = ws
|
||||
self.local_type = local_type
|
||||
self.local_port = server_port
|
||||
self.local_capabilities_for_handshake = local_capabilities_for_handshake
|
||||
@ -87,13 +93,13 @@ class WSChiaConnection:
|
||||
self.last_message_time: float = 0
|
||||
|
||||
# Messaging
|
||||
self.incoming_queue: asyncio.Queue = incoming_queue
|
||||
self.outgoing_queue: asyncio.Queue = asyncio.Queue()
|
||||
self.incoming_queue = incoming_queue
|
||||
self.outgoing_queue: asyncio.Queue[Message] = asyncio.Queue()
|
||||
|
||||
self.inbound_task: Optional[asyncio.Task] = None
|
||||
self.outbound_task: Optional[asyncio.Task] = None
|
||||
self.inbound_task: Optional[asyncio.Task[None]] = None
|
||||
self.outbound_task: Optional[asyncio.Task[None]] = None
|
||||
self.active: bool = False # once handshake is successful this will be changed to True
|
||||
self.close_event: asyncio.Event = close_event
|
||||
self.close_event = close_event
|
||||
self.session = session
|
||||
self.close_callback = close_callback
|
||||
|
||||
@ -118,6 +124,7 @@ class WSChiaConnection:
|
||||
self.protocol_version = ""
|
||||
|
||||
def _get_extra_info(self, name: str) -> Optional[Any]:
|
||||
assert self.ws._writer is not None, "websocket's ._writer is None, was .prepare() called?"
|
||||
return self.ws._writer.transport.get_extra_info(name)
|
||||
|
||||
async def perform_handshake(
|
||||
@ -205,7 +212,12 @@ class WSChiaConnection:
|
||||
self.outbound_task = asyncio.create_task(self.outbound_handler())
|
||||
self.inbound_task = asyncio.create_task(self.inbound_handler())
|
||||
|
||||
async def close(self, ban_time: int = 0, ws_close_code: WSCloseCode = WSCloseCode.OK, error: Optional[Err] = None):
|
||||
async def close(
|
||||
self,
|
||||
ban_time: int = 0,
|
||||
ws_close_code: WSCloseCode = WSCloseCode.OK,
|
||||
error: Optional[Err] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Closes the connection, and finally calls the close_callback on the server, so the connection gets removed
|
||||
from the global list.
|
||||
@ -236,24 +248,26 @@ class WSChiaConnection:
|
||||
error_stack = traceback.format_exc()
|
||||
self.log.warning(f"Exception closing socket: {error_stack}")
|
||||
try:
|
||||
self.close_callback(self, ban_time)
|
||||
if self.close_callback is not None:
|
||||
self.close_callback(self, ban_time)
|
||||
except Exception:
|
||||
error_stack = traceback.format_exc()
|
||||
self.log.error(f"Error closing1: {error_stack}")
|
||||
raise
|
||||
try:
|
||||
self.close_callback(self, ban_time)
|
||||
if self.close_callback is not None:
|
||||
self.close_callback(self, ban_time)
|
||||
except Exception:
|
||||
error_stack = traceback.format_exc()
|
||||
self.log.error(f"Error closing2: {error_stack}")
|
||||
|
||||
async def ban_peer_bad_protocol(self, log_err_msg: str):
|
||||
async def ban_peer_bad_protocol(self, log_err_msg: str) -> None:
|
||||
"""Ban peer for protocol violation"""
|
||||
ban_seconds = INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
|
||||
self.log.error(f"Banning peer for {ban_seconds} seconds: {self.peer_host} {log_err_msg}")
|
||||
await self.close(ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.INVALID_PROTOCOL_MESSAGE)
|
||||
|
||||
def cancel_pending_requests(self):
|
||||
def cancel_pending_requests(self) -> None:
|
||||
for message_id, event in self.pending_requests.items():
|
||||
try:
|
||||
event.set()
|
||||
@ -283,10 +297,10 @@ class WSChiaConnection:
|
||||
self.log.error(f"Exception: {e} with {self.peer_host}")
|
||||
self.log.error(f"Exception Stack: {error_stack}")
|
||||
|
||||
async def inbound_handler(self):
|
||||
async def inbound_handler(self) -> None:
|
||||
try:
|
||||
while not self.closed:
|
||||
message: Message = await self._read_one_message()
|
||||
message = await self._read_one_message()
|
||||
if message is not None:
|
||||
if message.id in self.pending_requests:
|
||||
self.request_results[message.id] = message
|
||||
@ -310,17 +324,19 @@ class WSChiaConnection:
|
||||
await self.outgoing_queue.put(message)
|
||||
return True
|
||||
|
||||
def __getattr__(self, attr_name: str):
|
||||
def __getattr__(self, attr_name: str) -> Any:
|
||||
# TODO KWARGS
|
||||
async def invoke(*args, **kwargs):
|
||||
async def invoke(*args: Any, **kwargs: Any) -> Optional[Streamable]:
|
||||
timeout = 60
|
||||
if "timeout" in kwargs:
|
||||
timeout = kwargs["timeout"]
|
||||
if self.connection_type is None:
|
||||
raise ValueError("handshake not done yet")
|
||||
attribute = getattr(class_for_type(self.connection_type), attr_name, None)
|
||||
if attribute is None:
|
||||
raise AttributeError(f"Node type {self.connection_type} does not have method {attr_name}")
|
||||
|
||||
request = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, args[0])
|
||||
request = Message(uint8(getattr(ProtocolMessageTypes, attr_name).value), None, bytes(args[0]))
|
||||
request_start_t = time.time()
|
||||
response = await self.send_request(request, timeout)
|
||||
self.log.debug(
|
||||
@ -339,7 +355,9 @@ class WSChiaConnection:
|
||||
raise ProtocolError(Err.INVALID_PROTOCOL_MESSAGE, [error_message])
|
||||
|
||||
recv_method = getattr(class_for_type(self.local_type), recv_message_type.name)
|
||||
return get_metadata(recv_method).message_class.from_bytes(response.data)
|
||||
api_metadata = get_metadata(recv_method)
|
||||
assert api_metadata is not None, f"ApiMetadata unavailable for {recv_method}"
|
||||
return api_metadata.message_class.from_bytes(response.data)
|
||||
|
||||
return invoke
|
||||
|
||||
@ -380,13 +398,13 @@ class WSChiaConnection:
|
||||
|
||||
return result
|
||||
|
||||
async def send_messages(self, messages: List[Message]):
|
||||
async def send_messages(self, messages: List[Message]) -> None:
|
||||
if self.closed:
|
||||
return None
|
||||
for message in messages:
|
||||
await self.outgoing_queue.put(message)
|
||||
|
||||
async def _wait_and_retry(self, msg: Message):
|
||||
async def _wait_and_retry(self, msg: Message) -> None:
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
await self.outgoing_queue.put(msg)
|
||||
@ -394,7 +412,7 @@ class WSChiaConnection:
|
||||
self.log.debug(f"Exception {e} while waiting to retry sending rate limited message")
|
||||
return None
|
||||
|
||||
async def _send_message(self, message: Message):
|
||||
async def _send_message(self, message: Message) -> None:
|
||||
encoded: bytes = bytes(message)
|
||||
size = len(encoded)
|
||||
assert len(encoded) < (2 ** (LENGTH_BYTES * 8))
|
||||
@ -507,7 +525,7 @@ class WSChiaConnection:
|
||||
def get_tls_version(self) -> str:
|
||||
ssl_obj = self._get_extra_info("ssl_object")
|
||||
if ssl_obj is not None:
|
||||
return ssl_obj.version()
|
||||
return str(ssl_obj.version())
|
||||
else:
|
||||
return "unknown"
|
||||
|
||||
|
@ -63,7 +63,7 @@ async def add_dummy_connection(
|
||||
False,
|
||||
self_hostname,
|
||||
incoming_queue,
|
||||
lambda x, y: x,
|
||||
None,
|
||||
peer_id,
|
||||
100,
|
||||
30,
|
||||
|
@ -9,6 +9,7 @@ from chia.server.server import ChiaServer, ssl_context_for_client
|
||||
from chia.server.ssl_context import chia_ssl_ca_paths, private_ssl_ca_paths
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.ssl.create_ssl import generate_ca_signed_cert
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.ints import uint16
|
||||
|
||||
@ -29,8 +30,8 @@ async def establish_connection(server: ChiaServer, self_hostname: str, ssl_conte
|
||||
False,
|
||||
self_hostname,
|
||||
incoming_queue,
|
||||
lambda x, y: x,
|
||||
None,
|
||||
bytes32(b"\x00" * 32),
|
||||
100,
|
||||
30,
|
||||
local_capabilities_for_handshake=capabilities,
|
||||
|
@ -20,8 +20,8 @@ from chia.plot_sync.util import Constants, State
|
||||
from chia.plotting.manager import PlotManager
|
||||
from chia.plotting.util import add_plot_directory, remove_plot_directory
|
||||
from chia.protocols.harvester_protocol import Plot
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.start_service import Service
|
||||
from chia.server.ws_connection import ProtocolMessageTypes
|
||||
from chia.simulator.block_tools import BlockTools
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
|
@ -22,7 +22,7 @@ from chia.protocols.harvester_protocol import (
|
||||
PlotSyncResponse,
|
||||
PlotSyncStart,
|
||||
)
|
||||
from chia.server.ws_connection import NodeType
|
||||
from chia.server.outbound_message import NodeType
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.util.ints import uint8, uint32, uint64
|
||||
from chia.util.misc import get_list_or_len
|
||||
|
@ -6,7 +6,8 @@ from chia.plot_sync.exceptions import AlreadyStartedError, InvalidConnectionType
|
||||
from chia.plot_sync.sender import ExpectedResponse, Sender
|
||||
from chia.plot_sync.util import Constants
|
||||
from chia.protocols.harvester_protocol import PlotSyncIdentifier, PlotSyncResponse
|
||||
from chia.server.ws_connection import NodeType, ProtocolMessageTypes
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import NodeType
|
||||
from chia.simulator.block_tools import BlockTools
|
||||
from chia.util.ints import int16, uint64
|
||||
from tests.plot_sync.util import get_dummy_connection, plot_sync_identifier
|
||||
|
@ -21,8 +21,10 @@ from chia.plot_sync.util import Constants
|
||||
from chia.plotting.manager import PlotManager
|
||||
from chia.plotting.util import PlotInfo
|
||||
from chia.protocols.harvester_protocol import PlotSyncError, PlotSyncResponse
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.server.outbound_message import make_msg
|
||||
from chia.server.start_service import Service
|
||||
from chia.server.ws_connection import ProtocolMessageTypes, WSChiaConnection, make_msg
|
||||
from chia.server.ws_connection import WSChiaConnection
|
||||
from chia.simulator.block_tools import BlockTools
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
|
@ -9,8 +9,8 @@ from chia.farmer.farmer import Farmer
|
||||
from chia.harvester.harvester import Harvester
|
||||
from chia.plot_sync.sender import Sender
|
||||
from chia.protocols.harvester_protocol import PlotSyncIdentifier
|
||||
from chia.server.outbound_message import Message, NodeType
|
||||
from chia.server.start_service import Service
|
||||
from chia.server.ws_connection import Message, NodeType
|
||||
from chia.simulator.time_out_assert import time_out_assert
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.types.peer_info import PeerInfo
|
||||
|
Loading…
Reference in New Issue
Block a user