from __future__ import annotations import asyncio import logging import ssl import time import traceback from dataclasses import dataclass, field from ipaddress import IPv4Network, IPv6Network, ip_network from pathlib import Path from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Union, cast from aiohttp import ( ClientResponseError, ClientSession, ClientTimeout, ServerDisconnectedError, WSCloseCode, client_exceptions, web, ) from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization from typing_extensions import final from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.protocol_state_machine import message_requires_reply from chia.protocols.protocol_timing import INVALID_PROTOCOL_BAN_SECONDS from chia.protocols.shared_protocol import protocol_version from chia.server.api_protocol import ApiProtocol from chia.server.introducer_peers import IntroducerPeers from chia.server.outbound_message import Message, NodeType from chia.server.ssl_context import private_ssl_paths, public_ssl_paths from chia.server.ws_connection import ConnectionCallback, WSChiaConnection from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.peer_info import PeerInfo from chia.util.errors import Err, ProtocolError from chia.util.ints import uint16 from chia.util.network import WebServer, is_in_network, is_localhost, is_trusted_peer from chia.util.ssl_check import verify_ssl_certs_and_keys from chia.util.streamable import Streamable max_message_size = 50 * 1024 * 1024 # 50MB def ssl_context_for_server( ca_cert: Path, ca_key: Path, cert_path: Path, key_path: Path, *, check_permissions: bool = True, log: Optional[logging.Logger] = None, ) -> ssl.SSLContext: if check_permissions: verify_ssl_certs_and_keys([ca_cert, cert_path], [ca_key, key_path], log) ssl_context = ssl._create_unverified_context(purpose=ssl.Purpose.CLIENT_AUTH, cafile=str(ca_cert)) ssl_context.check_hostname = False ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2 ssl_context.set_ciphers( "ECDHE-ECDSA-AES256-GCM-SHA384:" "ECDHE-RSA-AES256-GCM-SHA384:" "ECDHE-ECDSA-CHACHA20-POLY1305:" "ECDHE-RSA-CHACHA20-POLY1305:" "ECDHE-ECDSA-AES128-GCM-SHA256:" "ECDHE-RSA-AES128-GCM-SHA256:" "ECDHE-ECDSA-AES256-SHA384:" "ECDHE-RSA-AES256-SHA384:" "ECDHE-ECDSA-AES128-SHA256:" "ECDHE-RSA-AES128-SHA256" ) ssl_context.load_cert_chain(certfile=str(cert_path), keyfile=str(key_path)) ssl_context.verify_mode = ssl.CERT_REQUIRED return ssl_context def ssl_context_for_root( ca_cert_file: str, *, check_permissions: bool = True, log: Optional[logging.Logger] = None ) -> ssl.SSLContext: if check_permissions: verify_ssl_certs_and_keys([Path(ca_cert_file)], [], log) ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=ca_cert_file) return ssl_context def ssl_context_for_client( ca_cert: Path, ca_key: Path, cert_path: Path, key_path: Path, *, check_permissions: bool = True, log: Optional[logging.Logger] = None, ) -> ssl.SSLContext: if check_permissions: verify_ssl_certs_and_keys([ca_cert, cert_path], [ca_key, key_path], log) ssl_context = ssl._create_unverified_context(purpose=ssl.Purpose.SERVER_AUTH, cafile=str(ca_cert)) ssl_context.check_hostname = False ssl_context.load_cert_chain(certfile=str(cert_path), keyfile=str(key_path)) ssl_context.verify_mode = ssl.CERT_REQUIRED return ssl_context def calculate_node_id(cert_path: Path) -> bytes32: pem_cert = x509.load_pem_x509_certificate(cert_path.read_bytes(), default_backend()) der_cert_bytes = pem_cert.public_bytes(encoding=serialization.Encoding.DER) der_cert = x509.load_der_x509_certificate(der_cert_bytes, default_backend()) return bytes32(der_cert.fingerprint(hashes.SHA256())) @final @dataclass class ChiaServer: _port: Optional[int] _local_type: NodeType _local_capabilities_for_handshake: List[Tuple[uint16, str]] _ping_interval: int _network_id: str _inbound_rate_limit_percent: int _outbound_rate_limit_percent: int api: ApiProtocol node: Any root_path: Path config: Dict[str, Any] log: logging.Logger ssl_context: ssl.SSLContext ssl_client_context: ssl.SSLContext node_id: bytes32 exempt_peer_networks: List[Union[IPv4Network, IPv6Network]] all_connections: Dict[bytes32, WSChiaConnection] = field(default_factory=dict) on_connect: Optional[ConnectionCallback] = None shut_down_event: asyncio.Event = field(default_factory=asyncio.Event) introducer_peers: Optional[IntroducerPeers] = None gc_task: Optional[asyncio.Task[None]] = None webserver: Optional[WebServer] = None connection_close_task: Optional[asyncio.Task[None]] = None received_message_callback: Optional[ConnectionCallback] = None banned_peers: Dict[str, float] = field(default_factory=dict) invalid_protocol_ban_seconds = INVALID_PROTOCOL_BAN_SECONDS @classmethod def create( cls, port: Optional[int], node: Any, api: ApiProtocol, local_type: NodeType, ping_interval: int, network_id: str, inbound_rate_limit_percent: int, outbound_rate_limit_percent: int, capabilities: List[Tuple[uint16, str]], root_path: Path, config: Dict[str, Any], private_ca_crt_key: Tuple[Path, Path], chia_ca_crt_key: Tuple[Path, Path], name: str = __name__, ) -> ChiaServer: log = logging.getLogger(name) log.info("Service capabilities: %s", capabilities) ca_private_crt_path, ca_private_key_path = private_ca_crt_key chia_ca_crt_path, chia_ca_key_path = chia_ca_crt_key private_cert_path, private_key_path = None, None public_cert_path, public_key_path = None, None authenticated_client_types = {NodeType.HARVESTER} authenticated_server_types = {NodeType.HARVESTER, NodeType.FARMER, NodeType.WALLET, NodeType.DATA_LAYER} if local_type in authenticated_client_types: # Authenticated clients private_cert_path, private_key_path = private_ssl_paths(root_path, config) ssl_client_context = ssl_context_for_client( ca_cert=ca_private_crt_path, ca_key=ca_private_key_path, cert_path=private_cert_path, key_path=private_key_path, ) else: # Public clients public_cert_path, public_key_path = public_ssl_paths(root_path, config) ssl_client_context = ssl_context_for_client( ca_cert=chia_ca_crt_path, ca_key=chia_ca_key_path, cert_path=public_cert_path, key_path=public_key_path, ) if local_type in authenticated_server_types: # Authenticated servers private_cert_path, private_key_path = private_ssl_paths(root_path, config) ssl_context = ssl_context_for_server( ca_cert=ca_private_crt_path, ca_key=ca_private_key_path, cert_path=private_cert_path, key_path=private_key_path, log=log, ) else: # Public servers public_cert_path, public_key_path = public_ssl_paths(root_path, config) ssl_context = ssl_context_for_server( ca_cert=chia_ca_crt_path, ca_key=chia_ca_key_path, cert_path=public_cert_path, key_path=public_key_path, log=log, ) node_id_cert_path = private_cert_path if public_cert_path is None else public_cert_path assert node_id_cert_path is not None return cls( _port=port, _local_type=local_type, _local_capabilities_for_handshake=capabilities, _ping_interval=ping_interval, _network_id=network_id, _inbound_rate_limit_percent=inbound_rate_limit_percent, _outbound_rate_limit_percent=outbound_rate_limit_percent, log=log, api=api, node=node, root_path=root_path, config=config, ssl_context=ssl_context, ssl_client_context=ssl_client_context, node_id=calculate_node_id(node_id_cert_path), exempt_peer_networks=[ip_network(net, strict=False) for net in config.get("exempt_peer_networks", [])], introducer_peers=IntroducerPeers() if local_type is NodeType.INTRODUCER else None, ) def set_received_message_callback(self, callback: ConnectionCallback) -> None: self.received_message_callback = callback async def garbage_collect_connections_task(self) -> None: """ Periodically checks for connections with no activity (have not sent us any data), and removes them, to allow room for other peers. """ is_crawler = getattr(self.node, "crawl", None) while True: await asyncio.sleep(600 if is_crawler is None else 2) to_remove: List[WSChiaConnection] = [] for connection in self.all_connections.values(): if connection.closed: to_remove.append(connection) elif ( self._local_type == NodeType.FULL_NODE or self._local_type == NodeType.WALLET ) and connection.connection_type == NodeType.FULL_NODE: if is_crawler is not None: if time.time() - connection.creation_time > 5: to_remove.append(connection) else: if time.time() - connection.last_message_time > 1800: to_remove.append(connection) for connection in to_remove: self.log.debug(f"Garbage collecting connection {connection.peer_info.host} due to inactivity") if connection.closed: self.all_connections.pop(connection.peer_node_id) else: await connection.close() # Also garbage collect banned_peers dict to_remove_ban = [] for peer_ip, ban_until_time in self.banned_peers.items(): if time.time() > ban_until_time: to_remove_ban.append(peer_ip) for peer_ip in to_remove_ban: del self.banned_peers[peer_ip] async def start( self, prefer_ipv6: bool, on_connect: Optional[ConnectionCallback] = None, ) -> None: if self.webserver is not None: raise RuntimeError("ChiaServer already started") if self.gc_task is None: self.gc_task = asyncio.create_task(self.garbage_collect_connections_task()) if self._port is not None: self.on_connect = on_connect self.webserver = await WebServer.create( hostname="", port=self.get_port(), routes=[web.get("/ws", self.incoming_connection)], ssl_context=self.ssl_context, prefer_ipv6=prefer_ipv6, logger=self.log, ) self._port = int(self.webserver.listen_port) self.log.info(f"Started listening on port: {self._port}") async def incoming_connection(self, request: web.Request) -> web.StreamResponse: if getattr(self.node, "crawl", None) is not None: raise web.HTTPForbidden(reason="incoming connections not allowed for crawler") if request.remote is None: raise web.HTTPInternalServerError(reason=f"remote is None for request {request}") if request.remote in self.banned_peers and time.time() < self.banned_peers[request.remote]: reason = f"Peer {request.remote} is banned, refusing connection" self.log.warning(reason) raise web.HTTPForbidden(reason=reason) ws = web.WebSocketResponse(max_msg_size=max_message_size) await ws.prepare(request) ssl_object = request.get_extra_info("ssl_object") if ssl_object is None: reason = f"ssl_object is None for request {request}" self.log.warning(reason) raise web.HTTPInternalServerError(reason=reason) cert_bytes = ssl_object.getpeercert(True) der_cert = x509.load_der_x509_certificate(cert_bytes) peer_id = bytes32(der_cert.fingerprint(hashes.SHA256())) if peer_id == self.node_id: return ws connection: Optional[WSChiaConnection] = None try: connection = WSChiaConnection.create( local_type=self._local_type, ws=ws, api=self.api, server_port=self.get_port(), log=self.log, is_outbound=False, received_message_callback=self.received_message_callback, close_callback=self.connection_closed, peer_id=peer_id, inbound_rate_limit_percent=self._inbound_rate_limit_percent, outbound_rate_limit_percent=self._outbound_rate_limit_percent, local_capabilities_for_handshake=self._local_capabilities_for_handshake, ) await connection.perform_handshake(self._network_id, protocol_version, self.get_port(), self._local_type) assert connection.connection_type is not None, "handshake failed to set connection type, still None" # Limit inbound connections to config's specifications. if not self.accept_inbound_connections(connection.connection_type) and not is_in_network( connection.peer_info.host, self.exempt_peer_networks ): self.log.info( f"Not accepting inbound connection: {connection.get_peer_logging()}.Inbound limit reached." ) await connection.close() else: await self.connection_added(connection, self.on_connect) if self.introducer_peers is not None and connection.connection_type is NodeType.FULL_NODE: self.introducer_peers.add(connection.get_peer_info()) except ProtocolError as e: if connection is not None: await connection.close(self.invalid_protocol_ban_seconds, WSCloseCode.PROTOCOL_ERROR, e.code) if e.code == Err.INVALID_HANDSHAKE: self.log.warning("Invalid handshake with peer. Maybe the peer is running old software.") elif e.code == Err.INCOMPATIBLE_NETWORK_ID: self.log.warning("Incompatible network ID. Maybe the peer is on another network") else: error_stack = traceback.format_exc() self.log.error(f"Exception {e}, exception Stack: {error_stack}") except ValueError as e: if connection is not None: await connection.close(self.invalid_protocol_ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.UNKNOWN) self.log.warning(f"{e} - closing connection") except Exception as e: if connection is not None: await connection.close(ws_close_code=WSCloseCode.PROTOCOL_ERROR, error=Err.UNKNOWN) error_stack = traceback.format_exc() self.log.error(f"Exception {e}, exception Stack: {error_stack}") if connection is not None: await connection.wait_until_closed() return ws async def connection_added( self, connection: WSChiaConnection, on_connect: Optional[ConnectionCallback] = None ) -> None: # If we already had a connection to this peer_id, close the old one. This is secure because peer_ids are based # on TLS public keys if connection.closed: self.log.debug(f"ignoring unexpected request to add closed connection {connection.peer_info.host} ") return if connection.peer_node_id in self.all_connections: con = self.all_connections[connection.peer_node_id] await con.close() self.all_connections[connection.peer_node_id] = connection if connection.connection_type is not None: if on_connect is not None: await on_connect(connection) else: self.log.error(f"Invalid connection type for connection {connection}") def is_duplicate_or_self_connection(self, target_node: PeerInfo) -> bool: if is_localhost(target_node.host) and target_node.port == self._port: # Don't connect to self self.log.debug(f"Not connecting to {target_node}") return True for connection in self.all_connections.values(): if connection.peer_info.host == target_node.host and connection.peer_server_port == target_node.port: self.log.debug(f"Not connecting to {target_node}, duplicate connection") return True return False async def start_client( self, target_node: PeerInfo, on_connect: Optional[ConnectionCallback] = None, is_feeler: bool = False, ) -> bool: """ Tries to connect to the target node, adding one connection into the pipeline, if successful. An on connect method can also be specified, and this will be saved into the instance variables. """ if self.is_duplicate_or_self_connection(target_node): self.log.warning(f"cannot connect to {target_node.host}, duplicate/self connection") return False if target_node.host in self.banned_peers and time.time() < self.banned_peers[target_node.host]: self.log.warning(f"Peer {target_node.host} is still banned, not connecting to it") return False session = None connection: Optional[WSChiaConnection] = None try: # Crawler/DNS introducer usually uses a lower timeout than the default timeout_value = float(self.config.get("peer_connect_timeout", 30)) timeout = ClientTimeout(total=timeout_value) session = ClientSession(timeout=timeout) ip = f"[{target_node.ip}]" if target_node.ip.is_v6 else f"{target_node.ip}" url = f"wss://{ip}:{target_node.port}/ws" self.log.debug(f"Connecting: {url}, Peer info: {target_node}") try: ws = await session.ws_connect( url, autoclose=True, autoping=True, heartbeat=60, ssl=self.ssl_client_context, max_msg_size=max_message_size, ) except ServerDisconnectedError: self.log.debug(f"Server disconnected error connecting to {url}. Perhaps we are banned by the peer.") return False except ClientResponseError as e: self.log.warning(f"Connection failed to {url}. Error: {e}") return False except asyncio.TimeoutError: self.log.debug(f"Timeout error connecting to {url}") return False if ws is None: self.log.warning(f"Connection failed to {url}. ws was None") return False ssl_object = ws.get_extra_info("ssl_object") if ssl_object is None: raise ValueError(f"ssl_object is None for {ws}") cert_bytes = ssl_object.getpeercert(True) der_cert = x509.load_der_x509_certificate(cert_bytes, default_backend()) peer_id = bytes32(der_cert.fingerprint(hashes.SHA256())) if peer_id == self.node_id: self.log.info(f"Connected to a node with the same peer ID, disconnecting: {target_node} {peer_id}") return False server_port: uint16 try: server_port = self.get_port() except ValueError: server_port = uint16(0) connection = WSChiaConnection.create( local_type=self._local_type, ws=ws, api=self.api, server_port=server_port, log=self.log, is_outbound=True, received_message_callback=self.received_message_callback, close_callback=self.connection_closed, peer_id=peer_id, inbound_rate_limit_percent=self._inbound_rate_limit_percent, outbound_rate_limit_percent=self._outbound_rate_limit_percent, local_capabilities_for_handshake=self._local_capabilities_for_handshake, session=session, ) await connection.perform_handshake(self._network_id, protocol_version, server_port, self._local_type) await self.connection_added(connection, on_connect) # the session has been adopted by the connection, don't close it at # the end of the function session = None connection_type_str = "" if connection.connection_type is not None: connection_type_str = connection.connection_type.name.lower() self.log.info(f"Connected with {connection_type_str} {target_node}") if is_feeler: asyncio.create_task(connection.close()) return True except client_exceptions.ClientConnectorError as e: self.log.info(f"{e}") except ProtocolError as e: if connection is not None: await connection.close(self.invalid_protocol_ban_seconds, WSCloseCode.PROTOCOL_ERROR, e.code) if e.code == Err.INVALID_HANDSHAKE: self.log.warning(f"Invalid handshake with peer {target_node}. Maybe the peer is running old software.") elif e.code == Err.INCOMPATIBLE_NETWORK_ID: self.log.warning("Incompatible network ID. Maybe the peer is on another network") elif e.code == Err.SELF_CONNECTION: pass else: error_stack = traceback.format_exc() self.log.error(f"Exception {e}, exception Stack: {error_stack}") except Exception as e: if connection is not None: await connection.close(self.invalid_protocol_ban_seconds, WSCloseCode.PROTOCOL_ERROR, Err.UNKNOWN) error_stack = traceback.format_exc() self.log.error(f"Exception {e}, exception Stack: {error_stack}") finally: if session is not None: await session.close() return False async def connection_closed( self, connection: WSChiaConnection, ban_time: int, closed_connection: bool = False ) -> None: # closed_connection is true if the callback is being called with a connection that was previously closed # in this case we still want to do the banning logic and remove the conection from the list # but the other cleanup should already have been done so we skip that if is_localhost(connection.peer_info.host) and ban_time != 0: self.log.warning(f"Trying to ban localhost for {ban_time}, but will not ban") ban_time = 0 if ban_time > 0: ban_until: float = time.time() + ban_time self.log.warning(f"Banning {connection.peer_info.host} for {ban_time} seconds") if connection.peer_info.host in self.banned_peers: if ban_until > self.banned_peers[connection.peer_info.host]: self.banned_peers[connection.peer_info.host] = ban_until else: self.banned_peers[connection.peer_info.host] = ban_until present_connection = self.all_connections.get(connection.peer_node_id) if present_connection is connection: self.all_connections.pop(connection.peer_node_id) if not closed_connection: self.log.info(f"Connection closed: {connection.peer_info.host}, node id: {connection.peer_node_id}") if connection.connection_type is None: # This means the handshake was never finished with this peer self.log.debug( f"Invalid connection type for connection {connection.peer_info.host}," f" while closing. Handshake never finished." ) connection.cancel_tasks() on_disconnect = getattr(self.node, "on_disconnect", None) if on_disconnect is not None: await on_disconnect(connection) async def validate_broadcast_message_type(self, messages: List[Message], node_type: NodeType) -> None: for message in messages: if message_requires_reply(ProtocolMessageTypes(message.type)): # Internal protocol logic error - we will raise, blocking messages to all peers self.log.error(f"Attempt to broadcast message requiring protocol response: {message.type}") for _, connection in self.all_connections.items(): if connection.connection_type is node_type: await connection.close( ban_time=self.invalid_protocol_ban_seconds, ws_close_code=WSCloseCode.INTERNAL_ERROR, error=Err.INTERNAL_PROTOCOL_ERROR, ) raise ProtocolError(Err.INTERNAL_PROTOCOL_ERROR, [message.type]) async def send_to_all( self, messages: List[Message], node_type: NodeType, exclude: Optional[bytes32] = None, ) -> None: await self.validate_broadcast_message_type(messages, node_type) for _, connection in self.all_connections.items(): if connection.connection_type is node_type and connection.peer_node_id != exclude: for message in messages: await connection.send_message(message) async def send_to_specific(self, messages: List[Message], node_id: bytes32) -> None: if node_id in self.all_connections: connection = self.all_connections[node_id] for message in messages: await connection.send_message(message) async def call_api_of_specific( self, request_method: Callable[..., Awaitable[Optional[Message]]], message_data: Streamable, node_id: bytes32 ) -> Optional[Any]: if node_id in self.all_connections: connection = self.all_connections[node_id] return await connection.call_api(request_method, message_data) return None def get_connections( self, node_type: Optional[NodeType] = None, *, outbound: Optional[bool] = None ) -> List[WSChiaConnection]: result = [] for _, connection in self.all_connections.items(): node_type_match = node_type is None or connection.connection_type == node_type outbound_match = outbound is None or connection.is_outbound == outbound if node_type_match and outbound_match: result.append(connection) return result async def close_all_connections(self) -> None: for connection in self.all_connections.copy().values(): try: await connection.close() except Exception as e: self.log.error(f"Exception while closing connection {e}") def close_all(self) -> None: self.connection_close_task = asyncio.create_task(self.close_all_connections()) if self.webserver is not None: self.webserver.close() self.shut_down_event.set() if self.gc_task is not None: self.gc_task.cancel() self.gc_task = None async def await_closed(self) -> None: self.log.debug("Await Closed") await self.shut_down_event.wait() if self.connection_close_task is not None: await self.connection_close_task if self.webserver is not None: await self.webserver.await_closed() self.webserver = None async def get_peer_info(self) -> Optional[PeerInfo]: ip = None try: port = self.get_port() except ValueError: return None # server doesn't have a local port, just return None here # Use chia's service first. try: timeout = ClientTimeout(total=15) async with ClientSession(timeout=timeout) as session: async with session.get("https://ip.chia.net/") as resp: if resp.status == 200: ip = str(await resp.text()) ip = ip.rstrip() except Exception: ip = None # Fallback to `checkip` from amazon. if ip is None: try: timeout = ClientTimeout(total=15) async with ClientSession(timeout=timeout) as session: async with session.get("https://checkip.amazonaws.com/") as resp: if resp.status == 200: ip = str(await resp.text()) ip = ip.rstrip() except Exception: ip = None if ip is None: return None try: return PeerInfo(ip, uint16(port)) except ValueError: return None def get_port(self) -> uint16: if self._port is None: raise ValueError("Port not set") return uint16(self._port) def accept_inbound_connections(self, node_type: NodeType) -> bool: if not self._local_type == NodeType.FULL_NODE: return True inbound_count = len(self.get_connections(node_type, outbound=False)) if node_type == NodeType.FULL_NODE: return inbound_count < cast(int, self.config.get("target_peer_count", 40)) - cast( int, self.config.get("target_outbound_peer_count", 8) ) if node_type == NodeType.WALLET: return inbound_count < cast(int, self.config.get("max_inbound_wallet", 20)) if node_type == NodeType.FARMER: return inbound_count < cast(int, self.config.get("max_inbound_farmer", 10)) if node_type == NodeType.TIMELORD: return inbound_count < cast(int, self.config.get("max_inbound_timelord", 5)) return True def is_trusted_peer(self, peer: WSChiaConnection, trusted_peers: Dict[str, Any]) -> bool: return is_trusted_peer( host=peer.peer_info.host, node_id=peer.peer_node_id, trusted_peers=trusted_peers, testing=self.config.get("testing", False), ) def set_capabilities(self, capabilities: List[Tuple[uint16, str]]) -> None: self._local_capabilities_for_handshake = capabilities