commonize signal setup (#13568)

* commonize signal setup

* add signal setup for async handlers

* future

* Update misc.py

* catch up

* logging

* missed one

* context manager and waiting for signal handlers to complete

* correct logging

* fixup

* fixup
This commit is contained in:
Kyle Altendorf 2023-09-06 14:01:11 -04:00 committed by GitHub
parent fba4a79db6
commit 353f12d5a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 272 additions and 128 deletions

View File

@ -1,7 +1,6 @@
from __future__ import annotations
import asyncio
import functools
import json
import logging
import os
@ -16,6 +15,7 @@ from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path
from types import FrameType
from typing import Any, AsyncIterator, Dict, List, Optional, Set, TextIO, Tuple
from blspy import G1Element
@ -39,6 +39,7 @@ from chia.util.ints import uint32
from chia.util.json_util import dict_to_json_str
from chia.util.keychain import Keychain, KeyData, passphrase_requirements, supports_os_passphrase_storage
from chia.util.lock import Lockfile, LockfileError
from chia.util.misc import SignalHandlers
from chia.util.network import WebServer
from chia.util.service_groups import validate_service
from chia.util.setproctitle import setproctitle
@ -209,21 +210,17 @@ class WebSocketServer:
await self.stop()
await self.exit()
async def setup_process_global_state(self) -> None:
try:
asyncio.get_running_loop().add_signal_handler(
signal.SIGINT,
functools.partial(self._accept_signal, signal_number=signal.SIGINT),
)
asyncio.get_running_loop().add_signal_handler(
signal.SIGTERM,
functools.partial(self._accept_signal, signal_number=signal.SIGTERM),
)
except NotImplementedError:
self.log.info("Not implemented")
async def setup_process_global_state(self, signal_handlers: SignalHandlers) -> None:
signal_handlers.setup_async_signal_handler(handler=self._accept_signal)
def _accept_signal(self, signal_number: int, stack_frame=None):
asyncio.create_task(self.stop())
async def _accept_signal(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
self.log.info("Received signal %s (%s), shutting down.", signal_.name, signal_.value)
await self.stop()
def cancel_task_safe(self, task: Optional[asyncio.Task]):
if task is not None:
@ -1522,9 +1519,10 @@ async def async_run_daemon(root_path: Path, wait_for_unlock: bool = False) -> in
key_path,
run_check_keys_on_unlock=wait_for_unlock,
)
await ws_server.setup_process_global_state()
async with ws_server.run():
await ws_server.shutdown_event.wait()
async with SignalHandlers.manage() as signal_handlers:
await ws_server.setup_process_global_state(signal_handlers=signal_handlers)
async with ws_server.run():
await ws_server.shutdown_event.wait()
if beta_metrics is not None:
await beta_metrics.stop_logging()

View File

@ -1,12 +1,12 @@
from __future__ import annotations
import asyncio
import functools
import logging
import signal
import sys
from dataclasses import dataclass, field
from pathlib import Path
from types import FrameType
from typing import Any, Dict, Optional
import click
@ -17,6 +17,7 @@ from chia.server.upnp import UPnP
from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
from chia.util.network import WebServer
from chia.util.path import path_from_root
from chia.util.setproctitle import setproctitle
@ -40,25 +41,11 @@ class DataLayerServer:
webserver: Optional[WebServer] = None
upnp: UPnP = field(default_factory=UPnP)
async def start(self) -> None:
async def start(self, signal_handlers: SignalHandlers) -> None:
if self.webserver is not None:
raise RuntimeError("DataLayerServer already started")
if sys.platform == "win32" or sys.platform == "cygwin":
# pylint: disable=E1101
signal.signal(signal.SIGBREAK, self._accept_signal)
signal.signal(signal.SIGINT, self._accept_signal)
signal.signal(signal.SIGTERM, self._accept_signal)
else:
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGINT,
functools.partial(self._accept_signal, signal_number=signal.SIGINT),
)
loop.add_signal_handler(
signal.SIGTERM,
functools.partial(self._accept_signal, signal_number=signal.SIGTERM),
)
signal_handlers.setup_sync_signal_handler(handler=self._accept_signal)
self.log.info("Starting Data Layer HTTP Server.")
@ -110,8 +97,13 @@ class DataLayerServer:
)
return response
def _accept_signal(self, signal_number: int, stack_frame: Any = None) -> None:
self.log.info("Got SIGINT or SIGTERM signal - stopping")
def _accept_signal(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
self.log.info("Received signal %s (%s), shutting down.", signal_.name, signal_.value)
self.close()
@ -133,9 +125,10 @@ async def async_start(root_path: Path) -> int:
)
data_layer_server = DataLayerServer(root_path, dl_config, log, shutdown_event)
await data_layer_server.start()
await shutdown_event.wait()
await data_layer_server.await_closed()
async with SignalHandlers.manage() as signal_handlers:
await data_layer_server.start(signal_handlers=signal_handlers)
await shutdown_event.wait()
await data_layer_server.await_closed()
return 0

View File

@ -10,6 +10,7 @@ from dataclasses import dataclass, field
from ipaddress import IPv4Address, IPv6Address, ip_address
from multiprocessing import freeze_support
from pathlib import Path
from types import FrameType
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional
import aiosqlite
@ -19,6 +20,7 @@ from chia.seeder.crawl_store import CrawlStore
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
from chia.util.path import path_from_root
SERVICE_NAME = "seeder"
@ -267,7 +269,7 @@ class DNSServer:
shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
crawl_store: Optional[CrawlStore] = field(init=False, default=None)
reliable_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
shutdown_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
shutting_down: bool = field(init=False, default=False)
udp_transport_ipv4: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
udp_protocol_ipv4: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
udp_transport_ipv6: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
@ -357,20 +359,23 @@ class DNSServer:
await self.stop()
log.info("DNS server stopped.")
async def setup_process_global_state(self) -> None:
try:
loop = asyncio.get_running_loop()
loop.add_signal_handler(signal.SIGINT, self._accept_signal)
loop.add_signal_handler(signal.SIGTERM, self._accept_signal)
except NotImplementedError:
log.warning("signal handlers unsupported on this platform")
async def setup_process_global_state(self, signal_handlers: SignalHandlers) -> None:
signal_handlers.setup_async_signal_handler(handler=self._accept_signal)
def _accept_signal(self) -> None: # pragma: no cover
if self.shutdown_task is None: # otherwise we are already shutting down, so we ignore the signal
self.shutdown_task = asyncio.create_task(self.stop())
async def _accept_signal(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None: # pragma: no cover
log.info("Received signal %s (%s), shutting down.", signal_.name, signal_.value)
await self.stop()
async def stop(self) -> None:
log.info("Stopping DNS server...")
if self.shutting_down:
return
self.shutting_down = True
if self.reliable_task is not None:
self.reliable_task.cancel() # cancel the peer update task
if self.crawl_store is not None:
@ -510,9 +515,10 @@ class DNSServer:
async def run_dns_server(dns_server: DNSServer) -> None: # pragma: no cover
await dns_server.setup_process_global_state()
async with dns_server.run():
await dns_server.shutdown_event.wait() # this is released on SIGINT or SIGTERM or any unhandled exception
async with SignalHandlers.manage() as signal_handlers:
await dns_server.setup_process_global_state(signal_handlers=signal_handlers)
async with dns_server.run():
await dns_server.shutdown_event.wait() # this is released on SIGINT or SIGTERM or any unhandled exception
def create_dns_server_service(config: Dict[str, Any], root_path: Path) -> DNSServer:

View File

@ -16,6 +16,7 @@ from chia.server.start_service import RpcInfo, Service, async_run
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
# See: https://bugs.python.org/issue29288
"".encode("idna")
@ -71,8 +72,9 @@ async def async_main() -> int:
updated_constants = DEFAULT_CONSTANTS.replace_str_to_bytes(**overrides)
initialize_service_logging(service_name=SERVICE_NAME, config=config)
service = create_full_node_crawler_service(DEFAULT_ROOT_PATH, config, updated_constants)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -16,6 +16,7 @@ from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16
from chia.util.misc import SignalHandlers
from chia.wallet.wallet_node import WalletNode
from chia.wallet.wallet_node_api import WalletNodeAPI
@ -103,8 +104,9 @@ async def async_main() -> int:
uploaders: List[str] = config["data_layer"].get("uploaders", [])
downloaders: List[str] = config["data_layer"].get("downloaders", [])
service = create_data_layer_service(DEFAULT_ROOT_PATH, config, downloaders, uploaders)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -16,6 +16,7 @@ from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.keychain import Keychain
from chia.util.misc import SignalHandlers
# See: https://bugs.python.org/issue29288
"".encode("idna")
@ -72,8 +73,9 @@ async def async_main() -> int:
config["pool"] = config_pool
initialize_service_logging(service_name=SERVICE_NAME, config=config)
service = create_farmer_service(DEFAULT_ROOT_PATH, config, config_pool, DEFAULT_CONSTANTS)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -18,6 +18,7 @@ from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16
from chia.util.misc import SignalHandlers
from chia.util.task_timing import maybe_manage_task_instrumentation
# See: https://bugs.python.org/issue29288
@ -98,8 +99,9 @@ async def async_main(service_config: Dict[str, Any]) -> int:
updated_constants = DEFAULT_CONSTANTS.replace_str_to_bytes(**overrides)
initialize_service_logging(service_name=SERVICE_NAME, config=config)
service = create_full_node_service(DEFAULT_ROOT_PATH, config, updated_constants)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -15,6 +15,7 @@ from chia.types.peer_info import UnresolvedPeerInfo
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
# See: https://bugs.python.org/issue29288
"".encode("idna")
@ -64,8 +65,9 @@ async def async_main() -> int:
initialize_service_logging(service_name=SERVICE_NAME, config=config)
farmer_peer = UnresolvedPeerInfo(service_config["farmer_peer"]["host"], service_config["farmer_peer"]["port"])
service = create_harvester_service(DEFAULT_ROOT_PATH, config, DEFAULT_CONSTANTS, farmer_peer)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -11,6 +11,7 @@ from chia.server.start_service import Service, async_run
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
# See: https://bugs.python.org/issue29288
"".encode("idna")
@ -52,8 +53,9 @@ async def async_main() -> int:
config[SERVICE_NAME] = service_config
service = create_introducer_service(DEFAULT_ROOT_PATH, config)
initialize_service_logging(service_name=SERVICE_NAME, config=config)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -1,12 +1,10 @@
from __future__ import annotations
import asyncio
import functools
import logging
import logging.config
import os
import signal
import sys
from pathlib import Path
from types import FrameType
from typing import Any, Awaitable, Callable, Coroutine, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar
@ -24,6 +22,7 @@ from chia.server.ws_connection import WSChiaConnection
from chia.types.peer_info import PeerInfo, UnresolvedPeerInfo
from chia.util.ints import uint16
from chia.util.lock import Lockfile, LockfileError
from chia.util.misc import SignalHandlers
from chia.util.network import resolve
from chia.util.setproctitle import setproctitle
@ -230,7 +229,7 @@ class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol]):
def add_peer(self, peer: UnresolvedPeerInfo) -> None:
self._connect_peers.add(peer)
async def setup_process_global_state(self) -> None:
async def setup_process_global_state(self, signal_handlers: SignalHandlers) -> None:
# Being async forces this to be run from within an active event loop as is
# needed for the signal handler setup.
proctitle_name = f"chia_{self._service_name}"
@ -238,31 +237,31 @@ class Service(Generic[_T_RpcServiceProtocol, _T_ApiProtocol]):
global main_pid
main_pid = os.getpid()
if sys.platform == "win32" or sys.platform == "cygwin":
# pylint: disable=E1101
signal.signal(signal.SIGBREAK, self._accept_signal)
signal.signal(signal.SIGINT, self._accept_signal)
signal.signal(signal.SIGTERM, self._accept_signal)
else:
loop = asyncio.get_running_loop()
loop.add_signal_handler(
signal.SIGINT,
functools.partial(self._accept_signal, signal_number=signal.SIGINT),
)
loop.add_signal_handler(
signal.SIGTERM,
functools.partial(self._accept_signal, signal_number=signal.SIGTERM),
)
def _accept_signal(self, signal_number: int, stack_frame: Optional[FrameType] = None) -> None:
self._log.info(f"got signal {signal_number}")
signal_handlers.setup_sync_signal_handler(handler=self._accept_signal)
def _accept_signal(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
# we only handle signals in the main process. In the ProcessPoolExecutor
# processes, we have to ignore them. We'll shut them down gracefully
# from the main process
global main_pid
if os.getpid() != main_pid:
ignore = os.getpid() != main_pid
# TODO: if we remove this conditional behavior, consider moving logging to common signal handling
if ignore:
message = "ignoring in worker process"
else:
message = "shutting down"
self._log.info("Received signal %s (%s), %s.", signal_.name, signal_.value, message)
if ignore:
return
self.stop()
def stop(self) -> None:

View File

@ -16,6 +16,7 @@ from chia.types.peer_info import UnresolvedPeerInfo
from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
# See: https://bugs.python.org/issue29288
"".encode("idna")
@ -71,8 +72,9 @@ async def async_main() -> int:
config[SERVICE_NAME] = service_config
initialize_service_logging(service_name=SERVICE_NAME, config=config)
service = create_timelord_service(DEFAULT_ROOT_PATH, config, DEFAULT_CONSTANTS)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -16,6 +16,7 @@ from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.keychain import Keychain
from chia.util.misc import SignalHandlers
from chia.util.task_timing import maybe_manage_task_instrumentation
from chia.wallet.wallet_node import WalletNode
@ -93,8 +94,9 @@ async def async_main() -> int:
constants = DEFAULT_CONSTANTS
initialize_service_logging(service_name=SERVICE_NAME, config=config)
service = create_wallet_service(DEFAULT_ROOT_PATH, config, constants)
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -9,6 +9,7 @@ import time
from contextlib import contextmanager
from pathlib import Path
from secrets import token_bytes
from types import FrameType
from typing import Any, AsyncGenerator, Dict, Iterator, List, Optional, Tuple
from chia.cmds.init_funcs import init
@ -48,6 +49,7 @@ from chia.util.config import config_path_for_filename, load_config, lock_and_loa
from chia.util.ints import uint16
from chia.util.keychain import bytes_to_mnemonic
from chia.util.lock import Lockfile
from chia.util.misc import SignalHandlers
from chia.wallet.wallet_node import WalletNode
from chia.wallet.wallet_node_api import WalletNodeAPI
@ -429,14 +431,17 @@ async def setup_vdf_client(bt: BlockTools, self_hostname: str, port: int) -> Asy
spawn_process(self_hostname, port, 1, lock, prefer_ipv6=bt.config.get("prefer_ipv6", False))
)
def stop() -> None:
asyncio.create_task(kill_processes(lock))
async def stop(
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
await kill_processes(lock)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, stop)
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, stop)
yield vdf_task_1
await kill_processes(lock)
async with SignalHandlers.manage() as signal_handlers:
signal_handlers.setup_async_signal_handler(handler=stop)
yield vdf_task_1
await kill_processes(lock)
async def setup_vdf_clients(
@ -453,15 +458,20 @@ async def setup_vdf_clients(
spawn_process(self_hostname, port, 3, lock, prefer_ipv6=bt.config.get("prefer_ipv6", False))
)
def stop() -> None:
asyncio.create_task(kill_processes(lock))
async def stop(
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
await kill_processes(lock)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, stop)
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, stop)
signal_handlers = SignalHandlers()
async with signal_handlers.manage():
signal_handlers.setup_async_signal_handler(handler=stop)
yield vdf_task_1, vdf_task_2, vdf_task_3
yield vdf_task_1, vdf_task_2, vdf_task_3
await kill_processes(lock)
await kill_processes(lock)
async def setup_timelord(

View File

@ -26,6 +26,7 @@ from chia.util.errors import KeychainFingerprintExists
from chia.util.ints import uint32
from chia.util.keychain import Keychain
from chia.util.lock import Lockfile
from chia.util.misc import SignalHandlers
from chia.wallet.derive_keys import master_sk_to_wallet_sk
"""
@ -157,7 +158,8 @@ async def get_full_chia_simulator(
ca_key_path = chia_root / config["private_ssl_ca"]["key"]
ws_server = WebSocketServer(chia_root, ca_crt_path, ca_key_path, crt_path, key_path)
await ws_server.setup_process_global_state()
async with ws_server.run():
async for simulator in start_simulator(chia_root, automated_testing):
yield simulator, chia_root, config, mnemonic, fingerprint, keychain
async with SignalHandlers.manage() as signal_handlers:
await ws_server.setup_process_global_state(signal_handlers=signal_handlers)
async with ws_server.run():
async for simulator in start_simulator(chia_root, automated_testing):
yield simulator, chia_root, config, mnemonic, fingerprint, keychain

View File

@ -19,6 +19,7 @@ from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config, load_config_cli, override_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16
from chia.util.misc import SignalHandlers
# See: https://bugs.python.org/issue29288
"".encode("idna")
@ -107,8 +108,11 @@ async def async_main(test_mode: bool = False, automated_testing: bool = False, r
service = create_full_node_simulator_service(root_path, override_config(config, overrides), bt)
if test_mode:
return service
await service.setup_process_global_state()
await service.run()
async with SignalHandlers.manage() as signal_handlers:
await service.setup_process_global_state(signal_handlers=signal_handlers)
await service.run()
return 0

View File

@ -6,13 +6,15 @@ import os
import pathlib
import signal
import time
from typing import Dict, List
from types import FrameType
from typing import Any, Dict, List, Optional
import pkg_resources
from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.misc import SignalHandlers
from chia.util.network import resolve
from chia.util.setproctitle import setproctitle
@ -94,24 +96,23 @@ async def spawn_all_processes(config: Dict, net_config: Dict, lock: asyncio.Lock
await asyncio.gather(*awaitables)
def signal_received(lock: asyncio.Lock):
asyncio.create_task(kill_processes(lock))
async def async_main(config, net_config):
loop = asyncio.get_running_loop()
async def async_main(config: Dict[str, Any], net_config: Dict[str, Any]) -> None:
lock = asyncio.Lock()
try:
loop.add_signal_handler(signal.SIGINT, signal_received, lock)
loop.add_signal_handler(signal.SIGTERM, signal_received, lock)
except NotImplementedError:
log.info("signal handlers unsupported")
async def stop(
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
await kill_processes(lock)
try:
await spawn_all_processes(config, net_config, lock)
finally:
log.info("Launcher fully closed.")
async with SignalHandlers.manage() as signal_handlers:
signal_handlers.setup_async_signal_handler(handler=stop)
try:
await spawn_all_processes(config, net_config, lock)
finally:
log.info("Launcher fully closed.")
def main():

View File

@ -1,11 +1,30 @@
from __future__ import annotations
import asyncio
import contextlib
import dataclasses
import functools
import signal
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Collection, Dict, Generic, Iterator, List, Sequence, TypeVar, Union
from types import FrameType
from typing import (
Any,
AsyncIterator,
Collection,
Dict,
Generic,
Iterator,
List,
Optional,
Sequence,
TypeVar,
Union,
final,
)
from typing_extensions import Protocol
from chia.util.errors import InvalidPathError
from chia.util.ints import uint16, uint32, uint64
@ -165,3 +184,99 @@ def to_batches(to_split: Collection[T], batch_size: int) -> Iterator[Batch[T]]:
yield Batch(total_size - processed, entries)
else:
raise ValueError(f"to_batches: Unsupported type {type(to_split)}")
class Handler(Protocol):
def __call__(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
...
class AsyncHandler(Protocol):
async def __call__(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
) -> None:
...
@final
@dataclasses.dataclass
class SignalHandlers:
tasks: List[asyncio.Task[None]] = dataclasses.field(default_factory=list)
@classmethod
@contextlib.asynccontextmanager
async def manage(cls) -> AsyncIterator[SignalHandlers]:
self = cls()
try:
yield self
finally:
# TODO: log errors?
# TODO: return to previous signal handlers?
await asyncio.gather(*self.tasks)
def remove_done_handlers(self) -> None:
self.tasks = [task for task in self.tasks if not task.done()]
def loop_safe_sync_signal_handler_for_async(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
handler: AsyncHandler,
) -> None:
self.remove_done_handlers()
task = asyncio.create_task(
handler(signal_=signal_, stack_frame=stack_frame, loop=loop),
)
self.tasks.append(task)
def threadsafe_sync_signal_handler_for_async(
self,
signal_: signal.Signals,
stack_frame: Optional[FrameType],
loop: asyncio.AbstractEventLoop,
handler: AsyncHandler,
) -> None:
loop.call_soon_threadsafe(
functools.partial(
self.loop_safe_sync_signal_handler_for_async,
signal_=signal_,
stack_frame=stack_frame,
loop=loop,
handler=handler,
),
)
def setup_sync_signal_handler(self, handler: Handler) -> None:
loop = asyncio.get_event_loop()
if sys.platform == "win32" or sys.platform == "cygwin":
for signal_ in [signal.SIGBREAK, signal.SIGINT, signal.SIGTERM]:
signal.signal(signal_, functools.partial(handler, loop=loop))
else:
for signal_ in [signal.SIGINT, signal.SIGTERM]:
loop.add_signal_handler(
signal_,
functools.partial(handler, signal_=signal_, stack_frame=None, loop=loop),
)
def setup_async_signal_handler(self, handler: AsyncHandler) -> None:
# https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.loop.add_signal_handler
# > a callback registered with this function is allowed to interact with the event
# > loop
#
# This is a bit vague so let's just use a thread safe call for Windows
# compatibility.
self.setup_sync_signal_handler(
handler=functools.partial(self.threadsafe_sync_signal_handler_for_async, handler=handler)
)