diff --git a/chia/rpc/rpc_server.py b/chia/rpc/rpc_server.py index 247b91699acb..42b0c12226de 100644 --- a/chia/rpc/rpc_server.py +++ b/chia/rpc/rpc_server.py @@ -8,7 +8,7 @@ import traceback from dataclasses import dataclass from pathlib import Path from ssl import SSLContext -from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, List, Optional, TypeVar from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web from typing_extensions import Protocol, final @@ -31,6 +31,7 @@ max_message_size = 50 * 1024 * 1024 # 50MB EndpointResult = Dict[str, Any] Endpoint = Callable[[Dict[str, object]], Awaitable[EndpointResult]] +_T_RpcApiProtocol = TypeVar("_T_RpcApiProtocol", bound="RpcApiProtocol") class StateChangedProtocol(Protocol): @@ -121,12 +122,12 @@ def default_get_connections(server: ChiaServer, request_node_type: Optional[Node @final @dataclass -class RpcServer: +class RpcServer(Generic[_T_RpcApiProtocol]): """ Implementation of RPC server. """ - rpc_api: RpcApiProtocol + rpc_api: _T_RpcApiProtocol stop_cb: Callable[[], None] service_name: str ssl_context: SSLContext @@ -142,13 +143,13 @@ class RpcServer: @classmethod def create( cls, - rpc_api: RpcApiProtocol, + rpc_api: _T_RpcApiProtocol, service_name: str, stop_cb: Callable[[], None], root_path: Path, net_config: Dict[str, Any], prefer_ipv6: bool, - ) -> RpcServer: + ) -> RpcServer[_T_RpcApiProtocol]: crt_path = root_path / net_config["daemon_ssl"]["private_crt"] key_path = root_path / net_config["daemon_ssl"]["private_key"] ca_cert_path = root_path / net_config["private_ssl_ca"]["crt"] @@ -396,7 +397,7 @@ class RpcServer: async def start_rpc_server( - rpc_api: RpcApiProtocol, + rpc_api: _T_RpcApiProtocol, self_hostname: str, daemon_port: uint16, rpc_port: uint16, @@ -405,7 +406,7 @@ async def start_rpc_server( net_config: Dict[str, object], connect_to_daemon: bool = True, max_request_body_size: Optional[int] = None, -) -> RpcServer: +) -> RpcServer[_T_RpcApiProtocol]: """ Starts an HTTP server with the following RPC methods, to be used by local clients to query the node. diff --git a/chia/seeder/start_crawler.py b/chia/seeder/start_crawler.py index 9d58a9f9d1cf..bfbc7008afc3 100644 --- a/chia/seeder/start_crawler.py +++ b/chia/seeder/start_crawler.py @@ -44,7 +44,7 @@ def create_full_node_crawler_service( network_id = service_config["selected_network"] - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[CrawlerRpcApi]] = None if crawler_config.get("start_rpc_server", True): rpc_info = (CrawlerRpcApi, crawler_config.get("rpc_port", 8561)) diff --git a/chia/server/start_data_layer.py b/chia/server/start_data_layer.py index 4d5a74806f9b..9901fa62efe8 100644 --- a/chia/server/start_data_layer.py +++ b/chia/server/start_data_layer.py @@ -61,7 +61,7 @@ def create_data_layer_service( api = DataLayerAPI(data_layer) network_id = service_config["selected_network"] rpc_port = service_config.get("rpc_port") - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[DataLayerRpcApi]] = None if rpc_port is not None: rpc_info = (DataLayerRpcApi, cast(int, service_config["rpc_port"])) diff --git a/chia/server/start_farmer.py b/chia/server/start_farmer.py index 1bf38556b675..a238c6117b9c 100644 --- a/chia/server/start_farmer.py +++ b/chia/server/start_farmer.py @@ -43,7 +43,7 @@ def create_farmer_service( root_path, service_config, config_pool, consensus_constants=updated_constants, local_keychain=keychain ) peer_api = FarmerAPI(farmer) - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[FarmerRpcApi]] = None if service_config["start_rpc_server"]: rpc_info = (FarmerRpcApi, service_config["rpc_port"]) return Service( diff --git a/chia/server/start_full_node.py b/chia/server/start_full_node.py index a7d307d2e377..d75cb770d416 100644 --- a/chia/server/start_full_node.py +++ b/chia/server/start_full_node.py @@ -49,7 +49,7 @@ async def create_full_node_service( if service_config["enable_upnp"]: upnp_list = [service_config["port"]] network_id = service_config["selected_network"] - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[FullNodeRpcApi]] = None if service_config["start_rpc_server"]: rpc_info = (FullNodeRpcApi, service_config["rpc_port"]) return Service( diff --git a/chia/server/start_harvester.py b/chia/server/start_harvester.py index 6e172ed80ad4..724194843809 100644 --- a/chia/server/start_harvester.py +++ b/chia/server/start_harvester.py @@ -39,7 +39,7 @@ def create_harvester_service( harvester = Harvester(root_path, service_config, updated_constants) peer_api = HarvesterAPI(harvester) network_id = service_config["selected_network"] - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[HarvesterRpcApi]] = None if service_config["start_rpc_server"]: rpc_info = (HarvesterRpcApi, service_config["rpc_port"]) return Service( diff --git a/chia/server/start_service.py b/chia/server/start_service.py index db749437bdee..f40aa6065dae 100644 --- a/chia/server/start_service.py +++ b/chia/server/start_service.py @@ -51,8 +51,9 @@ main_pid: Optional[int] = None T = TypeVar("T") _T_RpcServiceProtocol = TypeVar("_T_RpcServiceProtocol", bound=RpcServiceProtocol) _T_ApiProtocol = TypeVar("_T_ApiProtocol", bound=ApiProtocol) +_T_RpcApiProtocol = TypeVar("_T_RpcApiProtocol", bound=RpcApiProtocol) -RpcInfo = Tuple[Type[RpcApiProtocol], int] +RpcInfo = Tuple[Type[_T_RpcApiProtocol], int] log = logging.getLogger(__name__) @@ -61,7 +62,7 @@ class ServiceException(Exception): pass -class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol]): +class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol, _T_RpcApiProtocol]): def __init__( self, root_path: Path, @@ -76,7 +77,7 @@ class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol]): upnp_ports: Optional[List[int]] = None, connect_peers: Optional[Set[UnresolvedPeerInfo]] = None, on_connect_callback: Optional[Callable[[WSChiaConnection], Awaitable[None]]] = None, - rpc_info: Optional[RpcInfo] = None, + rpc_info: Optional[RpcInfo[_T_RpcApiProtocol]] = None, connect_to_daemon: bool = True, max_request_body_size: Optional[int] = None, override_capabilities: Optional[List[Tuple[uint16, str]]] = None, @@ -96,7 +97,7 @@ class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol]): self._connect_to_daemon = connect_to_daemon self._node_type = node_type self._service_name = service_name - self.rpc_server: Optional[RpcServer] = None + self.rpc_server: Optional[RpcServer[_T_RpcApiProtocol]] = None self._network_id: str = network_id self.max_request_body_size = max_request_body_size self.reconnect_retry_seconds: int = 3 diff --git a/chia/server/start_timelord.py b/chia/server/start_timelord.py index 86ac2c56e7e5..50ed0116c7b9 100644 --- a/chia/server/start_timelord.py +++ b/chia/server/start_timelord.py @@ -41,7 +41,7 @@ def create_timelord_service( peer_api = TimelordAPI(node) network_id = service_config["selected_network"] - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[TimelordRpcApi]] = None if service_config.get("start_rpc_server", True): rpc_info = (TimelordRpcApi, service_config.get("rpc_port", 8557)) diff --git a/chia/server/start_wallet.py b/chia/server/start_wallet.py index 89615e40bc03..f18b18ba2a0d 100644 --- a/chia/server/start_wallet.py +++ b/chia/server/start_wallet.py @@ -52,7 +52,7 @@ def create_wallet_service( network_id = service_config["selected_network"] rpc_port = service_config.get("rpc_port") - rpc_info: Optional[RpcInfo] = None + rpc_info: Optional[RpcInfo[WalletRpcApi]] = None if rpc_port is not None: rpc_info = (WalletRpcApi, service_config["rpc_port"]) diff --git a/chia/types/aliases.py b/chia/types/aliases.py index dd5a980155e3..7ab6a409e2b7 100644 --- a/chia/types/aliases.py +++ b/chia/types/aliases.py @@ -10,21 +10,29 @@ from chia.harvester.harvester import Harvester from chia.harvester.harvester_api import HarvesterAPI from chia.introducer.introducer import Introducer from chia.introducer.introducer_api import IntroducerAPI +from chia.rpc.crawler_rpc_api import CrawlerRpcApi +from chia.rpc.data_layer_rpc_api import DataLayerRpcApi +from chia.rpc.farmer_rpc_api import FarmerRpcApi +from chia.rpc.full_node_rpc_api import FullNodeRpcApi +from chia.rpc.harvester_rpc_api import HarvesterRpcApi +from chia.rpc.timelord_rpc_api import TimelordRpcApi +from chia.rpc.wallet_rpc_api import WalletRpcApi from chia.seeder.crawler import Crawler from chia.seeder.crawler_api import CrawlerAPI from chia.server.start_service import Service from chia.simulator.full_node_simulator import FullNodeSimulator +from chia.simulator.simulator_full_node_rpc_api import SimulatorFullNodeRpcApi from chia.timelord.timelord import Timelord from chia.timelord.timelord_api import TimelordAPI from chia.wallet.wallet_node import WalletNode from chia.wallet.wallet_node_api import WalletNodeAPI -CrawlerService = Service[Crawler, CrawlerAPI] -DataLayerService = Service[DataLayer, DataLayerAPI] -FarmerService = Service[Farmer, FarmerAPI] -FullNodeService = Service[FullNode, FullNodeAPI] -HarvesterService = Service[Harvester, HarvesterAPI] -IntroducerService = Service[Introducer, IntroducerAPI] -SimulatorFullNodeService = Service[FullNode, FullNodeSimulator] -TimelordService = Service[Timelord, TimelordAPI] -WalletService = Service[WalletNode, WalletNodeAPI] +CrawlerService = Service[Crawler, CrawlerAPI, CrawlerRpcApi] +DataLayerService = Service[DataLayer, DataLayerAPI, DataLayerRpcApi] +FarmerService = Service[Farmer, FarmerAPI, FarmerRpcApi] +FullNodeService = Service[FullNode, FullNodeAPI, FullNodeRpcApi] +HarvesterService = Service[Harvester, HarvesterAPI, HarvesterRpcApi] +IntroducerService = Service[Introducer, IntroducerAPI, FullNodeRpcApi] +SimulatorFullNodeService = Service[FullNode, FullNodeSimulator, SimulatorFullNodeRpcApi] +TimelordService = Service[Timelord, TimelordAPI, TimelordRpcApi] +WalletService = Service[WalletNode, WalletNodeAPI, WalletRpcApi]