Add a daemon heartbeat setting to config.yaml (#13886)

* make daemon heartbeat configurable and increase default

* Fix up daemon rpc test

* Fix dumb error with parameters

* Restore formatting

* Various updates from feedback

* Update tests/core/test_daemon_rpc.py

use config value for heartbeat

Co-authored-by: Kyle Altendorf <sda@fstab.net>

* black fixes

Co-authored-by: Kyle Altendorf <sda@fstab.net>
This commit is contained in:
Earle Lowe 2023-01-06 09:55:20 -08:00 committed by GitHub
parent c74689dbbe
commit 0a0c8920ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 49 additions and 23 deletions

View File

@ -18,12 +18,14 @@ class DaemonProxy:
self,
uri: str,
ssl_context: Optional[ssl.SSLContext],
heartbeat: int,
max_message_size: int = 50 * 1000 * 1000,
):
self._uri = uri
self._request_dict: Dict[str, asyncio.Event] = {}
self.response_dict: Dict[str, WsRpcMessage] = {}
self.ssl_context = ssl_context
self.heartbeat = heartbeat
self.client_session: Optional[aiohttp.ClientSession] = None
self.websocket: Optional[aiohttp.ClientWebSocketResponse] = None
self.max_message_size = max_message_size
@ -39,7 +41,7 @@ class DaemonProxy:
self._uri,
autoclose=True,
autoping=True,
heartbeat=60,
heartbeat=self.heartbeat,
ssl_context=self.ssl_context,
max_msg_size=self.max_message_size,
)
@ -145,13 +147,18 @@ class DaemonProxy:
async def connect_to_daemon(
self_hostname: str, daemon_port: int, max_message_size: int, ssl_context: ssl.SSLContext
self_hostname: str, daemon_port: int, max_message_size: int, ssl_context: ssl.SSLContext, heartbeat: int
) -> DaemonProxy:
"""
Connect to the local daemon.
"""
client = DaemonProxy(f"wss://{self_hostname}:{daemon_port}", ssl_context, max_message_size)
client = DaemonProxy(
f"wss://{self_hostname}:{daemon_port}",
ssl_context=ssl_context,
max_message_size=max_message_size,
heartbeat=heartbeat,
)
await client.start()
return client
@ -167,13 +174,18 @@ async def connect_to_daemon_and_validate(
try:
daemon_max_message_size = config.get("daemon_max_message_size", 50 * 1000 * 1000)
daemon_heartbeat = config.get("daemon_heartbeat", 300)
crt_path = root_path / config["daemon_ssl"]["private_crt"]
key_path = root_path / config["daemon_ssl"]["private_key"]
ca_crt_path = root_path / config["private_ssl_ca"]["crt"]
ca_key_path = root_path / config["private_ssl_ca"]["key"]
ssl_context = ssl_context_for_client(ca_crt_path, ca_key_path, crt_path, key_path)
connection = await connect_to_daemon(
config["self_hostname"], config["daemon_port"], daemon_max_message_size, ssl_context
config["self_hostname"],
config["daemon_port"],
max_message_size=daemon_max_message_size,
ssl_context=ssl_context,
heartbeat=daemon_heartbeat,
)
r = await connection.ping()

View File

@ -49,8 +49,9 @@ class KeychainProxy(DaemonProxy):
local_keychain: Optional[Keychain] = None,
user: Optional[str] = None,
service: Optional[str] = None,
heartbeat: int = 300,
):
super().__init__(uri, ssl_context)
super().__init__(uri, ssl_context, heartbeat=heartbeat)
self.log = log
if local_keychain:
self.keychain = local_keychain
@ -108,7 +109,7 @@ class KeychainProxy(DaemonProxy):
self._uri,
autoclose=True,
autoping=True,
heartbeat=60,
heartbeat=self.heartbeat,
ssl_context=self.ssl_context,
max_msg_size=self.max_message_size,
)
@ -391,6 +392,7 @@ def wrap_local_keychain(keychain: Keychain, log: logging.Logger) -> KeychainProx
async def connect_to_keychain(
self_hostname: str,
daemon_port: int,
daemon_heartbeat: int,
ssl_context: Optional[ssl.SSLContext],
log: logging.Logger,
user: Optional[str] = None,
@ -401,7 +403,12 @@ async def connect_to_keychain(
"""
client = KeychainProxy(
uri=f"wss://{self_hostname}:{daemon_port}", ssl_context=ssl_context, log=log, user=user, service=service
uri=f"wss://{self_hostname}:{daemon_port}",
heartbeat=daemon_heartbeat,
ssl_context=ssl_context,
log=log,
user=user,
service=service,
)
# Connect to the service if the proxy isn't using a local keychain
if not client.use_local_keychain():
@ -426,8 +433,9 @@ async def connect_to_keychain_and_validate(
ca_crt_path = root_path / net_config["private_ssl_ca"]["crt"]
ca_key_path = root_path / net_config["private_ssl_ca"]["key"]
ssl_context = ssl_context_for_client(ca_crt_path, ca_key_path, crt_path, key_path, log=log)
daemon_heartbeat = net_config.get("daemon_heartbeat", 300)
connection = await connect_to_keychain(
net_config["self_hostname"], net_config["daemon_port"], ssl_context, log, user, service
net_config["self_hostname"], net_config["daemon_port"], daemon_heartbeat, ssl_context, log, user, service
)
# If proxying to a local keychain, don't attempt to ping

View File

@ -142,6 +142,7 @@ class WebSocketServer:
self.self_hostname = self.net_config["self_hostname"]
self.daemon_port = self.net_config["daemon_port"]
self.daemon_max_message_size = self.net_config.get("daemon_max_message_size", 50 * 1000 * 1000)
self.heartbeat = self.net_config.get("daemon_heartbeat", 300)
self.webserver: Optional[WebServer] = None
self.ssl_context = ssl_context_for_server(ca_crt_path, ca_key_path, crt_path, key_path, log=self.log)
self.keychain_server = KeychainServer()
@ -217,7 +218,9 @@ class WebSocketServer:
return {"success": True, "services_stopped": service_names}
async def incoming_connection(self, request):
ws: WebSocketResponse = web.WebSocketResponse(max_msg_size=self.daemon_max_message_size, heartbeat=30)
ws: WebSocketResponse = web.WebSocketResponse(
max_msg_size=self.daemon_max_message_size, heartbeat=self.heartbeat
)
await ws.prepare(request)
while True:

View File

@ -147,6 +147,7 @@ class RpcServer:
ssl_context: SSLContext
ssl_client_context: SSLContext
webserver: Optional[WebServer] = None
daemon_heartbeat: int = 300
daemon_connection_task: Optional[asyncio.Task[None]] = None
shut_down: bool = False
websocket: Optional[ClientWebSocketResponse] = None
@ -165,9 +166,10 @@ class RpcServer:
key_path = root_path / net_config["daemon_ssl"]["private_key"]
ca_cert_path = root_path / net_config["private_ssl_ca"]["crt"]
ca_key_path = root_path / net_config["private_ssl_ca"]["key"]
daemon_heartbeat = net_config.get("daemon_heartbeat", 300)
ssl_context = ssl_context_for_server(ca_cert_path, ca_key_path, crt_path, key_path, log=log)
ssl_client_context = ssl_context_for_client(ca_cert_path, ca_key_path, crt_path, key_path, log=log)
return cls(rpc_api, stop_cb, service_name, ssl_context, ssl_client_context)
return cls(rpc_api, stop_cb, service_name, ssl_context, ssl_client_context, daemon_heartbeat=daemon_heartbeat)
async def start(self, self_hostname: str, rpc_port: uint16, max_request_body_size: int, prefer_ipv6: bool) -> None:
if self.webserver is not None:
@ -378,7 +380,7 @@ class RpcServer:
f"wss://{self_hostname}:{daemon_port}",
autoclose=True,
autoping=True,
heartbeat=60,
heartbeat=self.daemon_heartbeat,
ssl_context=self.ssl_client_context,
max_msg_size=max_message_size,
)

View File

@ -6,6 +6,7 @@ self_hostname: &self_hostname "localhost"
prefer_ipv6: False
daemon_port: 55400
daemon_max_message_size: 50000000 # maximum size of RPC message in bytes
daemon_heartbeat: 300 # sets the heartbeat for ping/ping interval and timeouts
inbound_rate_limit_percent: 100
outbound_rate_limit_percent: 30

View File

@ -626,7 +626,6 @@ async def daemon_connection_and_temp_keychain(get_b_tools):
f"wss://127.0.0.1:{get_b_tools._config['daemon_port']}",
autoclose=True,
autoping=True,
heartbeat=60,
ssl=get_b_tools.get_daemon_ssl_context(),
max_msg_size=52428800,
) as ws:

View File

@ -214,7 +214,6 @@ async def test_daemon_simulation(self_hostname, daemon_simulation):
f"wss://127.0.0.1:{daemon1.daemon_port}",
autoclose=True,
autoping=True,
heartbeat=60,
ssl_context=get_b_tools.get_daemon_ssl_context(),
max_msg_size=100 * 1024 * 1024,
)

View File

@ -51,14 +51,14 @@ class TestDos:
ssl_context = server_2.ssl_client_context
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
assert not ws.closed
await ws.close()
assert ws.closed
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
assert not ws.closed
@ -76,7 +76,7 @@ class TestDos:
assert ws.closed
try:
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
response: WSMessage = await ws.receive()
assert response.type == WSMsgType.CLOSE
@ -98,7 +98,7 @@ class TestDos:
ssl_context = server_2.ssl_client_context
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
await ws.send_bytes(bytes([1] * 1024))
@ -113,7 +113,7 @@ class TestDos:
assert ws.closed
try:
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
response: WSMessage = await ws.receive()
assert response.type == WSMsgType.CLOSE
@ -122,9 +122,7 @@ class TestDos:
await asyncio.sleep(6)
# Ban expired
await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
await session.ws_connect(url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024)
await session.close()
@ -142,7 +140,7 @@ class TestDos:
ssl_context = server_2.ssl_client_context
ws = await session.ws_connect(
url, autoclose=True, autoping=True, heartbeat=60, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
url, autoclose=True, autoping=True, ssl=ssl_context, max_msg_size=100 * 1024 * 1024
)
# Construct an otherwise valid handshake message

View File

@ -12,7 +12,11 @@ class TestDaemonRpc:
ws_server = get_daemon
config = bt.config
client = await connect_to_daemon(
config["self_hostname"], config["daemon_port"], 50 * 1000 * 1000, bt.get_daemon_ssl_context()
config["self_hostname"],
config["daemon_port"],
50 * 1000 * 1000,
bt.get_daemon_ssl_context(),
heartbeat=config["daemon_heartbeat"],
)
response = await client.get_version()