Make WSChiaConnection a dataclass (#13906)

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
dustinface 2022-11-16 20:38:28 +01:00 committed by GitHub
parent 643b29e636
commit 5c861db42f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 82 additions and 62 deletions

View File

@ -315,7 +315,7 @@ class ChiaServer:
return ws
connection: Optional[WSChiaConnection] = None
try:
connection = WSChiaConnection(
connection = WSChiaConnection.create(
self._local_type,
ws,
self._port,
@ -467,7 +467,7 @@ class ChiaServer:
if peer_id == self.node_id:
raise RuntimeError(f"Trying to connect to a peer ({target_node}) with the same peer_id: {peer_id}")
connection = WSChiaConnection(
connection = WSChiaConnection.create(
self._local_type,
ws,
self._port,

View File

@ -5,12 +5,13 @@ import contextlib
import logging
import time
import traceback
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union
from aiohttp import ClientSession, WSCloseCode, WSMessage, WSMsgType
from aiohttp.client import ClientWebSocketResponse
from aiohttp.web import WebSocketResponse
from typing_extensions import Protocol
from typing_extensions import Protocol, final
from chia.cmds.init_funcs import chia_full_version_str
from chia.protocols.protocol_message_types import ProtocolMessageTypes
@ -46,6 +47,8 @@ class ConnectionClosedCallbackProtocol(Protocol):
...
@final
@dataclass
class WSChiaConnection:
"""
Represents a connection to another node. Local host and port are ours, while peer host and
@ -53,8 +56,54 @@ class WSChiaConnection:
set after the handshake is performed in this connection.
"""
def __init__(
self,
ws: WebSocket
local_type: NodeType
local_port: int
local_capabilities_for_handshake: List[Tuple[uint16, str]]
local_capabilities: List[Capability]
peer_host: str
peer_port: uint16
peer_node_id: bytes32
log: logging.Logger
close_callback: Optional[ConnectionClosedCallbackProtocol]
outbound_rate_limiter: RateLimiter
inbound_rate_limiter: RateLimiter
# connection properties
is_outbound: bool
is_feeler: bool
# Messaging
incoming_queue: asyncio.Queue[Tuple[Message, WSChiaConnection]]
outgoing_queue: asyncio.Queue[Message] = field(default_factory=asyncio.Queue)
# ChiaConnection metrics
creation_time: float = field(default_factory=time.time)
bytes_read: int = 0
bytes_written: int = 0
last_message_time: float = 0
peer_server_port: Optional[uint16] = None
inbound_task: Optional[asyncio.Task[None]] = None
outbound_task: Optional[asyncio.Task[None]] = None
active: bool = False # once handshake is successful this will be changed to True
close_event: Optional[asyncio.Event] = None
session: Optional[ClientSession] = None
pending_requests: Dict[uint16, asyncio.Event] = field(default_factory=dict)
request_results: Dict[uint16, Message] = field(default_factory=dict)
closed: bool = False
connection_type: Optional[NodeType] = None
request_nonce: uint16 = uint16(0)
peer_capabilities: List[Capability] = field(default_factory=list)
# Used by the Chia Seeder.
version: str = field(default_factory=str)
protocol_version: str = field(default_factory=str)
@classmethod
def create(
cls,
local_type: NodeType,
ws: WebSocket,
server_port: int,
@ -70,70 +119,41 @@ class WSChiaConnection:
local_capabilities_for_handshake: List[Tuple[uint16, str]],
close_event: Optional[asyncio.Event] = None,
session: Optional[ClientSession] = None,
) -> None:
# Local properties
self.ws = ws
self.local_type = local_type
self.local_port = server_port
self.local_capabilities_for_handshake = local_capabilities_for_handshake
self.local_capabilities: List[Capability] = [
Capability(x[0]) for x in local_capabilities_for_handshake if x[1] == "1"
]
) -> WSChiaConnection:
# Remote properties
self.peer_host = peer_host
peername = self._get_extra_info("peername")
assert ws._writer is not None
peername = ws._writer.transport.get_extra_info("peername")
if peername is None:
raise ValueError(f"Was not able to get peername from {self.peer_host}")
raise ValueError(f"Was not able to get peername from {peer_host}")
connection_port = peername[1]
self.peer_port = connection_port
self.peer_server_port: Optional[uint16] = None
self.peer_node_id = peer_id
self.log = log
# connection properties
self.is_outbound = is_outbound
self.is_feeler = is_feeler
# ChiaConnection metrics
self.creation_time = time.time()
self.bytes_read = 0
self.bytes_written = 0
self.last_message_time: float = 0
# Messaging
self.incoming_queue = incoming_queue
self.outgoing_queue: asyncio.Queue[Message] = asyncio.Queue()
self.inbound_task: Optional[asyncio.Task[None]] = None
self.outbound_task: Optional[asyncio.Task[None]] = None
self.active: bool = False # once handshake is successful this will be changed to True
self.close_event = close_event
self.session = session
self.close_callback: Optional[ConnectionClosedCallbackProtocol] = close_callback
self.pending_requests: Dict[uint16, asyncio.Event] = {}
self.request_results: Dict[uint16, Message] = {}
self.closed = False
self.connection_type: Optional[NodeType] = None
if is_outbound:
self.request_nonce: uint16 = uint16(0)
request_nonce = uint16(0)
else:
# Different nonce to reduce chances of overlap. Each peer will increment the nonce by one for each
# request. The receiving peer (not is_outbound), will use 2^15 to 2^16 - 1
self.request_nonce = uint16(2**15)
request_nonce = uint16(2**15)
# This means that even if the other peer's boundaries for each minute are not aligned, we will not
# disconnect. Also it allows a little flexibility.
self.outbound_rate_limiter = RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent)
self.inbound_rate_limiter = RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent)
self.peer_capabilities: List[Capability] = []
# Used by the Chia Seeder.
self.version = ""
self.protocol_version = ""
return cls(
ws=ws,
local_type=local_type,
local_port=server_port,
local_capabilities_for_handshake=local_capabilities_for_handshake,
local_capabilities=[Capability(x[0]) for x in local_capabilities_for_handshake if x[1] == "1"],
peer_host=peer_host,
peer_port=peername[1],
peer_node_id=peer_id,
log=log,
close_callback=close_callback,
request_nonce=request_nonce,
outbound_rate_limiter=RateLimiter(incoming=False, percentage_of_limit=outbound_rate_limit_percent),
inbound_rate_limiter=RateLimiter(incoming=True, percentage_of_limit=inbound_rate_limit_percent),
is_outbound=is_outbound,
is_feeler=is_feeler,
incoming_queue=incoming_queue,
close_event=close_event,
session=session,
)
def _get_extra_info(self, name: str) -> Optional[Any]:
assert self.ws._writer is not None, "websocket's ._writer is None, was .prepare() called?"

View File

@ -55,7 +55,7 @@ async def add_dummy_connection(
peer_id = bytes32(der_cert.fingerprint(hashes.SHA256()))
url = f"wss://{self_hostname}:{server._port}/ws"
ws = await session.ws_connect(url, autoclose=True, autoping=True, ssl=ssl_context)
wsc = WSChiaConnection(
wsc = WSChiaConnection.create(
type,
ws,
server._port,

View File

@ -21,7 +21,7 @@ async def establish_connection(server: ChiaServer, self_hostname: str, ssl_conte
incoming_queue: asyncio.Queue = asyncio.Queue()
url = f"wss://{self_hostname}:{server._port}/ws"
ws = await session.ws_connect(url, autoclose=False, autoping=True, ssl=ssl_context)
wsc = WSChiaConnection(
wsc = WSChiaConnection.create(
NodeType.FULL_NODE,
ws,
server._port,