rpc: Improve RPC server start/stop (#12768)

* rpc: Improve RPC server start/stop

* `RpcServer.stop` -> `RpcServer.close`
This commit is contained in:
dustinface 2022-08-06 08:07:41 +02:00 committed by GitHub
parent 2ab0fa8b0a
commit 820889a316
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 148 additions and 111 deletions

View File

@ -7,7 +7,7 @@ import traceback
from dataclasses import dataclass
from pathlib import Path
from ssl import SSLContext
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from typing import Any, Awaitable, Callable, Dict, List, Optional
from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web
from typing_extensions import final
@ -30,6 +30,13 @@ EndpointResult = Dict[str, Any]
Endpoint = Callable[[Dict[str, object]], Awaitable[EndpointResult]]
@dataclass(frozen=True)
class RpcEnvironment:
runner: web.AppRunner
site: web.TCPSite
listen_port: uint16
@final
@dataclass
class RpcServer:
@ -42,6 +49,8 @@ class RpcServer:
service_name: str
ssl_context: SSLContext
ssl_client_context: SSLContext
environment: Optional[RpcEnvironment] = None
daemon_connection_task: Optional[asyncio.Task] = None # type: ignore[type-arg] # Asks for Task parameter which doesn't work # noqa: E501
shut_down: bool = False
websocket: Optional[ClientWebSocketResponse] = None
client_session: Optional[ClientSession] = None
@ -58,12 +67,41 @@ class RpcServer:
ssl_client_context = ssl_context_for_client(ca_cert_path, ca_key_path, crt_path, key_path, log=log)
return cls(rpc_api, stop_cb, service_name, ssl_context, ssl_client_context)
async def stop(self) -> None:
async def start(self, root_path: Path, self_hostname: str, rpc_port: int, max_request_body_size: int) -> None:
if self.environment is not None:
raise RuntimeError("RpcServer already started")
app = web.Application(client_max_size=max_request_body_size)
runner = web.AppRunner(app, access_log=None)
runner.app.add_routes([web.post(route, wrap_http_handler(func)) for (route, func) in self.get_routes().items()])
await runner.setup()
site = web.TCPSite(runner, self_hostname, int(rpc_port), ssl_context=self.ssl_context)
await site.start()
#
# On a dual-stack system, we want to get the (first) IPv4 port unless
# prefer_ipv6 is set in which case we use the IPv6 port
#
if rpc_port == 0:
rpc_port = select_port(root_path, runner.addresses)
self.environment = RpcEnvironment(runner, site, uint16(rpc_port))
def close(self) -> None:
self.shut_down = True
async def await_closed(self) -> None:
if self.websocket is not None:
await self.websocket.close()
if self.client_session is not None:
await self.client_session.close()
if self.environment is not None:
await self.environment.runner.cleanup()
self.environment = None
if self.daemon_connection_task is not None:
await self.daemon_connection_task
self.daemon_connection_task = None
async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None:
if self.websocket is None or self.websocket.closed:
@ -97,6 +135,12 @@ class RpcServer:
return None
asyncio.create_task(self._state_changed(change, change_data))
@property
def listen_port(self) -> uint16:
if self.environment is None:
raise RuntimeError("RpcServer is not started")
return self.environment.listen_port
def get_routes(self) -> Dict[str, Endpoint]:
return {
**self.rpc_api.get_routes(),
@ -276,31 +320,37 @@ class RpcServer:
break
async def connect_to_daemon(self, self_hostname: str, daemon_port: uint16) -> None:
while not self.shut_down:
try:
self.client_session = ClientSession()
self.websocket = await self.client_session.ws_connect(
f"wss://{self_hostname}:{daemon_port}",
autoclose=True,
autoping=True,
heartbeat=60,
ssl_context=self.ssl_client_context,
max_msg_size=max_message_size,
)
await self.connection(self.websocket)
except ClientConnectorError:
log.warning(f"Cannot connect to daemon at ws://{self_hostname}:{daemon_port}")
except Exception as e:
tb = traceback.format_exc()
log.warning(f"Exception: {tb} {type(e)}")
if self.websocket is not None:
await self.websocket.close()
if self.client_session is not None:
await self.client_session.close()
self.websocket = None
self.client_session = None
await asyncio.sleep(2)
def connect_to_daemon(self, self_hostname: str, daemon_port: uint16) -> None:
if self.daemon_connection_task is not None:
raise RuntimeError("Already connected to the daemon")
async def inner() -> None:
while not self.shut_down:
try:
self.client_session = ClientSession()
self.websocket = await self.client_session.ws_connect(
f"wss://{self_hostname}:{daemon_port}",
autoclose=True,
autoping=True,
heartbeat=60,
ssl_context=self.ssl_client_context,
max_msg_size=max_message_size,
)
await self.connection(self.websocket)
except ClientConnectorError:
log.warning(f"Cannot connect to daemon at ws://{self_hostname}:{daemon_port}")
except Exception as e:
tb = traceback.format_exc()
log.warning(f"Exception: {tb} {type(e)}")
if self.websocket is not None:
await self.websocket.close()
if self.client_session is not None:
await self.client_session.close()
self.websocket = None
self.client_session = None
await asyncio.sleep(2)
self.daemon_connection_task = asyncio.create_task(inner())
async def start_rpc_server(
@ -313,8 +363,7 @@ async def start_rpc_server(
net_config: Dict[str, object],
connect_to_daemon: bool = True,
max_request_body_size: Optional[int] = None,
name: str = "rpc_server",
) -> Tuple[Callable[[], Awaitable[None]], uint16]:
) -> RpcServer:
"""
Starts an HTTP server with the following RPC methods, to be used by local clients to
query the node.
@ -322,32 +371,14 @@ async def start_rpc_server(
try:
if max_request_body_size is None:
max_request_body_size = 1024 ** 2
app = web.Application(client_max_size=max_request_body_size)
rpc_server = RpcServer.create(rpc_api, rpc_api.service_name, stop_cb, root_path, net_config)
rpc_server.rpc_api.service._set_state_changed_callback(rpc_server.state_changed)
app.add_routes([web.post(route, wrap_http_handler(func)) for (route, func) in rpc_server.get_routes().items()])
await rpc_server.start(root_path, self_hostname, rpc_port, max_request_body_size)
if connect_to_daemon:
daemon_connection = asyncio.create_task(rpc_server.connect_to_daemon(self_hostname, daemon_port))
runner = web.AppRunner(app, access_log=None)
await runner.setup()
rpc_server.connect_to_daemon(self_hostname, daemon_port)
site = web.TCPSite(runner, self_hostname, int(rpc_port), ssl_context=rpc_server.ssl_context)
await site.start()
#
# On a dual-stack system, we want to get the (first) IPv4 port unless
# prefer_ipv6 is set in which case we use the IPv6 port
#
if rpc_port == 0:
rpc_port = select_port(root_path, runner.addresses)
async def cleanup() -> None:
await rpc_server.stop()
await runner.cleanup()
if connect_to_daemon:
await daemon_connection
return cleanup, rpc_port
return rpc_server
except Exception:
tb = traceback.format_exc()
log.error(f"Starting RPC server failed. Exception {tb}.")

View File

@ -18,7 +18,7 @@ except ImportError:
uvloop = None
from chia.cmds.init_funcs import chia_full_version_str
from chia.rpc.rpc_server import start_rpc_server
from chia.rpc.rpc_server import start_rpc_server, RpcServer
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer
from chia.server.upnp import UPnP
@ -69,7 +69,7 @@ class Service:
self._connect_to_daemon = connect_to_daemon
self._node_type = node_type
self._service_name = service_name
self._rpc_task: Optional[asyncio.Task] = None
self.rpc_server: Optional[RpcServer] = None
self._rpc_close_task: Optional[asyncio.Task] = None
self._network_id: str = network_id
self.max_request_body_size = max_request_body_size
@ -164,19 +164,16 @@ class Service:
self._rpc_close_task = None
if self._rpc_info:
rpc_api, rpc_port = self._rpc_info
self._rpc_task = asyncio.create_task(
start_rpc_server(
rpc_api(self._node),
self.self_hostname,
self.daemon_port,
uint16(rpc_port),
self.stop,
self.root_path,
self.config,
self._connect_to_daemon,
max_request_body_size=self.max_request_body_size,
name=self._service_name + "_rpc",
)
self.rpc_server = await start_rpc_server(
rpc_api(self._node),
self.self_hostname,
self.daemon_port,
uint16(rpc_port),
self.stop,
self.root_path,
self.config,
self._connect_to_daemon,
max_request_body_size=self.max_request_body_size,
)
async def run(self) -> None:
@ -243,14 +240,9 @@ class Service:
self._log.info("Calling service stop callback")
if self._rpc_task is not None:
if self.rpc_server is not None:
self._log.info("Closing RPC server")
async def close_rpc_server() -> None:
if self._rpc_task:
await (await self._rpc_task)[0]()
self._rpc_close_task = asyncio.create_task(close_rpc_server())
self.rpc_server.close()
async def wait_closed(self) -> None:
await self._is_stopping.wait()
@ -260,9 +252,9 @@ class Service:
self._log.info("Waiting for ChiaServer to be closed")
await self._server.await_closed()
if self._rpc_close_task:
if self.rpc_server:
self._log.info("Waiting for RPC server")
await self._rpc_close_task
await self.rpc_server.await_closed()
self._log.info("Closed RPC server")
self._log.info("Waiting for service _await_closed callback")

View File

@ -68,7 +68,7 @@ async def harvester_farmer_environment(farmer_one_harvester, self_hostname):
farmer_rpc_api = FarmerRpcApi(farmer_service._api.farmer)
harvester_rpc_api = HarvesterRpcApi(harvester_service._node)
rpc_cleanup, rpc_port_farmer = await start_rpc_server(
rpc_server_farmer = await start_rpc_server(
farmer_rpc_api,
hostname,
daemon_port,
@ -78,7 +78,7 @@ async def harvester_farmer_environment(farmer_one_harvester, self_hostname):
config,
connect_to_daemon=False,
)
rpc_cleanup_2, rpc_port_harvester = await start_rpc_server(
rpc_server_harvester = await start_rpc_server(
harvester_rpc_api,
hostname,
daemon_port,
@ -89,8 +89,10 @@ async def harvester_farmer_environment(farmer_one_harvester, self_hostname):
connect_to_daemon=False,
)
farmer_rpc_cl = await FarmerRpcClient.create(self_hostname, rpc_port_farmer, bt.root_path, config)
harvester_rpc_cl = await HarvesterRpcClient.create(self_hostname, rpc_port_harvester, bt.root_path, config)
farmer_rpc_cl = await FarmerRpcClient.create(self_hostname, rpc_server_farmer.listen_port, bt.root_path, config)
harvester_rpc_cl = await HarvesterRpcClient.create(
self_hostname, rpc_server_harvester.listen_port, bt.root_path, config
)
async def have_connections():
return len(await farmer_rpc_cl.get_connections()) > 0
@ -101,10 +103,12 @@ async def harvester_farmer_environment(farmer_one_harvester, self_hostname):
farmer_rpc_cl.close()
harvester_rpc_cl.close()
rpc_server_harvester.close()
rpc_server_farmer.close()
await farmer_rpc_cl.await_closed()
await harvester_rpc_cl.await_closed()
await rpc_cleanup()
await rpc_cleanup_2()
await rpc_server_harvester.await_closed()
await rpc_server_farmer.await_closed()
@pytest.mark.asyncio

View File

@ -45,7 +45,7 @@ class TestRpc:
hostname = config["self_hostname"]
daemon_port = config["daemon_port"]
rpc_cleanup, test_rpc_port = await start_rpc_server(
rpc_server = await start_rpc_server(
full_node_rpc_api,
hostname,
daemon_port,
@ -57,7 +57,7 @@ class TestRpc:
)
try:
client = await FullNodeRpcClient.create(self_hostname, test_rpc_port, bt.root_path, config)
client = await FullNodeRpcClient.create(self_hostname, rpc_server.listen_port, bt.root_path, config)
await validate_get_routes(client, full_node_rpc_api)
state = await client.get_blockchain_state()
assert state["peak"] is None
@ -284,8 +284,9 @@ class TestRpc:
finally:
# Checks that the RPC manages to stop the node
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()
@pytest.mark.asyncio
async def test_signage_points(self, two_nodes_sim_and_wallets, empty_blockchain):
@ -306,7 +307,7 @@ class TestRpc:
full_node_rpc_api = FullNodeRpcApi(full_node_api_1.full_node)
rpc_cleanup, test_rpc_port = await start_rpc_server(
rpc_server = await start_rpc_server(
full_node_rpc_api,
self_hostname,
daemon_port,
@ -318,7 +319,7 @@ class TestRpc:
)
try:
client = await FullNodeRpcClient.create(self_hostname, test_rpc_port, bt.root_path, config)
client = await FullNodeRpcClient.create(self_hostname, rpc_server.listen_port, bt.root_path, config)
# Only provide one
res = await client.get_recent_signage_point_or_eos(None, None)
@ -428,5 +429,6 @@ class TestRpc:
finally:
# Checks that the RPC manages to stop the node
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()

View File

@ -101,7 +101,7 @@ async def one_wallet_node_and_rpc(
config = bt.config
daemon_port = config["daemon_port"]
rpc_cleanup, test_rpc_port = await start_rpc_server(
rpc_server = await start_rpc_server(
api_user,
self_hostname,
daemon_port,
@ -111,13 +111,14 @@ async def one_wallet_node_and_rpc(
config,
connect_to_daemon=False,
)
client = await WalletRpcClient.create(self_hostname, test_rpc_port, bt.root_path, config)
client = await WalletRpcClient.create(self_hostname, rpc_server.listen_port, bt.root_path, config)
yield client, wallet_node_0, full_node_api, bt
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()
@pytest_asyncio.fixture(scope="function")
@ -134,7 +135,7 @@ async def setup(two_wallet_nodes, self_hostname):
config = bt.config
daemon_port = config["daemon_port"]
rpc_cleanup, test_rpc_port = await start_rpc_server(
rpc_server = await start_rpc_server(
api_user,
self_hostname,
daemon_port,
@ -144,14 +145,14 @@ async def setup(two_wallet_nodes, self_hostname):
config,
connect_to_daemon=False,
)
client = await WalletRpcClient.create(self_hostname, test_rpc_port, bt.root_path, config)
client = await WalletRpcClient.create(self_hostname, rpc_server.listen_port, bt.root_path, config)
return (
full_nodes,
[wallet_node_0, wallet_node_1],
[our_ph, pool_ph],
client, # wallet rpc client
rpc_cleanup,
rpc_server,
)
@ -809,7 +810,7 @@ class TestPoolWalletRpc:
trusted, fee = trusted_and_fee
num_blocks = 4 # Num blocks to farm at a time
total_blocks = 0 # Total blocks farmed so far
full_nodes, wallet_nodes, receive_address, client, rpc_cleanup = setup
full_nodes, wallet_nodes, receive_address, client, rpc_server = setup
wallets = [wallet_n.wallet_state_manager.main_wallet for wallet_n in wallet_nodes]
wallet_node_0 = wallet_nodes[0]
our_ph = receive_address[0]
@ -922,15 +923,16 @@ class TestPoolWalletRpc:
finally:
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()
@pytest.mark.asyncio
@pytest.mark.parametrize("trusted_and_fee", [(True, FEE_AMOUNT), (False, 0)])
async def test_leave_pool(self, setup, trusted_and_fee, self_hostname):
"""This tests self-pooling -> pooling -> escaping -> self pooling"""
trusted, fee = trusted_and_fee
full_nodes, wallet_nodes, receive_address, client, rpc_cleanup = setup
full_nodes, wallet_nodes, receive_address, client, rpc_server = setup
our_ph = receive_address[0]
wallets = [wallet_n.wallet_state_manager.main_wallet for wallet_n in wallet_nodes]
pool_ph = receive_address[1]
@ -1057,15 +1059,16 @@ class TestPoolWalletRpc:
finally:
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()
@pytest.mark.asyncio
@pytest.mark.parametrize("trusted_and_fee", [(True, FEE_AMOUNT), (False, 0)])
async def test_change_pools(self, setup, trusted_and_fee, self_hostname):
"""This tests Pool A -> escaping -> Pool B"""
trusted, fee = trusted_and_fee
full_nodes, wallet_nodes, receive_address, client, rpc_cleanup = setup
full_nodes, wallet_nodes, receive_address, client, rpc_server = setup
our_ph = receive_address[0]
pool_a_ph = receive_address[1]
wallets = [wallet_n.wallet_state_manager.main_wallet for wallet_n in wallet_nodes]
@ -1158,15 +1161,16 @@ class TestPoolWalletRpc:
finally:
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()
@pytest.mark.asyncio
@pytest.mark.parametrize("trusted_and_fee", [(True, FEE_AMOUNT), (False, 0)])
async def test_change_pools_reorg(self, setup, trusted_and_fee, self_hostname):
"""This tests Pool A -> escaping -> reorg -> escaping -> Pool B"""
trusted, fee = trusted_and_fee
full_nodes, wallet_nodes, receive_address, client, rpc_cleanup = setup
full_nodes, wallet_nodes, receive_address, client, rpc_server = setup
our_ph = receive_address[0]
pool_a_ph = receive_address[1]
wallets = [wallet_n.wallet_state_manager.main_wallet for wallet_n in wallet_nodes]
@ -1285,5 +1289,6 @@ class TestPoolWalletRpc:
finally:
client.close()
rpc_server.close()
await client.await_closed()
await rpc_cleanup()
await rpc_server.await_closed()

View File

@ -135,7 +135,7 @@ async def wallet_rpc_environment(two_wallet_nodes, request, self_hostname):
full_node_rpc_api = FullNodeRpcApi(full_node_api.full_node)
rpc_cleanup_node, test_rpc_port_node = await start_rpc_server(
rpc_server_node = await start_rpc_server(
full_node_rpc_api,
hostname,
daemon_port,
@ -145,7 +145,7 @@ async def wallet_rpc_environment(two_wallet_nodes, request, self_hostname):
config,
connect_to_daemon=False,
)
rpc_cleanup, test_rpc_port = await start_rpc_server(
rpc_server_wallet = await start_rpc_server(
wallet_rpc_api,
hostname,
daemon_port,
@ -155,7 +155,7 @@ async def wallet_rpc_environment(two_wallet_nodes, request, self_hostname):
config,
connect_to_daemon=False,
)
rpc_cleanup_2, test_rpc_port_2 = await start_rpc_server(
rpc_server_wallet_2 = await start_rpc_server(
wallet_rpc_api_2,
hostname,
daemon_port,
@ -169,9 +169,9 @@ async def wallet_rpc_environment(two_wallet_nodes, request, self_hostname):
await server_2.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
await server_3.start_client(PeerInfo(self_hostname, uint16(full_node_server._port)), None)
client = await WalletRpcClient.create(hostname, test_rpc_port, bt.root_path, config)
client_2 = await WalletRpcClient.create(hostname, test_rpc_port_2, bt.root_path, config)
client_node = await FullNodeRpcClient.create(hostname, test_rpc_port_node, bt.root_path, config)
client = await WalletRpcClient.create(hostname, rpc_server_wallet.listen_port, bt.root_path, config)
client_2 = await WalletRpcClient.create(hostname, rpc_server_wallet_2.listen_port, bt.root_path, config)
client_node = await FullNodeRpcClient.create(hostname, rpc_server_node.listen_port, bt.root_path, config)
wallet_bundle_1: WalletBundle = WalletBundle(wallet_node, client, wallet)
wallet_bundle_2: WalletBundle = WalletBundle(wallet_node_2, client_2, wallet_2)
@ -183,12 +183,15 @@ async def wallet_rpc_environment(two_wallet_nodes, request, self_hostname):
client.close()
client_2.close()
client_node.close()
rpc_server_node.close()
rpc_server_wallet.close()
rpc_server_wallet_2.close()
await client.await_closed()
await client_2.await_closed()
await client_node.await_closed()
await rpc_cleanup()
await rpc_cleanup_2()
await rpc_cleanup_node()
await rpc_server_node.await_closed()
await rpc_server_wallet.await_closed()
await rpc_server_wallet_2.await_closed()
async def create_tx_outputs(wallet: Wallet, output_args: List[Tuple[int, Optional[List[str]]]]) -> List[Dict[str, Any]]: