diff --git a/chia/server/server.py b/chia/server/server.py index 44319bcc29e5..af720b730f98 100644 --- a/chia/server/server.py +++ b/chia/server/server.py @@ -315,7 +315,7 @@ class ChiaServer: return ws connection: Optional[WSChiaConnection] = None try: - connection = WSChiaConnection( + connection = WSChiaConnection.create( self._local_type, ws, self._port, @@ -467,7 +467,7 @@ class ChiaServer: if peer_id == self.node_id: raise RuntimeError(f"Trying to connect to a peer ({target_node}) with the same peer_id: {peer_id}") - connection = WSChiaConnection( + connection = WSChiaConnection.create( self._local_type, ws, self._port, diff --git a/chia/server/ws_connection.py b/chia/server/ws_connection.py index 589451b49c5d..5c57bf0f256f 100644 --- a/chia/server/ws_connection.py +++ b/chia/server/ws_connection.py @@ -5,12 +5,13 @@ import contextlib import logging import time import traceback +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Union from aiohttp import ClientSession, WSCloseCode, WSMessage, WSMsgType from aiohttp.client import ClientWebSocketResponse from aiohttp.web import WebSocketResponse -from typing_extensions import Protocol +from typing_extensions import Protocol, final from chia.cmds.init_funcs import chia_full_version_str from chia.protocols.protocol_message_types import ProtocolMessageTypes @@ -46,6 +47,8 @@ class ConnectionClosedCallbackProtocol(Protocol): ... +@final +@dataclass class WSChiaConnection: """ Represents a connection to another node. Local host and port are ours, while peer host and @@ -53,8 +56,54 @@ class WSChiaConnection: set after the handshake is performed in this connection. """ - def __init__( - self, + ws: WebSocket + local_type: NodeType + local_port: int + local_capabilities_for_handshake: List[Tuple[uint16, str]] + local_capabilities: List[Capability] + peer_host: str + peer_port: uint16 + peer_node_id: bytes32 + log: logging.Logger + + close_callback: Optional[ConnectionClosedCallbackProtocol] + outbound_rate_limiter: RateLimiter + inbound_rate_limiter: RateLimiter + + # connection properties + is_outbound: bool + is_feeler: bool + + # Messaging + incoming_queue: asyncio.Queue[Tuple[Message, WSChiaConnection]] + outgoing_queue: asyncio.Queue[Message] = field(default_factory=asyncio.Queue) + + # ChiaConnection metrics + creation_time: float = field(default_factory=time.time) + bytes_read: int = 0 + bytes_written: int = 0 + last_message_time: float = 0 + + peer_server_port: Optional[uint16] = None + inbound_task: Optional[asyncio.Task[None]] = None + outbound_task: Optional[asyncio.Task[None]] = None + active: bool = False # once handshake is successful this will be changed to True + close_event: Optional[asyncio.Event] = None + session: Optional[ClientSession] = None + + pending_requests: Dict[uint16, asyncio.Event] = field(default_factory=dict) + request_results: Dict[uint16, Message] = field(default_factory=dict) + closed: bool = False + connection_type: Optional[NodeType] = None + request_nonce: uint16 = uint16(0) + peer_capabilities: List[Capability] = field(default_factory=list) + # Used by the Chia Seeder. + version: str = field(default_factory=str) + protocol_version: str = field(default_factory=str) + + @classmethod + def create( + cls, local_type: NodeType, ws: WebSocket, server_port: int, @@ -70,70 +119,41 @@ class WSChiaConnection: local_capabilities_for_handshake: List[Tuple[uint16, str]], close_event: Optional[asyncio.Event] = None, session: Optional[ClientSession] = None, - ) -> None: - # Local properties - self.ws = ws - self.local_type = local_type - self.local_port = server_port - self.local_capabilities_for_handshake = local_capabilities_for_handshake - self.local_capabilities: List[Capability] = [ - Capability(x[0]) for x in local_capabilities_for_handshake if x[1] == "1" - ] + ) -> WSChiaConnection: - # Remote properties - self.peer_host = peer_host - - peername = self._get_extra_info("peername") + assert ws._writer is not None + peername = ws._writer.transport.get_extra_info("peername") if peername is None: - raise ValueError(f"Was not able to get peername from {self.peer_host}") + raise ValueError(f"Was not able to get peername from {peer_host}") - connection_port = peername[1] - self.peer_port = connection_port - self.peer_server_port: Optional[uint16] = None - self.peer_node_id = peer_id - self.log = log - - # connection properties - self.is_outbound = is_outbound - self.is_feeler = is_feeler - - # ChiaConnection metrics - self.creation_time = time.time() - self.bytes_read = 0 - self.bytes_written = 0 - self.last_message_time: float = 0 - - # Messaging - self.incoming_queue = incoming_queue - self.outgoing_queue: asyncio.Queue[Message] = asyncio.Queue() - - self.inbound_task: Optional[asyncio.Task[None]] = None - self.outbound_task: Optional[asyncio.Task[None]] = None - self.active: bool = False # once handshake is successful this will be changed to True - self.close_event = close_event - self.session = session - self.close_callback: Optional[ConnectionClosedCallbackProtocol] = close_callback - - self.pending_requests: Dict[uint16, asyncio.Event] = {} - self.request_results: Dict[uint16, Message] = {} - self.closed = False - self.connection_type: Optional[NodeType] = None if is_outbound: - self.request_nonce: uint16 = uint16(0) + request_nonce = uint16(0) else: # Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each # request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1 - self.request_nonce = uint16(2**15) + request_nonce = uint16(2**15) - # This means that even if the other peer's boundaries for each minute are not aligned, we will not - # disconnect. Also it allows a little flexibility. - self.outbound_rate_limiter = RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent) - self.inbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent) - self.peer_capabilities: List[Capability] = [] - # Used by the Chia Seeder. - self.version = "" - self.protocol_version = "" + return cls( + ws=ws, + local_type=local_type, + local_port=server_port, + local_capabilities_for_handshake=local_capabilities_for_handshake, + local_capabilities=[Capability(x[0]) for x in local_capabilities_for_handshake if x[1] == "1"], + peer_host=peer_host, + peer_port=peername[1], + peer_node_id=peer_id, + log=log, + close_callback=close_callback, + request_nonce=request_nonce, + outbound_rate_limiter=RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent), + inbound_rate_limiter=RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent), + is_outbound=is_outbound, + is_feeler=is_feeler, + incoming_queue=incoming_queue, + close_event=close_event, + session=session, + ) def _get_extra_info(self, name: str) -> Optional[Any]: assert self.ws._writer is not None, "websocket's ._writer is None, was .prepare() called?" diff --git a/tests/connection_utils.py b/tests/connection_utils.py index 34a327e2f58d..59b43731d6e9 100644 --- a/tests/connection_utils.py +++ b/tests/connection_utils.py @@ -55,7 +55,7 @@ async def add_dummy_connection( peer_id = bytes32(der_cert.fingerprint(hashes.SHA256())) url = f"wss://{self_hostname}:{server._port}/ws" ws = await session.ws_connect(url, autoclose=True, autoping=True, ssl=ssl_context) - wsc = WSChiaConnection( + wsc = WSChiaConnection.create( type, ws, server._port, diff --git a/tests/core/ssl/test_ssl.py b/tests/core/ssl/test_ssl.py index f0e888489345..6a20f00facf7 100644 --- a/tests/core/ssl/test_ssl.py +++ b/tests/core/ssl/test_ssl.py @@ -21,7 +21,7 @@ async def establish_connection(server: ChiaServer, self_hostname: str, ssl_conte incoming_queue: asyncio.Queue = asyncio.Queue() url = f"wss://{self_hostname}:{server._port}/ws" ws = await session.ws_connect(url, autoclose=False, autoping=True, ssl=ssl_context) - wsc = WSChiaConnection( + wsc = WSChiaConnection.create( NodeType.FULL_NODE, ws, server._port,