sapling/eden/fs/cli/systemd.py
Chad Austin 4c0996afa5 make systemd management functions explicitly async
Summary:
Rather than dynamically allocating an event loop in the systemd async
code, make all the corresponding functions async, so the caller is
responsible for threading an event loop down.

Reviewed By: genevievehelsel

Differential Revision: D21894106

fbshipit-source-id: 398c769c30c85a3bb210dbc209f34f9f7336996c
2020-07-07 11:31:33 -07:00

770 lines
26 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2.
import asyncio
import contextlib
import logging
import os
import pathlib
import re
import subprocess
import types
import typing
pystemd_import_error = None
try:
import pystemd
import pystemd.dbusexc # pyre-ignore[21]: T32805591
import pystemd.dbuslib
import pystemd.systemd1.manager
import pystemd.systemd1.unit
except ModuleNotFoundError as e:
pystemd_import_error = e
logger = logging.getLogger(__name__)
_T = typing.TypeVar("_T")
def edenfs_systemd_service_name(eden_dir: pathlib.Path) -> str:
assert isinstance(eden_dir, pathlib.PosixPath)
instance_name = systemd_escape_path(eden_dir)
return f"fb-edenfs@{instance_name}.service"
def systemd_escape_path(path: pathlib.PurePosixPath) -> str:
"""Escape a path for inclusion in a systemd unit name.
See the 'systemd-escape --path' command for details.
"""
if not path.is_absolute():
raise ValueError("systemd_escape_path can only escape absolute paths")
if ".." in path.parts:
raise ValueError(
"systemd_escape_path can only escape paths without '..' components"
)
stdout: bytes = subprocess.check_output(
["systemd-escape", "--path", "--", str(path)]
)
return stdout.decode("utf-8").rstrip("\n")
class EdenFSSystemdServiceConfig:
__eden_dir: pathlib.Path
__edenfs_executable_path: pathlib.Path
__extra_edenfs_arguments: typing.List[str]
def __init__(
self,
eden_dir: pathlib.Path,
edenfs_executable_path: pathlib.Path,
extra_edenfs_arguments: typing.Sequence[str],
) -> None:
super().__init__()
self.__eden_dir = eden_dir
self.__edenfs_executable_path = edenfs_executable_path
self.__extra_edenfs_arguments = list(extra_edenfs_arguments)
@property
def config_file_path(self) -> pathlib.Path:
return self.__eden_dir / "systemd.conf"
@property
def startup_log_file_path(self) -> pathlib.Path:
# TODO(T33122320): Move this into <eden_dir>/logs/.
return self.__eden_dir / "startup.log"
def write_config_file(self) -> None:
variables = {
b"EDENFS_EXECUTABLE_PATH": bytes(self.__edenfs_executable_path),
b"EDENFS_EXTRA_ARGUMENTS": self.__escape_argument_list(
self.__extra_edenfs_arguments
),
}
self.config_file_path.parent.mkdir(parents=True, exist_ok=True)
self.config_file_path.write_bytes(SystemdEnvironmentFile.dumps(variables))
@staticmethod
def __escape_argument_list(arguments: typing.Sequence[str]) -> bytes:
for argument in arguments:
if "\n" in arguments:
raise ValueError(
f"Newlines in arguments are not supported\nArgument: {argument!r}"
)
return b"\n".join(arg.encode("utf-8") for arg in arguments)
class SystemdEnvironmentFile:
_comment_characters = b"#;"
_escape_characters = b"\\"
_name_characters = (
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_"
)
_newline_characters = b"\n\r"
_quote_characters = b"'\""
_whitespace_characters = b" \t"
def __init__(self, entries: typing.Sequence[typing.Tuple[bytes, bytes]]) -> None:
super().__init__()
self.__entries = list(entries)
@classmethod
def loads(cls, content: bytes) -> "SystemdEnvironmentFile":
content = _truncated_at_null_byte(content)
entries = _EnvironmentFileParser(content).parse_entries()
return cls(entries=entries)
@classmethod
def dumps(cls, variables: typing.Mapping[bytes, bytes]) -> bytes:
output = bytearray()
for name, value in variables.items():
cls._validate_entry(name, value)
output.extend(name)
output.extend(b"=")
output.extend(cls.__escape_value(value))
output.extend(b"\n")
return bytes(output)
@staticmethod
def __escape_value(value: bytes) -> bytes:
return (
b'"'
+ re.sub(b'[\\\\"]', lambda match: b"\\" + match.group(0), value)
+ b'"'
)
@classmethod
def _is_valid_entry(cls, name: bytes, value: bytes) -> bool:
try:
cls._validate_entry(name, value)
return True
except (VariableNameError, VariableValueError):
return False
@classmethod
def _validate_entry(cls, name: bytes, value: bytes) -> None:
if not name:
raise VariableNameError("Variables must have a non-empty name")
if name[0:1].isdigit():
raise VariableNameError("Variable names must not begin with a digit")
for c in name:
if c in cls._whitespace_characters:
raise VariableNameError("Variable names must not contain whitespace")
if c in cls._newline_characters:
raise VariableNameError(
"Variable names must not contain any newline characters"
)
if c < 0x20:
raise VariableNameError(
f"Variable names must not contain any control characters"
)
if c < 0x80 and c not in cls._name_characters:
offending_character = bytes([c]).decode("utf-8")
raise VariableNameError(
f"Variable names must not contain '{offending_character}'"
)
for c in value:
if c in b"\r":
raise VariableValueError(
"Variable values must not contain carriage returns"
)
if c < 0x20 and c not in b"\n\t":
raise VariableValueError(
"Variable values must not contain any control characters"
)
@property
def entries(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self.__entries
class VariableNameError(ValueError):
pass
class VariableValueError(ValueError):
pass
class _Scanner:
def __init__(self, input: bytes) -> None:
super().__init__()
self.__input = input
self.__index = 0
@property
def at_eof(self) -> bool:
return self.__index == len(self.__input)
def scan_one_byte(self) -> int:
if self.at_eof:
raise ValueError("Cannot scan past end of file")
c = self.__input[self.__index]
self.__index += 1
return c
def peek_one_byte(self) -> int:
if self.at_eof:
raise ValueError("Cannot peek past end of file")
return self.__input[self.__index]
def skip_one_byte(self) -> None:
if self.at_eof:
raise ValueError("Cannot skip past end of file")
self.__index += 1
def scan_while_any(self, scan_bytes: typing.Sequence[int]) -> bytes:
return self.__scan_while(lambda c: c in scan_bytes)
def scan_until_any(self, stop_bytes: typing.Sequence[int]) -> bytes:
return self.__scan_while(lambda c: c not in stop_bytes)
def skip_while_any(self, skip_bytes: typing.Sequence[int]) -> None:
self.__skip_while(lambda c: c in skip_bytes)
def skip_until_any(self, stop_bytes: typing.Sequence[int]) -> None:
self.__skip_while(lambda c: c not in stop_bytes)
def __scan_while(self, scan_predicate: typing.Callable[[int], bool]) -> bytes:
begin_index = self.__index
while not self.at_eof:
if not scan_predicate(self.__input[self.__index]):
break
self.__index += 1
end_index = self.__index
return self.__input[begin_index:end_index]
def __skip_while(self, skip_predicate: typing.Callable[[int], bool]) -> None:
while not self.at_eof:
if not skip_predicate(self.__input[self.__index]):
break
self.__index += 1
class _EnvironmentFileParser(_Scanner):
comment_characters = SystemdEnvironmentFile._comment_characters
escape_characters = SystemdEnvironmentFile._escape_characters
newline_characters = SystemdEnvironmentFile._newline_characters
quote_characters = SystemdEnvironmentFile._quote_characters
whitespace_characters = SystemdEnvironmentFile._whitespace_characters
def parse_entries(self) -> typing.List[typing.Tuple[bytes, bytes]]:
entries = []
while not self.at_eof:
entry = self.parse_entry()
if entry is not None:
entries.append(entry)
return entries
def parse_entry(self) -> typing.Optional[typing.Tuple[bytes, bytes]]:
self.skip_whitespace()
if self.at_eof:
return None
c = self.peek_one_byte()
if c in self.comment_characters:
self.parse_comment()
return None
elif c in self.newline_characters:
self.skip_one_byte()
return None
name = self.parse_entry_name_and_equal_sign()
if name is None:
return None
self.skip_whitespace()
value = self.parse_entry_value()
if not SystemdEnvironmentFile._is_valid_entry(name, value):
return None
return (name, value)
def parse_entry_name_and_equal_sign(self) -> typing.Optional[bytes]:
name = bytearray([self.scan_one_byte()])
name.extend(self.scan_until_any(b"=" + self.newline_characters))
if self.at_eof:
return None
c = self.scan_one_byte()
if c in self.newline_characters:
return None
assert c == b"="[0]
return bytes(name.rstrip(self.whitespace_characters))
def parse_entry_value(self) -> bytes:
value = bytearray()
self.parse_quoted_entry_value(out_value=value)
self.parse_unquoted_entry_value(out_value=value)
return bytes(value)
def parse_quoted_entry_value(self, out_value: bytearray) -> None:
while not self.at_eof:
c = self.peek_one_byte()
if c not in self.quote_characters:
return
terminating_quote_characters = bytes([c])
self.skip_one_byte()
while not self.at_eof:
scanned = self.scan_until_any(
self.escape_characters + terminating_quote_characters
)
out_value.extend(scanned)
if self.at_eof:
return
c = self.scan_one_byte()
if c in self.escape_characters:
if self.at_eof:
return
c = self.scan_one_byte()
if c not in self.newline_characters:
out_value.append(c)
elif c in terminating_quote_characters:
break
else:
raise AssertionError()
self.skip_whitespace()
def parse_unquoted_entry_value(self, out_value: bytearray) -> None:
while not self.at_eof:
scanned = self.scan_until_any(
self.escape_characters
+ self.newline_characters
+ self.whitespace_characters
)
out_value.extend(scanned)
if self.at_eof:
return
c = self.scan_one_byte()
if c in self.escape_characters:
if self.at_eof:
return
c = self.scan_one_byte()
if c not in self.newline_characters:
out_value.append(c)
elif c in self.newline_characters:
return
elif c in self.whitespace_characters:
scanned = self.scan_while_any(self.whitespace_characters)
is_trailing_whitespace = (
self.at_eof or self.peek_one_byte() in self.newline_characters
)
if is_trailing_whitespace:
return
out_value.append(c)
out_value.extend(scanned)
else:
raise AssertionError()
def parse_comment(self) -> None:
c = self.scan_one_byte()
assert c in self.comment_characters
while not self.at_eof:
self.skip_until_any(self.escape_characters + self.newline_characters)
if self.at_eof:
break
c = self.scan_one_byte()
if c in self.escape_characters:
if self.at_eof:
break
self.skip_one_byte()
elif c in self.newline_characters:
break
else:
raise AssertionError()
def skip_whitespace(self) -> None:
self.skip_while_any(self.whitespace_characters)
def _truncated_at_null_byte(data: bytes) -> bytes:
end_of_file_index = data.find(b"\x00")
if end_of_file_index == -1:
return data
return data[:end_of_file_index]
async def print_service_status_using_systemctl_for_diagnostics_async(
service_name: str, xdg_runtime_dir: str
) -> None:
systemctl_environment = dict(os.environ)
systemctl_environment["XDG_RUNTIME_DIR"] = xdg_runtime_dir
status_process = await asyncio.create_subprocess_exec(
"systemctl",
"--no-pager",
"--user",
"status",
"--",
service_name,
env=systemctl_environment,
)
_status_exit_code = await status_process.wait()
# Ignore the status exit code. We only run systemctl for improved
# diagnostics.
_status_exit_code
# Types of parameters for D-Bus methods and signals. For details, see D-Bus'
# documentation:
# https://dbus.freedesktop.org/doc/dbus-specification.html#type-system
DBusObjectPath = bytes
DBusString = bytes
DBusUint32 = int
class SystemdUserBus:
"""A communication channel with systemd.
See systemd's D-Bus documentation:
https://www.freedesktop.org/wiki/Software/systemd/dbus/
"""
_cleanups: contextlib.ExitStack
_dbus: "pystemd.dbuslib.DBus"
_manager: "pystemd.SDManager"
def __init__(self, xdg_runtime_dir: str) -> None:
if pystemd_import_error is not None:
raise pystemd_import_error
super().__init__()
self._cleanups = contextlib.ExitStack()
self._dbus = self._get_dbus(xdg_runtime_dir)
self._manager = pystemd.systemd1.manager.Manager(bus=self._dbus)
@staticmethod
def _get_dbus(xdg_runtime_dir: str) -> "pystemd.dbuslib.DBus":
# HACK(strager): pystemd.dbuslib.DBus(user_mode=True) fails with a
# connection timeout. 'SYSTEMCTL_FORCE_BUS=1 systemctl --user ...' also
# fails, and it seems to use the same C APIs as
# pystemd.dbuslib.DBus(user_mode=True). Work around the issue by doing
# what systemctl's internal bus_connect_user_systemd() function does
# [1].
#
# [1] https://github.com/systemd/systemd/blob/78a562ee4bcbc7b0e8b58b475ff656f646e95e40/src/shared/bus-util.c#L594
socket_path = pathlib.Path(xdg_runtime_dir) / "systemd" / "private"
return pystemd.dbuslib.DBusAddress(
b"unix:path=" + escape_dbus_address(bytes(socket_path)), peer_to_peer=True
)
def _process_queued_messages(self) -> None:
while True:
message = self._dbus.process()
if message.is_empty():
break
async def get_unit_active_state_async(self, unit_name: bytes) -> DBusString:
"""Query org.freedesktop.systemd1.Unit.ActiveState.
"""
def go() -> DBusString:
unit = pystemd.systemd1.unit.Unit(unit_name, bus=self._dbus)
unit.load()
active_state = _pystemd_dynamic(unit).Unit.ActiveState
assert isinstance(active_state, DBusString)
return active_state
return await self._run_in_executor_async(go)
async def get_service_result_async(self, service_name: bytes) -> DBusString:
"""Query org.freedesktop.systemd1.Service.Result.
"""
def go() -> DBusString:
unit = pystemd.systemd1.unit.Unit(service_name, bus=self._dbus)
unit.load()
result = _pystemd_dynamic(unit).Service.Result
assert isinstance(result, DBusString)
return result
return await self._run_in_executor_async(go)
async def start_service_and_wait_async(self, service_name: DBusString) -> None:
"""Start a service, waiting for it to successfully start.
If the service or the job fails, this method raises an exception.
"""
start_job = await self.start_unit_job_and_wait_until_job_completes_async(
service_name
)
logger.debug(f"Querying status of service {service_name!r}...")
(service_active_state, service_result) = await asyncio.gather(
self.get_unit_active_state_async(unit_name=service_name),
self.get_service_result_async(service_name=service_name),
)
logger.debug(
f"Service {service_name!r} has active state "
f"{service_active_state!r} and result {service_result!r}"
)
if not (
start_job.result == b"done"
and service_active_state == b"active"
and service_result == b"success"
):
raise SystemdServiceFailedToStartError(
service_name=service_name.decode(errors="replace"),
start_job_result=start_job.result.decode(errors="replace"),
service_active_state=service_active_state.decode(errors="replace"),
service_result=service_result.decode(errors="replace"),
)
async def start_unit_job_and_wait_until_job_completes_async(
self, unit_name: DBusString
) -> "JobRemovedSignal":
"""Call org.freedesktop.systemd1.Manager.StartUnit and wait for the
returned job to complete.
If the job fails, this method does *not* raise an exception.
"""
with await self.subscribe_to_job_removed_async() as job_removed_subscription:
logger.debug(f"Starting service {unit_name!r}...")
job_object_path = await self.start_unit_async(
name=unit_name, mode=b"replace"
)
logger.debug(f"Waiting for job {job_object_path!r} to finish...")
removed_job = await job_removed_subscription.wait_until_signal_async(
lambda removed_job: removed_job.job == job_object_path
)
logger.debug(
f"Job {job_object_path!r} for {unit_name!r} finished "
f"with result {removed_job.result!r}"
)
return removed_job
async def start_unit_async(
self, name: DBusString, mode: DBusString
) -> DBusObjectPath:
"""Call org.freedesktop.systemd1.Manager.StartUnit.
"""
def go() -> DBusObjectPath:
path = _pystemd_dynamic(self._manager).Manager.StartUnit(name, mode)
assert isinstance(path, DBusObjectPath)
return path
return await self._run_in_executor_async(go)
async def subscribe_to_job_removed_async(
self,
) -> "SystemdSignalSubscription[JobRemovedSignal]":
"""Subscribe to org.freedesktop.systemd1.Manager.JobRemoved.
"""
event_loop = asyncio.get_running_loop()
subscription: SystemdSignalSubscription[
JobRemovedSignal
] = SystemdSignalSubscription(self._manager)
await asyncio.gather(
self._subscribe_async(),
self._run_in_executor_async(
lambda: self._dbus.match_signal(
sender=b"org.freedesktop.systemd1",
path=b"/org/freedesktop/systemd1",
interface=b"org.freedesktop.systemd1.Manager",
member=b"JobRemoved",
callback=self._on_job_removed,
userdata=(subscription, event_loop),
)
),
)
return subscription
@staticmethod
def _on_job_removed(
msg: "pystemd.dbuslib.DbusMessage",
error: typing.Optional[Exception],
userdata: typing.Any,
) -> None:
"""Handle a org.freedesktop.systemd1.Manager.JobRemoved signal.
"""
(subscription, event_loop) = userdata
assert isinstance(subscription, DBusSignalSubscription)
assert isinstance(event_loop, asyncio.AbstractEventLoop)
if error is not None:
event_loop.create_task(subscription.post_exception_async(error))
return
try:
msg.process_reply(False)
(id, job, unit, result) = msg.body
event_loop.create_task(
subscription.post_signal_async(
JobRemovedSignal(id=id, job=job, unit=unit, result=result)
)
)
except Exception as e:
event_loop.create_task(subscription.post_exception_async(e))
async def _subscribe_async(self) -> None:
"""Call org.freedesktop.systemd1.Manager.Subscribe.
"""
def go() -> None:
_pystemd_dynamic(self._manager).Manager.Subscribe()
await self._run_in_executor_async(go)
async def _run_in_executor_async(self, func: typing.Callable[[], "_T"]) -> "_T":
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=func
)
async def __aenter__(self):
loop = asyncio.get_running_loop()
self._cleanups.enter_context(self._dbus)
self._manager.load()
dbus_fd = self._dbus.get_fd()
loop.add_reader(dbus_fd, self._process_queued_messages)
self._cleanups.callback(lambda: loop.remove_reader(dbus_fd))
return self
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> None:
self._cleanups.close()
_DBusSignal = typing.TypeVar("_DBusSignal")
class DBusSignalSubscription(typing.Generic[_DBusSignal]):
_queue: "asyncio.Queue[typing.Union[_DBusSignal, Exception]]"
def __init__(self) -> None:
super().__init__()
self._queue = asyncio.Queue()
async def post_signal_async(self, signal: _DBusSignal) -> None:
await self._queue.put(signal)
async def post_exception_async(self, exception: Exception) -> None:
await self._queue.put(exception)
async def get_next_signal_async(self) -> _DBusSignal:
signal_or_exception = await self._queue.get()
if isinstance(signal_or_exception, Exception):
raise signal_or_exception
return signal_or_exception
async def wait_until_signal_async(
self, predicate: typing.Callable[[_DBusSignal], bool]
) -> _DBusSignal:
while True:
signal = await self.get_next_signal_async()
if predicate(signal):
return signal
def unsubscribe(self) -> None:
# TODO(strager): Add an API in pystemd to cancel a match_signal request.
logger.debug("Leaking D-Bus signal subscription")
def __enter__(self):
return self
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> None:
self.unsubscribe()
class SystemdSignalSubscription(DBusSignalSubscription[_DBusSignal]):
_manager: "pystemd.SDManager"
def __init__(self, manager: "pystemd.SDManager") -> None:
super().__init__()
self._manager = manager
def unsubscribe(self) -> None:
super().unsubscribe()
_pystemd_dynamic(self._manager).Manager.Unsubscribe()
class JobRemovedSignal(typing.NamedTuple):
"""A org.freedesktop.systemd1.Manager.JobRemoved signal.
https://www.freedesktop.org/wiki/Software/systemd/dbus/#signals
"""
id: DBusUint32
job: DBusObjectPath
unit: DBusString
result: DBusString
def _pystemd_dynamic(
object: typing.Union[
"pystemd.systemd1.manager.Manager", "pystemd.systemd1.unit.Unit"
]
) -> typing.Any:
"""Silence mypy and Pyre for the given dynamically-typed pystemd object.
TODO(strager): Add type annotations to pystemd.
"""
return typing.cast(typing.Any, object)
def escape_dbus_address(input: bytes) -> bytes:
"""Escape a string for inclusion in DBUS_SESSION_BUS_ADDRESS.
For more details, see the D-Bus specification:
https://dbus.freedesktop.org/doc/dbus-specification.html#addresses
"""
whitelist = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-./\\"
scanner = _Scanner(input)
result_pieces = []
while not scanner.at_eof:
unescaped_bytes = scanner.scan_while_any(whitelist)
result_pieces.append(unescaped_bytes)
if scanner.at_eof:
break
byte_to_escape = scanner.scan_one_byte()
result_pieces.append(f"%{byte_to_escape:02x}".encode())
return b"".join(result_pieces)
if pystemd_import_error is None:
SystemdConnectionRefusedError = (
pystemd.dbusexc.DBusConnectionRefusedError # pyre-ignore[16]: T32805591
)
SystemdFileNotFoundError = (
pystemd.dbusexc.DBusFileNotFoundError # pyre-ignore[16]: T32805591
)
else:
SystemdConnectionRefusedError = Exception
SystemdFileNotFoundError = Exception
class SystemdServiceFailedToStartError(Exception):
def __init__(
self,
start_job_result: str,
service_active_state: str,
service_name: str,
service_result: str,
) -> None:
super().__init__()
self.service_active_state = service_active_state
self.service_name = service_name
self.service_result = service_result
self.start_job_result = start_job_result
def __str__(self) -> str:
return (
f"Starting the {self.service_name} systemd service failed "
f"(reason: {self.service_result})"
)