server: Make ChiaServer a dataclass (#13574)

* server: Make `ChiaServer` a `dataclass`

* Use existing logger

Co-authored-by: Kyle Altendorf <sda@fstab.net>

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
dustinface 2022-11-03 17:28:23 +01:00 committed by GitHub
parent c331f1cd4a
commit cf77f58bf2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 71 additions and 55 deletions

View File

@ -6,6 +6,7 @@ import ssl
import time
import traceback
from collections import Counter
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Address, IPv6Network, ip_address, ip_network
from pathlib import Path
from secrets import token_bytes
@ -25,6 +26,7 @@ from aiohttp import (
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from typing_extensions import final
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_state_machine import message_requires_reply
@ -115,9 +117,45 @@ def calculate_node_id(cert_path: Path) -> bytes32:
return bytes32(der_cert.fingerprint(hashes.SHA256()))
@final
@dataclass
class ChiaServer:
def __init__(
self,
_port: int
_local_type: NodeType
_local_capabilities_for_handshake: List[Tuple[uint16, str]]
_ping_interval: int
_network_id: str
_inbound_rate_limit_percent: int
_outbound_rate_limit_percent: int
api: Any
node: Any
root_path: Path
config: Dict[str, Any]
log: logging.Logger
ssl_context: ssl.SSLContext
ssl_client_context: ssl.SSLContext
node_id: bytes32
exempt_peer_networks: List[Union[IPv4Network, IPv6Network]]
all_connections: Dict[bytes32, WSChiaConnection] = field(default_factory=dict)
on_connect: Optional[Callable] = None
incoming_messages: asyncio.Queue = field(default_factory=asyncio.Queue)
shut_down_event: asyncio.Event = field(default_factory=asyncio.Event)
introducer_peers: Optional[IntroducerPeers] = None
incoming_task: Optional[asyncio.Task] = None
gc_task: Optional[asyncio.Task] = None
webserver: Optional[WebServer] = None
connection_close_task: Optional[asyncio.Task] = None
received_message_callback: Optional[Callable] = None
api_tasks: Dict[bytes32, asyncio.Task] = field(default_factory=dict)
execute_tasks: Set[bytes32] = field(default_factory=set)
tasks_from_peer: Dict[bytes32, Set[bytes32]] = field(default_factory=dict)
banned_peers: Dict[str, float] = field(default_factory=dict)
invalid_protocol_ban_seconds = INVALID_PROTOCOL_BAN_SECONDS
api_exception_ban_seconds = API_EXCEPTION_BAN_SECONDS
@classmethod
def create(
cls,
port: int,
node: Any,
api: Any,
@ -132,32 +170,10 @@ class ChiaServer:
private_ca_crt_key: Tuple[Path, Path],
chia_ca_crt_key: Tuple[Path, Path],
name: str = None,
):
# Keeps track of all connections to and from this node.
self.all_connections: Dict[bytes32, WSChiaConnection] = {}
) -> ChiaServer:
self._port = port # TCP port to identify our node
self._local_type: NodeType = local_type
self._local_capabilities_for_handshake = capabilities
self._ping_interval = ping_interval
self._network_id = network_id
self._inbound_rate_limit_percent = inbound_rate_limit_percent
self._outbound_rate_limit_percent = outbound_rate_limit_percent
self.log = logging.getLogger(name if name else __name__)
self.log.info("Service capabilities: %s", self._local_capabilities_for_handshake)
# Our unique random node id that we will send to other peers, regenerated on launch
self.api = api
self.node = node
self.root_path = root_path
self.config = config
self.on_connect: Optional[Callable] = None
self.incoming_messages: asyncio.Queue = asyncio.Queue()
self.shut_down_event = asyncio.Event()
if self._local_type is NodeType.INTRODUCER:
self.introducer_peers = IntroducerPeers()
log = logging.getLogger(name if name else __name__)
log.info("Service capabilities: %s", capabilities)
ca_private_crt_path, ca_private_key_path = private_ca_crt_key
chia_ca_crt_path, chia_ca_key_path = chia_ca_crt_key
@ -168,55 +184,55 @@ class ChiaServer:
authenticated_client_types = {NodeType.HARVESTER}
authenticated_server_types = {NodeType.HARVESTER, NodeType.FARMER, NodeType.WALLET, NodeType.DATA_LAYER}
if self._local_type in authenticated_client_types:
if local_type in authenticated_client_types:
# Authenticated clients
private_cert_path, private_key_path = private_ssl_paths(root_path, config)
self.ssl_client_context = ssl_context_for_client(
ssl_client_context = ssl_context_for_client(
ca_private_crt_path, ca_private_key_path, private_cert_path, private_key_path
)
else:
# Public clients
public_cert_path, public_key_path = public_ssl_paths(root_path, config)
self.ssl_client_context = ssl_context_for_client(
ssl_client_context = ssl_context_for_client(
chia_ca_crt_path, chia_ca_key_path, public_cert_path, public_key_path
)
if self._local_type in authenticated_server_types:
if local_type in authenticated_server_types:
# Authenticated servers
private_cert_path, private_key_path = private_ssl_paths(root_path, config)
self.ssl_context = ssl_context_for_server(
ssl_context = ssl_context_for_server(
ca_private_crt_path,
ca_private_key_path,
private_cert_path,
private_key_path,
log=self.log,
log=log,
)
else:
# Public servers
public_cert_path, public_key_path = public_ssl_paths(root_path, config)
self.ssl_context = ssl_context_for_server(
chia_ca_crt_path, chia_ca_key_path, public_cert_path, public_key_path, log=self.log
ssl_context = ssl_context_for_server(
chia_ca_crt_path, chia_ca_key_path, public_cert_path, public_key_path, log=log
)
# If node has public cert use that one for id, if not use private.
self.node_id = calculate_node_id(private_cert_path if public_cert_path is None else public_cert_path)
self.incoming_task: Optional[asyncio.Task] = None
self.gc_task: Optional[asyncio.Task] = None
self.webserver: Optional[WebServer] = None
self.connection_close_task: Optional[asyncio.Task] = None
self.received_message_callback: Optional[Callable] = None
self.api_tasks: Dict[bytes32, asyncio.Task] = {}
self.execute_tasks: Set[bytes32] = set()
self.tasks_from_peer: Dict[bytes32, Set[bytes32]] = {}
self.banned_peers: Dict[str, float] = {}
self.invalid_protocol_ban_seconds = INVALID_PROTOCOL_BAN_SECONDS
self.api_exception_ban_seconds = API_EXCEPTION_BAN_SECONDS
self.exempt_peer_networks: List[Union[IPv4Network, IPv6Network]] = [
ip_network(net, strict=False) for net in config.get("exempt_peer_networks", [])
]
return cls(
_port=port,
_local_type=local_type,
_local_capabilities_for_handshake=capabilities,
_ping_interval=ping_interval,
_network_id=network_id,
_inbound_rate_limit_percent=inbound_rate_limit_percent,
_outbound_rate_limit_percent=outbound_rate_limit_percent,
log=log,
api=api,
node=node,
root_path=root_path,
config=config,
ssl_context=ssl_context,
ssl_client_context=ssl_client_context,
node_id=calculate_node_id(private_cert_path if public_cert_path is None else public_cert_path),
exempt_peer_networks=[ip_network(net, strict=False) for net in config.get("exempt_peer_networks", [])],
introducer_peers=IntroducerPeers() if local_type is NodeType.INTRODUCER else None,
)
def set_received_message_callback(self, callback: Callable):
self.received_message_callback = callback

View File

@ -96,7 +96,7 @@ class Service(Generic[_T_RpcServiceProtocol]):
capabilities_to_use = override_capabilities
assert inbound_rlp and outbound_rlp
self._server = ChiaServer(
self._server = ChiaServer.create(
advertised_port,
node,
peer_api,