mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-19 23:21:46 +03:00
util|rpc: Refactor RpcEnvironment
-> WebServer
(#13515)
* Move `rpc.rpc_server.RpcEnvironment` -> `util.network.WebServer` * Improve `WebServer` * Drop attribute resets * `black` fix after rebase * Fix `test_nft_bulk_mint.py`
This commit is contained in:
parent
0f7d022108
commit
2aa117b58a
@ -21,7 +21,7 @@ from chia.util.byte_types import hexstr_to_bytes
|
||||
from chia.util.config import str2bool
|
||||
from chia.util.ints import uint16
|
||||
from chia.util.json_util import dict_to_json_str
|
||||
from chia.util.network import select_port
|
||||
from chia.util.network import WebServer
|
||||
from chia.util.ws_message import WsRpcMessage, create_payload, create_payload_dict, format_response, pong
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -32,13 +32,6 @@ EndpointResult = Dict[str, Any]
|
||||
Endpoint = Callable[[Dict[str, object]], Awaitable[EndpointResult]]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RpcEnvironment:
|
||||
runner: web.AppRunner
|
||||
site: web.TCPSite
|
||||
listen_port: uint16
|
||||
|
||||
|
||||
class StateChangedProtocol(Protocol):
|
||||
def __call__(self, change: str, change_data: Dict[str, Any]) -> None:
|
||||
...
|
||||
@ -153,7 +146,7 @@ class RpcServer:
|
||||
service_name: str
|
||||
ssl_context: SSLContext
|
||||
ssl_client_context: SSLContext
|
||||
environment: Optional[RpcEnvironment] = None
|
||||
webserver: Optional[WebServer] = None
|
||||
daemon_connection_task: Optional[asyncio.Task[None]] = None
|
||||
shut_down: bool = False
|
||||
websocket: Optional[ClientWebSocketResponse] = None
|
||||
@ -171,38 +164,30 @@ class RpcServer:
|
||||
ssl_client_context = ssl_context_for_client(ca_cert_path, ca_key_path, crt_path, key_path, log=log)
|
||||
return cls(rpc_api, stop_cb, service_name, ssl_context, ssl_client_context)
|
||||
|
||||
async def start(self, self_hostname: str, rpc_port: int, max_request_body_size: int, prefer_ipv6: bool) -> None:
|
||||
if self.environment is not None:
|
||||
async def start(self, self_hostname: str, rpc_port: uint16, max_request_body_size: int, prefer_ipv6: bool) -> None:
|
||||
if self.webserver is not None:
|
||||
raise RuntimeError("RpcServer already started")
|
||||
|
||||
app = web.Application(client_max_size=max_request_body_size)
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
|
||||
runner.app.add_routes([web.post(route, wrap_http_handler(func)) for (route, func) in self.get_routes().items()])
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, self_hostname, int(rpc_port), ssl_context=self.ssl_context)
|
||||
await site.start()
|
||||
|
||||
#
|
||||
# On a dual-stack system, we want to get the (first) IPv4 port unless
|
||||
# prefer_ipv6 is set in which case we use the IPv6 port
|
||||
#
|
||||
if rpc_port == 0:
|
||||
rpc_port = select_port(prefer_ipv6, runner.addresses)
|
||||
|
||||
self.environment = RpcEnvironment(runner, site, uint16(rpc_port))
|
||||
self.webserver = await WebServer.create(
|
||||
hostname=self_hostname,
|
||||
port=rpc_port,
|
||||
max_request_body_size=max_request_body_size,
|
||||
routes=[web.post(route, wrap_http_handler(func)) for (route, func) in self.get_routes().items()],
|
||||
ssl_context=self.ssl_context,
|
||||
prefer_ipv6=prefer_ipv6,
|
||||
)
|
||||
|
||||
def close(self) -> None:
|
||||
self.shut_down = True
|
||||
if self.webserver is not None:
|
||||
self.webserver.close()
|
||||
|
||||
async def await_closed(self) -> None:
|
||||
if self.websocket is not None:
|
||||
await self.websocket.close()
|
||||
if self.client_session is not None:
|
||||
await self.client_session.close()
|
||||
if self.environment is not None:
|
||||
await self.environment.runner.cleanup()
|
||||
self.environment = None
|
||||
if self.webserver is not None:
|
||||
await self.webserver.await_closed()
|
||||
if self.daemon_connection_task is not None:
|
||||
await self.daemon_connection_task
|
||||
self.daemon_connection_task = None
|
||||
@ -241,9 +226,9 @@ class RpcServer:
|
||||
|
||||
@property
|
||||
def listen_port(self) -> uint16:
|
||||
if self.environment is None:
|
||||
if self.webserver is None:
|
||||
raise RuntimeError("RpcServer is not started")
|
||||
return self.environment.listen_port
|
||||
return self.webserver.listen_port
|
||||
|
||||
def get_routes(self) -> Dict[str, Endpoint]:
|
||||
return {
|
||||
|
@ -1,12 +1,72 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.log import web_logger
|
||||
from dataclasses import dataclass
|
||||
from ipaddress import ip_address, IPv4Network, IPv6Network
|
||||
from typing import Iterable, List, Tuple, Union, Any, Optional, Dict
|
||||
from typing_extensions import final
|
||||
from chia.server.outbound_message import NodeType
|
||||
from chia.types.blockchain_format.sized_bytes import bytes32
|
||||
from chia.types.peer_info import PeerInfo
|
||||
from chia.util.ints import uint16
|
||||
|
||||
|
||||
@final
|
||||
@dataclass
|
||||
class WebServer:
|
||||
runner: web.AppRunner
|
||||
listen_port: uint16
|
||||
_close_task: Optional[asyncio.Task[None]] = None
|
||||
|
||||
@classmethod
|
||||
async def create(
|
||||
cls,
|
||||
hostname: str,
|
||||
port: uint16,
|
||||
routes: List[web.RouteDef],
|
||||
max_request_body_size: int = 1024**2, # Default `client_max_size` from web.Application
|
||||
ssl_context: Optional[ssl.SSLContext] = None,
|
||||
keepalive_timeout: int = 75, # Default from aiohttp.web
|
||||
shutdown_timeout: int = 60, # Default `shutdown_timeout` from web.TCPSite
|
||||
prefer_ipv6: bool = False,
|
||||
logger: logging.Logger = web_logger,
|
||||
) -> WebServer:
|
||||
app = web.Application(client_max_size=max_request_body_size, logger=logger)
|
||||
runner = web.AppRunner(app, access_log=None, keepalive_timeout=keepalive_timeout)
|
||||
|
||||
runner.app.add_routes(routes)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, hostname, int(port), ssl_context=ssl_context, shutdown_timeout=shutdown_timeout)
|
||||
await site.start()
|
||||
|
||||
#
|
||||
# On a dual-stack system, we want to get the (first) IPv4 port unless
|
||||
# prefer_ipv6 is set in which case we use the IPv6 port
|
||||
#
|
||||
if port == 0:
|
||||
port = select_port(prefer_ipv6, runner.addresses)
|
||||
|
||||
return cls(runner=runner, listen_port=uint16(port))
|
||||
|
||||
async def _close(self) -> None:
|
||||
await self.runner.shutdown()
|
||||
await self.runner.cleanup()
|
||||
|
||||
def close(self) -> None:
|
||||
self._close_task = asyncio.create_task(self._close())
|
||||
|
||||
async def await_closed(self) -> None:
|
||||
if self._close_task is None:
|
||||
raise RuntimeError("WebServer stop not triggered")
|
||||
await self._close_task
|
||||
|
||||
|
||||
def is_in_network(peer_host: str, networks: Iterable[Union[IPv4Network, IPv6Network]]) -> bool:
|
||||
try:
|
||||
peer_host_ip = ip_address(peer_host)
|
||||
|
@ -323,6 +323,8 @@ async def test_nft_mint_from_did_rpc(two_wallet_nodes: Any, trusted: Any, self_h
|
||||
finally:
|
||||
client.close()
|
||||
client_node.close()
|
||||
rpc_server.close()
|
||||
rpc_server_node.close()
|
||||
await client.await_closed()
|
||||
await client_node.await_closed()
|
||||
await rpc_server.await_closed()
|
||||
@ -505,6 +507,8 @@ async def test_nft_mint_from_did_rpc_no_royalties(two_wallet_nodes: Any, trusted
|
||||
finally:
|
||||
client.close()
|
||||
client_node.close()
|
||||
rpc_server.close()
|
||||
rpc_server_node.close()
|
||||
await client.await_closed()
|
||||
await client_node.await_closed()
|
||||
await rpc_server.await_closed()
|
||||
@ -905,6 +909,8 @@ async def test_nft_mint_from_xch_rpc(two_wallet_nodes: Any, trusted: Any, self_h
|
||||
finally:
|
||||
client.close()
|
||||
client_node.close()
|
||||
rpc_server.close()
|
||||
rpc_server_node.close()
|
||||
await client.await_closed()
|
||||
await client_node.await_closed()
|
||||
await rpc_server.await_closed()
|
||||
|
Loading…
Reference in New Issue
Block a user