linting fixes

This commit is contained in:
Jörg Thalheim 2023-11-29 12:22:04 +01:00
parent c25b5a3086
commit 2d71dd8d94
5 changed files with 94 additions and 97 deletions

View File

@ -1,5 +1,6 @@
import argparse import argparse
import asyncio import asyncio
import contextlib
import json import json
import logging import logging
import multiprocessing import multiprocessing
@ -13,10 +14,12 @@ from abc import ABC
from asyncio import Queue, TaskGroup from asyncio import Queue, TaskGroup
from asyncio.subprocess import Process from asyncio.subprocess import Process
from collections import defaultdict from collections import defaultdict
from collections.abc import AsyncIterator, Coroutine
from contextlib import AsyncExitStack, asynccontextmanager from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import IO, Any, AsyncIterator, Coroutine, DefaultDict, NoReturn, TypeVar from types import TracebackType
from typing import IO, Any, NoReturn, TypeVar
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,13 +38,18 @@ class Pipe:
def __enter__(self) -> "Pipe": def __enter__(self) -> "Pipe":
return self return self
def __exit__(self, _exc_type: Any, _exc_value: Any, _traceback: Any) -> None: def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.read_file.close() self.read_file.close()
self.write_file.close() self.write_file.close()
def nix_command(args: list[str]) -> list[str]: def nix_command(args: list[str]) -> list[str]:
return ["nix", "--experimental-features", "nix-command flakes"] + args return ["nix", "--experimental-features", "nix-command flakes", *args]
@dataclass @dataclass
@ -55,7 +63,7 @@ class Options:
systems: set[str] = field(default_factory=set) systems: set[str] = field(default_factory=set)
eval_max_memory_size: int = 4096 eval_max_memory_size: int = 4096
skip_cached: bool = False skip_cached: bool = False
eval_workers: int = multiprocessing.cpu_count() eval_workers: int = field(default_factory=multiprocessing.cpu_count)
max_jobs: int = 0 max_jobs: int = 0
retries: int = 0 retries: int = 0
debug: bool = False debug: bool = False
@ -76,9 +84,8 @@ def _maybe_remote(
cmd: list[str], remote: str | None, remote_ssh_options: list[str] cmd: list[str], remote: str | None, remote_ssh_options: list[str]
) -> list[str]: ) -> list[str]:
if remote: if remote:
return ["ssh", remote] + remote_ssh_options + ["--", shlex.join(cmd)] return ["ssh", remote, *remote_ssh_options, "--", shlex.join(cmd)]
else: return cmd
return cmd
def maybe_remote(cmd: list[str], opts: Options) -> list[str]: def maybe_remote(cmd: list[str], opts: Options) -> list[str]:
@ -245,7 +252,7 @@ async def parse_args(args: list[str]) -> Options:
else: else:
a.no_nom = shutil.which("nom") is None a.no_nom = shutil.which("nom") is None
if a.systems is None: if a.systems is None:
systems = set([nix_config.get("system", "")]) systems = {nix_config.get("system", "")}
else: else:
systems = set(a.systems.split(" ")) systems = set(a.systems.split(" "))
@ -280,7 +287,7 @@ def nix_flake_metadata(flake_url: str) -> dict[str, Any]:
] ]
) )
logger.info(f"run {shlex.join(cmd)}") logger.info(f"run {shlex.join(cmd)}")
proc = subprocess.run(cmd, stdout=subprocess.PIPE) proc = subprocess.run(cmd, stdout=subprocess.PIPE, check=True)
if proc.returncode != 0: if proc.returncode != 0:
die( die(
f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}" f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}"
@ -288,7 +295,7 @@ def nix_flake_metadata(flake_url: str) -> dict[str, Any]:
try: try:
data = json.loads(proc.stdout) data = json.loads(proc.stdout)
except Exception as e: except (json.JSONDecodeError, OSError) as e:
die( die(
f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout.decode('utf-8', 'replace')}" f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout.decode('utf-8', 'replace')}"
) )
@ -303,10 +310,7 @@ def is_path_input(node: dict[str, dict[str, str]]) -> bool:
def check_for_path_inputs(data: dict[str, Any]) -> bool: def check_for_path_inputs(data: dict[str, Any]) -> bool:
for node in data["locks"]["nodes"].values(): return any(is_path_input(node) for node in data["locks"]["nodes"].values())
if is_path_input(node):
return True
return False
def upload_sources(opts: Options) -> str: def upload_sources(opts: Options) -> str:
@ -333,7 +337,7 @@ def upload_sources(opts: Options) -> str:
path, path,
] ]
) )
proc = subprocess.run(cmd, stdout=subprocess.PIPE, env=env) proc = subprocess.run(cmd, stdout=subprocess.PIPE, env=env, check=False)
if proc.returncode != 0: if proc.returncode != 0:
die( die(
f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}" f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}"
@ -354,14 +358,14 @@ def upload_sources(opts: Options) -> str:
) )
print("run " + shlex.join(cmd)) print("run " + shlex.join(cmd))
logger.info("run %s", shlex.join(cmd)) logger.info("run %s", shlex.join(cmd))
proc = subprocess.run(cmd, stdout=subprocess.PIPE) proc = subprocess.run(cmd, stdout=subprocess.PIPE, check=False)
if proc.returncode != 0: if proc.returncode != 0:
die( die(
f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}" f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}"
) )
try: try:
return json.loads(proc.stdout)["path"] return json.loads(proc.stdout)["path"]
except Exception as e: except (json.JSONDecodeError, OSError) as e:
die( die(
f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout.decode('utf-8', 'replace')}" f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout.decode('utf-8', 'replace')}"
) )
@ -375,9 +379,9 @@ def nix_shell(packages: list[str]) -> list[str]:
"nix-command", "nix-command",
"--extra-experimental-features", "--extra-experimental-features",
"flakes", "flakes",
*packages,
"-c",
] ]
+ packages
+ ["-c"]
) )
@ -389,24 +393,21 @@ async def ensure_stop(
yield proc yield proc
finally: finally:
if proc.returncode is not None: if proc.returncode is not None:
return with contextlib.suppress(ProcessLookupError):
proc.send_signal(signal_no) proc.send_signal(signal_no)
try: try:
await asyncio.wait_for(proc.wait(), timeout=timeout) await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
print(f"Failed to stop process {shlex.join(cmd)}. Killing it.") print(f"Failed to stop process {shlex.join(cmd)}. Killing it.")
try: proc.kill()
proc.kill() await proc.wait()
except ProcessLookupError:
pass
await proc.wait()
@asynccontextmanager @asynccontextmanager
async def remote_temp_dir(opts: Options) -> AsyncIterator[str]: async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:
assert opts.remote assert opts.remote
ssh_cmd = ["ssh", opts.remote] + opts.remote_ssh_options + ["--"] ssh_cmd = ["ssh", opts.remote, *opts.remote_ssh_options, "--"]
cmd = ssh_cmd + ["mktemp", "-d"] cmd = [*ssh_cmd, "mktemp", "-d"]
proc = await asyncio.create_subprocess_exec(*cmd, stdout=subprocess.PIPE) proc = await asyncio.create_subprocess_exec(*cmd, stdout=subprocess.PIPE)
assert proc.stdout is not None assert proc.stdout is not None
line = await proc.stdout.readline() line = await proc.stdout.readline()
@ -419,7 +420,7 @@ async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:
try: try:
yield tempdir yield tempdir
finally: finally:
cmd = ssh_cmd + ["rm", "-rf", tempdir] cmd = [*ssh_cmd, "rm", "-rf", tempdir]
logger.info("run %s", shlex.join(cmd)) logger.info("run %s", shlex.join(cmd))
proc = await asyncio.create_subprocess_exec(*cmd) proc = await asyncio.create_subprocess_exec(*cmd)
await proc.wait() await proc.wait()
@ -443,7 +444,8 @@ async def nix_eval_jobs(stack: AsyncExitStack, opts: Options) -> AsyncIterator[P
str(opts.eval_workers), str(opts.eval_workers),
"--flake", "--flake",
f"{opts.flake_url}#{opts.flake_fragment}", f"{opts.flake_url}#{opts.flake_fragment}",
] + opts.options *opts.options,
]
if opts.skip_cached: if opts.skip_cached:
args.append("--check-cache-status") args.append("--check-cache-status")
if opts.remote: if opts.remote:
@ -462,7 +464,7 @@ async def nix_eval_jobs(stack: AsyncExitStack, opts: Options) -> AsyncIterator[P
@asynccontextmanager @asynccontextmanager
async def nix_output_monitor(pipe: Pipe, opts: Options) -> AsyncIterator[Process]: async def nix_output_monitor(pipe: Pipe, opts: Options) -> AsyncIterator[Process]:
cmd = maybe_remote(nix_shell(["nixpkgs#nix-output-monitor"]) + ["nom"], opts) cmd = maybe_remote([*nix_shell(["nixpkgs#nix-output-monitor"]), "nom"], opts)
proc = await asyncio.create_subprocess_exec(*cmd, stdin=pipe.read_file) proc = await asyncio.create_subprocess_exec(*cmd, stdin=pipe.read_file)
try: try:
yield proc yield proc
@ -471,10 +473,8 @@ async def nix_output_monitor(pipe: Pipe, opts: Options) -> AsyncIterator[Process
try: try:
pipe.write_file.close() pipe.write_file.close()
pipe.read_file.close() pipe.read_file.close()
try: with contextlib.suppress(ProcessLookupError):
proc.kill() proc.kill()
except ProcessLookupError:
pass
await proc.wait() await proc.wait()
finally: finally:
print("\033[?25h") print("\033[?25h")
@ -504,7 +504,7 @@ class Build:
async def nix_copy( async def nix_copy(
self, args: list[str], exit_stack: AsyncExitStack, opts: Options self, args: list[str], exit_stack: AsyncExitStack, opts: Options
) -> int: ) -> int:
cmd = maybe_remote(nix_command(["copy", "--log-format", "raw"] + args), opts) cmd = maybe_remote(nix_command(["copy", "--log-format", "raw", *args]), opts)
logger.debug("run %s", shlex.join(cmd)) logger.debug("run %s", shlex.join(cmd))
proc = await asyncio.create_subprocess_exec(*cmd) proc = await asyncio.create_subprocess_exec(*cmd)
await exit_stack.enter_async_context(ensure_stop(proc, cmd)) await exit_stack.enter_async_context(ensure_stop(proc, cmd))
@ -514,8 +514,14 @@ class Build:
if not opts.copy_to: if not opts.copy_to:
return 0 return 0
cmd = nix_command( cmd = nix_command(
["copy", "--log-format", "raw", "--to", opts.copy_to] [
+ list(self.outputs.values()) "copy",
"--log-format",
"raw",
"--to",
opts.copy_to,
*list(self.outputs.values()),
]
) )
cmd = maybe_remote(cmd, opts) cmd = maybe_remote(cmd, opts)
logger.debug("run %s", shlex.join(cmd)) logger.debug("run %s", shlex.join(cmd))
@ -534,8 +540,8 @@ class Build:
"--no-check-sigs", "--no-check-sigs",
"--from", "--from",
opts.remote_url, opts.remote_url,
*list(self.outputs.values()),
] ]
+ list(self.outputs.values())
) )
logger.debug("run %s", shlex.join(cmd)) logger.debug("run %s", shlex.join(cmd))
env = os.environ.copy() env = os.environ.copy()
@ -590,11 +596,7 @@ class QueueWithContext(Queue[T]):
async def nix_build( async def nix_build(
attr: str, installable: str, stderr: IO[Any] | None, opts: Options attr: str, installable: str, stderr: IO[Any] | None, opts: Options
) -> AsyncIterator[Process]: ) -> AsyncIterator[Process]:
args = [ args = ["nix-build", installable, "--keep-going", *opts.options]
"nix-build",
installable,
"--keep-going",
] + opts.options
if opts.no_link: if opts.no_link:
args += ["--no-link"] args += ["--no-link"]
else: else:
@ -609,10 +611,8 @@ async def nix_build(
try: try:
yield proc yield proc
finally: finally:
try: with contextlib.suppress(ProcessLookupError):
proc.kill() proc.kill()
except ProcessLookupError:
pass
@dataclass @dataclass
@ -756,7 +756,7 @@ async def run(stack: AsyncExitStack, opts: Options) -> int:
output_monitor: Process | None = None output_monitor: Process | None = None
if output_monitor_future: if output_monitor_future:
output_monitor = await output_monitor_future output_monitor = await output_monitor_future
failures: DefaultDict[type, list[Failure]] = defaultdict(list) failures: defaultdict[type, list[Failure]] = defaultdict(list)
build_queue: QueueWithContext[Job | StopTask] = QueueWithContext() build_queue: QueueWithContext[Job | StopTask] = QueueWithContext()
upload_queue: QueueWithContext[Build | StopTask] = QueueWithContext() upload_queue: QueueWithContext[Build | StopTask] = QueueWithContext()
download_queue: QueueWithContext[Build | StopTask] = QueueWithContext() download_queue: QueueWithContext[Build | StopTask] = QueueWithContext()

View File

@ -1,25 +1,29 @@
import contextlib
import os import os
import signal import signal
import subprocess import subprocess
from typing import IO, Any, Dict, Iterator, List, Union from collections.abc import Iterator
from typing import IO, Any
import pytest import pytest
_FILE = Union[None, int, IO[Any]] _FILE = None | int | IO[Any]
class Command: class Command:
def __init__(self) -> None: def __init__(self) -> None:
self.processes: List[subprocess.Popen[str]] = [] self.processes: list[subprocess.Popen[str]] = []
def run( def run(
self, self,
command: List[str], command: list[str],
extra_env: Dict[str, str] = {}, extra_env: dict[str, str] | None = None,
stdin: _FILE = None, stdin: _FILE = None,
stdout: _FILE = None, stdout: _FILE = None,
stderr: _FILE = None, stderr: _FILE = None,
) -> subprocess.Popen[str]: ) -> subprocess.Popen[str]:
if extra_env is None:
extra_env = {}
env = os.environ.copy() env = os.environ.copy()
env.update(extra_env) env.update(extra_env)
# We start a new session here so that we can than more reliably kill all childs as well # We start a new session here so that we can than more reliably kill all childs as well
@ -40,13 +44,11 @@ class Command:
# We just kill all processes as quickly as possible because we don't # We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts. # care about corrupted state and want to make tests fasts.
for p in reversed(self.processes): for p in reversed(self.processes):
try: with contextlib.suppress(OSError):
os.killpg(os.getpgid(p.pid), signal.SIGKILL) os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture @pytest.fixture()
def command() -> Iterator[Command]: def command() -> Iterator[Command]:
""" """
Starts a background command. The process is automatically terminated in the end. Starts a background command. The process is automatically terminated in the end.

View File

@ -1,11 +1,7 @@
#!/usr/bin/env python3
import socket import socket
import pytest import pytest
NEXT_PORT = 10000
def check_port(port: int) -> bool: def check_port(port: int) -> bool:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM) tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -14,33 +10,33 @@ def check_port(port: int) -> bool:
try: try:
tcp.bind(("127.0.0.1", port)) tcp.bind(("127.0.0.1", port))
udp.bind(("127.0.0.1", port)) udp.bind(("127.0.0.1", port))
return True
except OSError: except OSError:
return False return False
else:
return True
def check_port_range(port_range: range) -> bool: def check_port_range(port_range: range) -> bool:
for port in port_range: return all(check_port(port) for port in port_range)
if not check_port(port):
return False
return True
class Ports: class Ports:
NEXT_PORT = 10000
def allocate(self, num: int) -> int: def allocate(self, num: int) -> int:
""" """
Allocates Allocates
""" """
global NEXT_PORT while Ports.NEXT_PORT + num <= 65535:
while NEXT_PORT + num <= 65535: start = Ports.NEXT_PORT
start = NEXT_PORT Ports.NEXT_PORT += num
NEXT_PORT += num if not check_port_range(range(start, Ports.NEXT_PORT)):
if not check_port_range(range(start, NEXT_PORT)):
continue continue
return start return start
raise Exception("cannot find enough free port") msg = "cannot find enough free port"
raise OSError(msg)
@pytest.fixture @pytest.fixture()
def ports() -> Ports: def ports() -> Ports:
return Ports() return Ports()

View File

@ -2,10 +2,10 @@ import os
import shutil import shutil
import subprocess import subprocess
import time import time
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
from sys import platform from sys import platform
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Iterator, Optional
import pytest import pytest
from command import Command from command import Command
@ -20,7 +20,7 @@ class Sshd:
class SshdConfig: class SshdConfig:
def __init__(self, path: str, key: str, preload_lib: Optional[str]) -> None: def __init__(self, path: str, key: str, preload_lib: str | None) -> None:
self.path = path self.path = path
self.key = key self.key = key
self.preload_lib = preload_lib self.preload_lib = preload_lib
@ -30,8 +30,8 @@ class SshdConfig:
def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]: def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
# FIXME, if any parent of `project_root` is world-writable than sshd will refuse it. # FIXME, if any parent of `project_root` is world-writable than sshd will refuse it.
with TemporaryDirectory(dir=project_root) as _dir: with TemporaryDirectory(dir=project_root) as _dir:
dir = Path(_dir) directory = Path(_dir)
host_key = dir / "host_ssh_host_ed25519_key" host_key = directory / "host_ssh_host_ed25519_key"
subprocess.run( subprocess.run(
[ [
"ssh-keygen", "ssh-keygen",
@ -45,7 +45,7 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
check=True, check=True,
) )
sshd_config = dir / "sshd_config" sshd_config = directory / "sshd_config"
sshd_config.write_text( sshd_config.write_text(
f""" f"""
HostKey {host_key} HostKey {host_key}
@ -60,7 +60,7 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
lib_path = None lib_path = None
if platform == "linux": if platform == "linux":
# This enforces a login shell by overriding the login shell of `getpwnam(3)` # This enforces a login shell by overriding the login shell of `getpwnam(3)`
lib_path = str(dir / "libgetpwnam-preload.so") lib_path = str(directory / "libgetpwnam-preload.so")
subprocess.run( subprocess.run(
[ [
os.environ.get("CC", "cc"), os.environ.get("CC", "cc"),
@ -75,7 +75,7 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
yield SshdConfig(str(sshd_config), str(host_key), lib_path) yield SshdConfig(str(sshd_config), str(host_key), lib_path)
@pytest.fixture @pytest.fixture()
def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]: def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]:
port = ports.allocate(1) port = ports.allocate(1)
sshd = shutil.which("sshd") sshd = shutil.which("sshd")
@ -104,7 +104,8 @@ def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Ss
"-p", "-p",
str(port), str(port),
"true", "true",
] ],
check=False,
).returncode ).returncode
== 0 == 0
): ):
@ -113,5 +114,6 @@ def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Ss
else: else:
rc = proc.poll() rc = proc.poll()
if rc is not None: if rc is not None:
raise Exception(f"sshd processes was terminated with {rc}") msg = f"sshd processes was terminated with {rc}"
raise OSError(msg)
time.sleep(0.1) time.sleep(0.1)

View File

@ -2,6 +2,7 @@ import asyncio
import os import os
import pwd import pwd
import pytest
from sshd import Sshd from sshd import Sshd
from nix_fast_build import async_main from nix_fast_build import async_main
@ -12,29 +13,26 @@ def cli(args: list[str]) -> None:
def test_help() -> None: def test_help() -> None:
try: with pytest.raises(SystemExit) as e:
cli(["--help"]) cli(["--help"])
except SystemExit as e: assert e.value.code == 0
assert e.code == 0
def test_build() -> None: def test_build() -> None:
try: with pytest.raises(SystemExit) as e:
cli(["--option", "builders", ""]) cli(["--option", "builders", ""])
except SystemExit as e: assert e.value.code == 1
assert e.code == 0
def test_eval_error() -> None: def test_eval_error() -> None:
try: with pytest.raises(SystemExit) as e:
cli(["--option", "builders", "", "--flake", ".#legacyPackages"]) cli(["--option", "builders", "", "--flake", ".#legacyPackages"])
except SystemExit as e: assert e.value.code == 1
assert e.code == 1
def test_remote(sshd: Sshd) -> None: def test_remote(sshd: Sshd) -> None:
login = pwd.getpwuid(os.getuid()).pw_name login = pwd.getpwuid(os.getuid()).pw_name
try: with pytest.raises(SystemExit) as e:
cli( cli(
[ [
"--option", "--option",
@ -56,5 +54,4 @@ def test_remote(sshd: Sshd) -> None:
"/dev/null", "/dev/null",
] ]
) )
except SystemExit as e: assert e.value.code == 0
assert e.code == 0