mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-21 08:31:52 +03:00
server: Make ChiaServer
a dataclass
(#13574)
* server: Make `ChiaServer` a `dataclass` * Use existing logger Co-authored-by: Kyle Altendorf <sda@fstab.net> Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
parent
c331f1cd4a
commit
cf77f58bf2
@ -6,6 +6,7 @@ import ssl
|
||||
import time
|
||||
import traceback
|
||||
from collections import Counter
|
||||
from dataclasses import dataclass, field
|
||||
from ipaddress import IPv4Network, IPv6Address, IPv6Network, ip_address, ip_network
|
||||
from pathlib import Path
|
||||
from secrets import token_bytes
|
||||
@ -25,6 +26,7 @@ from aiohttp import (
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from typing_extensions import final
|
||||
|
||||
from chia.protocols.protocol_message_types import ProtocolMessageTypes
|
||||
from chia.protocols.protocol_state_machine import message_requires_reply
|
||||
@ -115,9 +117,45 @@ def calculate_node_id(cert_path: Path) -> bytes32:
|
||||
return bytes32(der_cert.fingerprint(hashes.SHA256()))
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class ChiaServer:
|
||||
def __init__(
|
||||
self,
|
||||
_port: int
|
||||
_local_type: NodeType
|
||||
_local_capabilities_for_handshake: List[Tuple[uint16, str]]
|
||||
_ping_interval: int
|
||||
_network_id: str
|
||||
_inbound_rate_limit_percent: int
|
||||
_outbound_rate_limit_percent: int
|
||||
api: Any
|
||||
node: Any
|
||||
root_path: Path
|
||||
config: Dict[str, Any]
|
||||
log: logging.Logger
|
||||
ssl_context: ssl.SSLContext
|
||||
ssl_client_context: ssl.SSLContext
|
||||
node_id: bytes32
|
||||
exempt_peer_networks: List[Union[IPv4Network, IPv6Network]]
|
||||
all_connections: Dict[bytes32, WSChiaConnection] = field(default_factory=dict)
|
||||
on_connect: Optional[Callable] = None
|
||||
incoming_messages: asyncio.Queue = field(default_factory=asyncio.Queue)
|
||||
shut_down_event: asyncio.Event = field(default_factory=asyncio.Event)
|
||||
introducer_peers: Optional[IntroducerPeers] = None
|
||||
incoming_task: Optional[asyncio.Task] = None
|
||||
gc_task: Optional[asyncio.Task] = None
|
||||
webserver: Optional[WebServer] = None
|
||||
connection_close_task: Optional[asyncio.Task] = None
|
||||
received_message_callback: Optional[Callable] = None
|
||||
api_tasks: Dict[bytes32, asyncio.Task] = field(default_factory=dict)
|
||||
execute_tasks: Set[bytes32] = field(default_factory=set)
|
||||
tasks_from_peer: Dict[bytes32, Set[bytes32]] = field(default_factory=dict)
|
||||
banned_peers: Dict[str, float] = field(default_factory=dict)
|
||||
invalid_protocol_ban_seconds = INVALID_PROTOCOL_BAN_SECONDS
|
||||
api_exception_ban_seconds = API_EXCEPTION_BAN_SECONDS
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
port: int,
|
||||
node: Any,
|
||||
api: Any,
|
||||
@ -132,32 +170,10 @@ class ChiaServer:
|
||||
private_ca_crt_key: Tuple[Path, Path],
|
||||
chia_ca_crt_key: Tuple[Path, Path],
|
||||
name: str = None,
|
||||
):
|
||||
# Keeps track of all connections to and from this node.
|
||||
self.all_connections: Dict[bytes32, WSChiaConnection] = {}
|
||||
) -> ChiaServer:
|
||||
|
||||
self._port = port # TCP port to identify our node
|
||||
self._local_type: NodeType = local_type
|
||||
self._local_capabilities_for_handshake = capabilities
|
||||
self._ping_interval = ping_interval
|
||||
self._network_id = network_id
|
||||
self._inbound_rate_limit_percent = inbound_rate_limit_percent
|
||||
self._outbound_rate_limit_percent = outbound_rate_limit_percent
|
||||
|
||||
self.log = logging.getLogger(name if name else __name__)
|
||||
self.log.info("Service capabilities: %s", self._local_capabilities_for_handshake)
|
||||
|
||||
# Our unique random node id that we will send to other peers, regenerated on launch
|
||||
self.api = api
|
||||
self.node = node
|
||||
self.root_path = root_path
|
||||
self.config = config
|
||||
self.on_connect: Optional[Callable] = None
|
||||
self.incoming_messages: asyncio.Queue = asyncio.Queue()
|
||||
self.shut_down_event = asyncio.Event()
|
||||
|
||||
if self._local_type is NodeType.INTRODUCER:
|
||||
self.introducer_peers = IntroducerPeers()
|
||||
log = logging.getLogger(name if name else __name__)
|
||||
log.info("Service capabilities: %s", capabilities)
|
||||
|
||||
ca_private_crt_path, ca_private_key_path = private_ca_crt_key
|
||||
chia_ca_crt_path, chia_ca_key_path = chia_ca_crt_key
|
||||
@ -168,55 +184,55 @@ class ChiaServer:
|
||||
authenticated_client_types = {NodeType.HARVESTER}
|
||||
authenticated_server_types = {NodeType.HARVESTER, NodeType.FARMER, NodeType.WALLET, NodeType.DATA_LAYER}
|
||||
|
||||
if self._local_type in authenticated_client_types:
|
||||
if local_type in authenticated_client_types:
|
||||
# Authenticated clients
|
||||
private_cert_path, private_key_path = private_ssl_paths(root_path, config)
|
||||
self.ssl_client_context = ssl_context_for_client(
|
||||
ssl_client_context = ssl_context_for_client(
|
||||
ca_private_crt_path, ca_private_key_path, private_cert_path, private_key_path
|
||||
)
|
||||
else:
|
||||
# Public clients
|
||||
public_cert_path, public_key_path = public_ssl_paths(root_path, config)
|
||||
self.ssl_client_context = ssl_context_for_client(
|
||||
ssl_client_context = ssl_context_for_client(
|
||||
chia_ca_crt_path, chia_ca_key_path, public_cert_path, public_key_path
|
||||
)
|
||||
|
||||
if self._local_type in authenticated_server_types:
|
||||
if local_type in authenticated_server_types:
|
||||
# Authenticated servers
|
||||
private_cert_path, private_key_path = private_ssl_paths(root_path, config)
|
||||
self.ssl_context = ssl_context_for_server(
|
||||
ssl_context = ssl_context_for_server(
|
||||
ca_private_crt_path,
|
||||
ca_private_key_path,
|
||||
private_cert_path,
|
||||
private_key_path,
|
||||
log=self.log,
|
||||
log=log,
|
||||
)
|
||||
else:
|
||||
# Public servers
|
||||
public_cert_path, public_key_path = public_ssl_paths(root_path, config)
|
||||
self.ssl_context = ssl_context_for_server(
|
||||
chia_ca_crt_path, chia_ca_key_path, public_cert_path, public_key_path, log=self.log
|
||||
ssl_context = ssl_context_for_server(
|
||||
chia_ca_crt_path, chia_ca_key_path, public_cert_path, public_key_path, log=log
|
||||
)
|
||||
|
||||
# If node has public cert use that one for id, if not use private.
|
||||
self.node_id = calculate_node_id(private_cert_path if public_cert_path is None else public_cert_path)
|
||||
|
||||
self.incoming_task: Optional[asyncio.Task] = None
|
||||
self.gc_task: Optional[asyncio.Task] = None
|
||||
self.webserver: Optional[WebServer] = None
|
||||
|
||||
self.connection_close_task: Optional[asyncio.Task] = None
|
||||
self.received_message_callback: Optional[Callable] = None
|
||||
self.api_tasks: Dict[bytes32, asyncio.Task] = {}
|
||||
self.execute_tasks: Set[bytes32] = set()
|
||||
|
||||
self.tasks_from_peer: Dict[bytes32, Set[bytes32]] = {}
|
||||
self.banned_peers: Dict[str, float] = {}
|
||||
self.invalid_protocol_ban_seconds = INVALID_PROTOCOL_BAN_SECONDS
|
||||
self.api_exception_ban_seconds = API_EXCEPTION_BAN_SECONDS
|
||||
self.exempt_peer_networks: List[Union[IPv4Network, IPv6Network]] = [
|
||||
ip_network(net, strict=False) for net in config.get("exempt_peer_networks", [])
|
||||
]
|
||||
return cls(
|
||||
_port=port,
|
||||
_local_type=local_type,
|
||||
_local_capabilities_for_handshake=capabilities,
|
||||
_ping_interval=ping_interval,
|
||||
_network_id=network_id,
|
||||
_inbound_rate_limit_percent=inbound_rate_limit_percent,
|
||||
_outbound_rate_limit_percent=outbound_rate_limit_percent,
|
||||
log=log,
|
||||
api=api,
|
||||
node=node,
|
||||
root_path=root_path,
|
||||
config=config,
|
||||
ssl_context=ssl_context,
|
||||
ssl_client_context=ssl_client_context,
|
||||
node_id=calculate_node_id(private_cert_path if public_cert_path is None else public_cert_path),
|
||||
exempt_peer_networks=[ip_network(net, strict=False) for net in config.get("exempt_peer_networks", [])],
|
||||
introducer_peers=IntroducerPeers() if local_type is NodeType.INTRODUCER else None,
|
||||
)
|
||||
|
||||
def set_received_message_callback(self, callback: Callable):
|
||||
self.received_message_callback = callback
|
||||
|
@ -96,7 +96,7 @@ class Service(Generic[_T_RpcServiceProtocol]):
|
||||
capabilities_to_use = override_capabilities
|
||||
|
||||
assert inbound_rlp and outbound_rlp
|
||||
self._server = ChiaServer(
|
||||
self._server = ChiaServer.create(
|
||||
advertised_port,
|
||||
node,
|
||||
peer_api,
|
||||
|
Loading…
Reference in New Issue
Block a user