sapling/eden/fs/cli/systemd.py
John Reese 737849c9dd apply black 20.8b1 formatting update
Summary:
allow-large-files

black_any_style

Reviewed By: zertosh

Differential Revision: D24325133

fbshipit-source-id: b4afe80d1e8b2bc993f4b8e3822c02964df47462
2020-10-14 20:21:52 -07:00

764 lines
26 KiB
Python

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This software may be used and distributed according to the terms of the
# GNU General Public License version 2.
import asyncio
import contextlib
import logging
import os
import pathlib
import re
import subprocess
import types
import typing
pystemd_import_error = None
try:
import pystemd
import pystemd.dbusexc # pyre-ignore[21]: T32805591
import pystemd.dbuslib
import pystemd.systemd1.manager
import pystemd.systemd1.unit
except ModuleNotFoundError as e:
pystemd_import_error = e
logger = logging.getLogger(__name__)
_T = typing.TypeVar("_T")
def edenfs_systemd_service_name(eden_dir: pathlib.Path) -> str:
assert isinstance(eden_dir, pathlib.PosixPath)
instance_name = systemd_escape_path(eden_dir)
return f"fb-edenfs@{instance_name}.service"
def systemd_escape_path(path: pathlib.PurePosixPath) -> str:
"""Escape a path for inclusion in a systemd unit name.
See the 'systemd-escape --path' command for details.
"""
if not path.is_absolute():
raise ValueError("systemd_escape_path can only escape absolute paths")
if ".." in path.parts:
raise ValueError(
"systemd_escape_path can only escape paths without '..' components"
)
stdout: bytes = subprocess.check_output(
["systemd-escape", "--path", "--", str(path)]
)
return stdout.decode("utf-8").rstrip("\n")
class EdenFSSystemdServiceConfig:
__eden_dir: pathlib.Path
__edenfs_executable_path: pathlib.Path
__extra_edenfs_arguments: typing.List[str]
def __init__(
self,
eden_dir: pathlib.Path,
edenfs_executable_path: pathlib.Path,
extra_edenfs_arguments: typing.Sequence[str],
) -> None:
super().__init__()
self.__eden_dir = eden_dir
self.__edenfs_executable_path = edenfs_executable_path
self.__extra_edenfs_arguments = list(extra_edenfs_arguments)
@property
def config_file_path(self) -> pathlib.Path:
return self.__eden_dir / "systemd.conf"
@property
def startup_log_file_path(self) -> pathlib.Path:
# TODO(T33122320): Move this into <eden_dir>/logs/.
return self.__eden_dir / "startup.log"
def write_config_file(self) -> None:
variables = {
b"EDENFS_EXECUTABLE_PATH": bytes(self.__edenfs_executable_path),
b"EDENFS_EXTRA_ARGUMENTS": self.__escape_argument_list(
self.__extra_edenfs_arguments
),
}
self.config_file_path.parent.mkdir(parents=True, exist_ok=True)
self.config_file_path.write_bytes(SystemdEnvironmentFile.dumps(variables))
@staticmethod
def __escape_argument_list(arguments: typing.Sequence[str]) -> bytes:
for argument in arguments:
if "\n" in arguments:
raise ValueError(
f"Newlines in arguments are not supported\nArgument: {argument!r}"
)
return b"\n".join(arg.encode("utf-8") for arg in arguments)
class SystemdEnvironmentFile:
_comment_characters = b"#;"
_escape_characters = b"\\"
_name_characters = (
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_"
)
_newline_characters = b"\n\r"
_quote_characters = b"'\""
_whitespace_characters = b" \t"
def __init__(self, entries: typing.Sequence[typing.Tuple[bytes, bytes]]) -> None:
super().__init__()
self.__entries = list(entries)
@classmethod
def loads(cls, content: bytes) -> "SystemdEnvironmentFile":
content = _truncated_at_null_byte(content)
entries = _EnvironmentFileParser(content).parse_entries()
return cls(entries=entries)
@classmethod
def dumps(cls, variables: typing.Mapping[bytes, bytes]) -> bytes:
output = bytearray()
for name, value in variables.items():
cls._validate_entry(name, value)
output.extend(name)
output.extend(b"=")
output.extend(cls.__escape_value(value))
output.extend(b"\n")
return bytes(output)
@staticmethod
def __escape_value(value: bytes) -> bytes:
return (
b'"'
+ re.sub(b'[\\\\"]', lambda match: b"\\" + match.group(0), value)
+ b'"'
)
@classmethod
def _is_valid_entry(cls, name: bytes, value: bytes) -> bool:
try:
cls._validate_entry(name, value)
return True
except (VariableNameError, VariableValueError):
return False
@classmethod
def _validate_entry(cls, name: bytes, value: bytes) -> None:
if not name:
raise VariableNameError("Variables must have a non-empty name")
if name[0:1].isdigit():
raise VariableNameError("Variable names must not begin with a digit")
for c in name:
if c in cls._whitespace_characters:
raise VariableNameError("Variable names must not contain whitespace")
if c in cls._newline_characters:
raise VariableNameError(
"Variable names must not contain any newline characters"
)
if c < 0x20:
raise VariableNameError(
f"Variable names must not contain any control characters"
)
if c < 0x80 and c not in cls._name_characters:
offending_character = bytes([c]).decode("utf-8")
raise VariableNameError(
f"Variable names must not contain '{offending_character}'"
)
for c in value:
if c in b"\r":
raise VariableValueError(
"Variable values must not contain carriage returns"
)
if c < 0x20 and c not in b"\n\t":
raise VariableValueError(
"Variable values must not contain any control characters"
)
@property
def entries(self) -> typing.List[typing.Tuple[bytes, bytes]]:
return self.__entries
class VariableNameError(ValueError):
pass
class VariableValueError(ValueError):
pass
class _Scanner:
def __init__(self, input: bytes) -> None:
super().__init__()
self.__input = input
self.__index = 0
@property
def at_eof(self) -> bool:
return self.__index == len(self.__input)
def scan_one_byte(self) -> int:
if self.at_eof:
raise ValueError("Cannot scan past end of file")
c = self.__input[self.__index]
self.__index += 1
return c
def peek_one_byte(self) -> int:
if self.at_eof:
raise ValueError("Cannot peek past end of file")
return self.__input[self.__index]
def skip_one_byte(self) -> None:
if self.at_eof:
raise ValueError("Cannot skip past end of file")
self.__index += 1
def scan_while_any(self, scan_bytes: typing.Sequence[int]) -> bytes:
return self.__scan_while(lambda c: c in scan_bytes)
def scan_until_any(self, stop_bytes: typing.Sequence[int]) -> bytes:
return self.__scan_while(lambda c: c not in stop_bytes)
def skip_while_any(self, skip_bytes: typing.Sequence[int]) -> None:
self.__skip_while(lambda c: c in skip_bytes)
def skip_until_any(self, stop_bytes: typing.Sequence[int]) -> None:
self.__skip_while(lambda c: c not in stop_bytes)
def __scan_while(self, scan_predicate: typing.Callable[[int], bool]) -> bytes:
begin_index = self.__index
while not self.at_eof:
if not scan_predicate(self.__input[self.__index]):
break
self.__index += 1
end_index = self.__index
return self.__input[begin_index:end_index]
def __skip_while(self, skip_predicate: typing.Callable[[int], bool]) -> None:
while not self.at_eof:
if not skip_predicate(self.__input[self.__index]):
break
self.__index += 1
class _EnvironmentFileParser(_Scanner):
comment_characters = SystemdEnvironmentFile._comment_characters
escape_characters = SystemdEnvironmentFile._escape_characters
newline_characters = SystemdEnvironmentFile._newline_characters
quote_characters = SystemdEnvironmentFile._quote_characters
whitespace_characters = SystemdEnvironmentFile._whitespace_characters
def parse_entries(self) -> typing.List[typing.Tuple[bytes, bytes]]:
entries = []
while not self.at_eof:
entry = self.parse_entry()
if entry is not None:
entries.append(entry)
return entries
def parse_entry(self) -> typing.Optional[typing.Tuple[bytes, bytes]]:
self.skip_whitespace()
if self.at_eof:
return None
c = self.peek_one_byte()
if c in self.comment_characters:
self.parse_comment()
return None
elif c in self.newline_characters:
self.skip_one_byte()
return None
name = self.parse_entry_name_and_equal_sign()
if name is None:
return None
self.skip_whitespace()
value = self.parse_entry_value()
if not SystemdEnvironmentFile._is_valid_entry(name, value):
return None
return (name, value)
def parse_entry_name_and_equal_sign(self) -> typing.Optional[bytes]:
name = bytearray([self.scan_one_byte()])
name.extend(self.scan_until_any(b"=" + self.newline_characters))
if self.at_eof:
return None
c = self.scan_one_byte()
if c in self.newline_characters:
return None
assert c == b"="[0]
return bytes(name.rstrip(self.whitespace_characters))
def parse_entry_value(self) -> bytes:
value = bytearray()
self.parse_quoted_entry_value(out_value=value)
self.parse_unquoted_entry_value(out_value=value)
return bytes(value)
def parse_quoted_entry_value(self, out_value: bytearray) -> None:
while not self.at_eof:
c = self.peek_one_byte()
if c not in self.quote_characters:
return
terminating_quote_characters = bytes([c])
self.skip_one_byte()
while not self.at_eof:
scanned = self.scan_until_any(
self.escape_characters + terminating_quote_characters
)
out_value.extend(scanned)
if self.at_eof:
return
c = self.scan_one_byte()
if c in self.escape_characters:
if self.at_eof:
return
c = self.scan_one_byte()
if c not in self.newline_characters:
out_value.append(c)
elif c in terminating_quote_characters:
break
else:
raise AssertionError()
self.skip_whitespace()
def parse_unquoted_entry_value(self, out_value: bytearray) -> None:
while not self.at_eof:
scanned = self.scan_until_any(
self.escape_characters
+ self.newline_characters
+ self.whitespace_characters
)
out_value.extend(scanned)
if self.at_eof:
return
c = self.scan_one_byte()
if c in self.escape_characters:
if self.at_eof:
return
c = self.scan_one_byte()
if c not in self.newline_characters:
out_value.append(c)
elif c in self.newline_characters:
return
elif c in self.whitespace_characters:
scanned = self.scan_while_any(self.whitespace_characters)
is_trailing_whitespace = (
self.at_eof or self.peek_one_byte() in self.newline_characters
)
if is_trailing_whitespace:
return
out_value.append(c)
out_value.extend(scanned)
else:
raise AssertionError()
def parse_comment(self) -> None:
c = self.scan_one_byte()
assert c in self.comment_characters
while not self.at_eof:
self.skip_until_any(self.escape_characters + self.newline_characters)
if self.at_eof:
break
c = self.scan_one_byte()
if c in self.escape_characters:
if self.at_eof:
break
self.skip_one_byte()
elif c in self.newline_characters:
break
else:
raise AssertionError()
def skip_whitespace(self) -> None:
self.skip_while_any(self.whitespace_characters)
def _truncated_at_null_byte(data: bytes) -> bytes:
end_of_file_index = data.find(b"\x00")
if end_of_file_index == -1:
return data
return data[:end_of_file_index]
async def print_service_status_using_systemctl_for_diagnostics_async(
service_name: str, xdg_runtime_dir: str
) -> None:
systemctl_environment = dict(os.environ)
systemctl_environment["XDG_RUNTIME_DIR"] = xdg_runtime_dir
status_process = await asyncio.create_subprocess_exec(
"systemctl",
"--no-pager",
"--user",
"status",
"--",
service_name,
env=systemctl_environment,
)
_status_exit_code = await status_process.wait()
# Ignore the status exit code. We only run systemctl for improved
# diagnostics.
_status_exit_code
# Types of parameters for D-Bus methods and signals. For details, see D-Bus'
# documentation:
# https://dbus.freedesktop.org/doc/dbus-specification.html#type-system
DBusObjectPath = bytes
DBusString = bytes
DBusUint32 = int
class SystemdUserBus:
"""A communication channel with systemd.
See systemd's D-Bus documentation:
https://www.freedesktop.org/wiki/Software/systemd/dbus/
"""
_cleanups: contextlib.ExitStack
_dbus: "pystemd.dbuslib.DBus"
_manager: "pystemd.SDManager"
def __init__(self, xdg_runtime_dir: str) -> None:
if pystemd_import_error is not None:
raise pystemd_import_error
super().__init__()
self._cleanups = contextlib.ExitStack()
self._dbus = self._get_dbus(xdg_runtime_dir)
self._manager = pystemd.systemd1.manager.Manager(bus=self._dbus)
@staticmethod
def _get_dbus(xdg_runtime_dir: str) -> "pystemd.dbuslib.DBus":
# HACK(strager): pystemd.dbuslib.DBus(user_mode=True) fails with a
# connection timeout. 'SYSTEMCTL_FORCE_BUS=1 systemctl --user ...' also
# fails, and it seems to use the same C APIs as
# pystemd.dbuslib.DBus(user_mode=True). Work around the issue by doing
# what systemctl's internal bus_connect_user_systemd() function does
# [1].
#
# [1] https://github.com/systemd/systemd/blob/78a562ee4bcbc7b0e8b58b475ff656f646e95e40/src/shared/bus-util.c#L594
socket_path = pathlib.Path(xdg_runtime_dir) / "systemd" / "private"
return pystemd.dbuslib.DBusAddress(
b"unix:path=" + escape_dbus_address(bytes(socket_path)), peer_to_peer=True
)
def _process_queued_messages(self) -> None:
while True:
message = self._dbus.process()
if message.is_empty():
break
async def get_unit_active_state_async(self, unit_name: bytes) -> DBusString:
"""Query org.freedesktop.systemd1.Unit.ActiveState."""
def go() -> DBusString:
unit = pystemd.systemd1.unit.Unit(unit_name, bus=self._dbus)
unit.load()
active_state = _pystemd_dynamic(unit).Unit.ActiveState
assert isinstance(active_state, DBusString)
return active_state
return await self._run_in_executor_async(go)
async def get_service_result_async(self, service_name: bytes) -> DBusString:
"""Query org.freedesktop.systemd1.Service.Result."""
def go() -> DBusString:
unit = pystemd.systemd1.unit.Unit(service_name, bus=self._dbus)
unit.load()
result = _pystemd_dynamic(unit).Service.Result
assert isinstance(result, DBusString)
return result
return await self._run_in_executor_async(go)
async def start_service_and_wait_async(self, service_name: DBusString) -> None:
"""Start a service, waiting for it to successfully start.
If the service or the job fails, this method raises an exception.
"""
start_job = await self.start_unit_job_and_wait_until_job_completes_async(
service_name
)
logger.debug(f"Querying status of service {service_name!r}...")
(service_active_state, service_result) = await asyncio.gather(
self.get_unit_active_state_async(unit_name=service_name),
self.get_service_result_async(service_name=service_name),
)
logger.debug(
f"Service {service_name!r} has active state "
f"{service_active_state!r} and result {service_result!r}"
)
if not (
start_job.result == b"done"
and service_active_state == b"active"
and service_result == b"success"
):
raise SystemdServiceFailedToStartError(
service_name=service_name.decode(errors="replace"),
start_job_result=start_job.result.decode(errors="replace"),
service_active_state=service_active_state.decode(errors="replace"),
service_result=service_result.decode(errors="replace"),
)
async def start_unit_job_and_wait_until_job_completes_async(
self, unit_name: DBusString
) -> "JobRemovedSignal":
"""Call org.freedesktop.systemd1.Manager.StartUnit and wait for the
returned job to complete.
If the job fails, this method does *not* raise an exception.
"""
with await self.subscribe_to_job_removed_async() as job_removed_subscription:
logger.debug(f"Starting service {unit_name!r}...")
job_object_path = await self.start_unit_async(
name=unit_name, mode=b"replace"
)
logger.debug(f"Waiting for job {job_object_path!r} to finish...")
removed_job = await job_removed_subscription.wait_until_signal_async(
lambda removed_job: removed_job.job == job_object_path
)
logger.debug(
f"Job {job_object_path!r} for {unit_name!r} finished "
f"with result {removed_job.result!r}"
)
return removed_job
async def start_unit_async(
self, name: DBusString, mode: DBusString
) -> DBusObjectPath:
"""Call org.freedesktop.systemd1.Manager.StartUnit."""
def go() -> DBusObjectPath:
path = _pystemd_dynamic(self._manager).Manager.StartUnit(name, mode)
assert isinstance(path, DBusObjectPath)
return path
return await self._run_in_executor_async(go)
async def subscribe_to_job_removed_async(
self,
) -> "SystemdSignalSubscription[JobRemovedSignal]":
"""Subscribe to org.freedesktop.systemd1.Manager.JobRemoved."""
event_loop = asyncio.get_running_loop()
subscription: SystemdSignalSubscription[
JobRemovedSignal
] = SystemdSignalSubscription(self._manager)
await asyncio.gather(
self._subscribe_async(),
self._run_in_executor_async(
lambda: self._dbus.match_signal(
sender=b"org.freedesktop.systemd1",
path=b"/org/freedesktop/systemd1",
interface=b"org.freedesktop.systemd1.Manager",
member=b"JobRemoved",
callback=self._on_job_removed,
userdata=(subscription, event_loop),
)
),
)
return subscription
@staticmethod
def _on_job_removed(
msg: "pystemd.dbuslib.DbusMessage",
error: typing.Optional[Exception],
userdata: typing.Any,
) -> None:
"""Handle a org.freedesktop.systemd1.Manager.JobRemoved signal."""
(subscription, event_loop) = userdata
assert isinstance(subscription, DBusSignalSubscription)
assert isinstance(event_loop, asyncio.AbstractEventLoop)
if error is not None:
event_loop.create_task(subscription.post_exception_async(error))
return
try:
msg.process_reply(False)
(id, job, unit, result) = msg.body
event_loop.create_task(
subscription.post_signal_async(
JobRemovedSignal(id=id, job=job, unit=unit, result=result)
)
)
except Exception as e:
event_loop.create_task(subscription.post_exception_async(e))
async def _subscribe_async(self) -> None:
"""Call org.freedesktop.systemd1.Manager.Subscribe."""
def go() -> None:
_pystemd_dynamic(self._manager).Manager.Subscribe()
await self._run_in_executor_async(go)
async def _run_in_executor_async(self, func: typing.Callable[[], "_T"]) -> "_T":
return await asyncio.get_running_loop().run_in_executor(
executor=None, func=func
)
async def __aenter__(self):
loop = asyncio.get_running_loop()
self._cleanups.enter_context(self._dbus)
self._manager.load()
dbus_fd = self._dbus.get_fd()
loop.add_reader(dbus_fd, self._process_queued_messages)
self._cleanups.callback(lambda: loop.remove_reader(dbus_fd))
return self
async def __aexit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> None:
self._cleanups.close()
_DBusSignal = typing.TypeVar("_DBusSignal")
class DBusSignalSubscription(typing.Generic[_DBusSignal]):
_queue: "asyncio.Queue[typing.Union[_DBusSignal, Exception]]"
def __init__(self) -> None:
super().__init__()
self._queue = asyncio.Queue()
async def post_signal_async(self, signal: _DBusSignal) -> None:
await self._queue.put(signal)
async def post_exception_async(self, exception: Exception) -> None:
await self._queue.put(exception)
async def get_next_signal_async(self) -> _DBusSignal:
signal_or_exception = await self._queue.get()
if isinstance(signal_or_exception, Exception):
raise signal_or_exception
return signal_or_exception
async def wait_until_signal_async(
self, predicate: typing.Callable[[_DBusSignal], bool]
) -> _DBusSignal:
while True:
signal = await self.get_next_signal_async()
if predicate(signal):
return signal
def unsubscribe(self) -> None:
# TODO(strager): Add an API in pystemd to cancel a match_signal request.
logger.debug("Leaking D-Bus signal subscription")
def __enter__(self):
return self
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exc_value: typing.Optional[BaseException],
traceback: typing.Optional[types.TracebackType],
) -> None:
self.unsubscribe()
class SystemdSignalSubscription(DBusSignalSubscription[_DBusSignal]):
_manager: "pystemd.SDManager"
def __init__(self, manager: "pystemd.SDManager") -> None:
super().__init__()
self._manager = manager
def unsubscribe(self) -> None:
super().unsubscribe()
_pystemd_dynamic(self._manager).Manager.Unsubscribe()
class JobRemovedSignal(typing.NamedTuple):
"""A org.freedesktop.systemd1.Manager.JobRemoved signal.
https://www.freedesktop.org/wiki/Software/systemd/dbus/#signals
"""
id: DBusUint32
job: DBusObjectPath
unit: DBusString
result: DBusString
def _pystemd_dynamic(
object: typing.Union[
"pystemd.systemd1.manager.Manager", "pystemd.systemd1.unit.Unit"
]
) -> typing.Any:
"""Silence mypy and Pyre for the given dynamically-typed pystemd object.
TODO(strager): Add type annotations to pystemd.
"""
return typing.cast(typing.Any, object)
def escape_dbus_address(input: bytes) -> bytes:
"""Escape a string for inclusion in DBUS_SESSION_BUS_ADDRESS.
For more details, see the D-Bus specification:
https://dbus.freedesktop.org/doc/dbus-specification.html#addresses
"""
whitelist = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-./\\"
scanner = _Scanner(input)
result_pieces = []
while not scanner.at_eof:
unescaped_bytes = scanner.scan_while_any(whitelist)
result_pieces.append(unescaped_bytes)
if scanner.at_eof:
break
byte_to_escape = scanner.scan_one_byte()
result_pieces.append(f"%{byte_to_escape:02x}".encode())
return b"".join(result_pieces)
if pystemd_import_error is None:
SystemdConnectionRefusedError = (
pystemd.dbusexc.DBusConnectionRefusedError # pyre-ignore[16]: T32805591
)
SystemdFileNotFoundError = (
pystemd.dbusexc.DBusFileNotFoundError # pyre-ignore[16]: T32805591
)
else:
SystemdConnectionRefusedError = Exception
SystemdFileNotFoundError = Exception
class SystemdServiceFailedToStartError(Exception):
def __init__(
self,
start_job_result: str,
service_active_state: str,
service_name: str,
service_result: str,
) -> None:
super().__init__()
self.service_active_state = service_active_state
self.service_name = service_name
self.service_result = service_result
self.start_job_result = start_job_result
def __str__(self) -> str:
return (
f"Starting the {self.service_name} systemd service failed "
f"(reason: {self.service_result})"
)