linting fixes

This commit is contained in:
Jörg Thalheim 2023-11-29 12:22:04 +01:00
parent c25b5a3086
commit 2d71dd8d94
5 changed files with 94 additions and 97 deletions

View File

@ -1,5 +1,6 @@
import argparse
import asyncio
import contextlib
import json
import logging
import multiprocessing
@ -13,10 +14,12 @@ from abc import ABC
from asyncio import Queue, TaskGroup
from asyncio.subprocess import Process
from collections import defaultdict
from collections.abc import AsyncIterator, Coroutine
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass, field
from tempfile import TemporaryDirectory
from typing import IO, Any, AsyncIterator, Coroutine, DefaultDict, NoReturn, TypeVar
from types import TracebackType
from typing import IO, Any, NoReturn, TypeVar
logger = logging.getLogger(__name__)
@ -35,13 +38,18 @@ class Pipe:
def __enter__(self) -> "Pipe":
return self
def __exit__(self, _exc_type: Any, _exc_value: Any, _traceback: Any) -> None:
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
self.read_file.close()
self.write_file.close()
def nix_command(args: list[str]) -> list[str]:
return ["nix", "--experimental-features", "nix-command flakes"] + args
return ["nix", "--experimental-features", "nix-command flakes", *args]
@dataclass
@ -55,7 +63,7 @@ class Options:
systems: set[str] = field(default_factory=set)
eval_max_memory_size: int = 4096
skip_cached: bool = False
eval_workers: int = multiprocessing.cpu_count()
eval_workers: int = field(default_factory=multiprocessing.cpu_count)
max_jobs: int = 0
retries: int = 0
debug: bool = False
@ -76,9 +84,8 @@ def _maybe_remote(
cmd: list[str], remote: str | None, remote_ssh_options: list[str]
) -> list[str]:
if remote:
return ["ssh", remote] + remote_ssh_options + ["--", shlex.join(cmd)]
else:
return cmd
return ["ssh", remote, *remote_ssh_options, "--", shlex.join(cmd)]
return cmd
def maybe_remote(cmd: list[str], opts: Options) -> list[str]:
@ -245,7 +252,7 @@ async def parse_args(args: list[str]) -> Options:
else:
a.no_nom = shutil.which("nom") is None
if a.systems is None:
systems = set([nix_config.get("system", "")])
systems = {nix_config.get("system", "")}
else:
systems = set(a.systems.split(" "))
@ -280,7 +287,7 @@ def nix_flake_metadata(flake_url: str) -> dict[str, Any]:
]
)
logger.info(f"run {shlex.join(cmd)}")
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
proc = subprocess.run(cmd, stdout=subprocess.PIPE, check=True)
if proc.returncode != 0:
die(
f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}"
@ -288,7 +295,7 @@ def nix_flake_metadata(flake_url: str) -> dict[str, Any]:
try:
data = json.loads(proc.stdout)
except Exception as e:
except (json.JSONDecodeError, OSError) as e:
die(
f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout.decode('utf-8', 'replace')}"
)
@ -303,10 +310,7 @@ def is_path_input(node: dict[str, dict[str, str]]) -> bool:
def check_for_path_inputs(data: dict[str, Any]) -> bool:
for node in data["locks"]["nodes"].values():
if is_path_input(node):
return True
return False
return any(is_path_input(node) for node in data["locks"]["nodes"].values())
def upload_sources(opts: Options) -> str:
@ -333,7 +337,7 @@ def upload_sources(opts: Options) -> str:
path,
]
)
proc = subprocess.run(cmd, stdout=subprocess.PIPE, env=env)
proc = subprocess.run(cmd, stdout=subprocess.PIPE, env=env, check=False)
if proc.returncode != 0:
die(
f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}"
@ -354,14 +358,14 @@ def upload_sources(opts: Options) -> str:
)
print("run " + shlex.join(cmd))
logger.info("run %s", shlex.join(cmd))
proc = subprocess.run(cmd, stdout=subprocess.PIPE)
proc = subprocess.run(cmd, stdout=subprocess.PIPE, check=False)
if proc.returncode != 0:
die(
f"failed to upload sources: {shlex.join(cmd)} failed with {proc.returncode}"
)
try:
return json.loads(proc.stdout)["path"]
except Exception as e:
except (json.JSONDecodeError, OSError) as e:
die(
f"failed to parse output of {shlex.join(cmd)}: {e}\nGot: {proc.stdout.decode('utf-8', 'replace')}"
)
@ -375,9 +379,9 @@ def nix_shell(packages: list[str]) -> list[str]:
"nix-command",
"--extra-experimental-features",
"flakes",
*packages,
"-c",
]
+ packages
+ ["-c"]
)
@ -389,24 +393,21 @@ async def ensure_stop(
yield proc
finally:
if proc.returncode is not None:
return
proc.send_signal(signal_no)
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError:
print(f"Failed to stop process {shlex.join(cmd)}. Killing it.")
try:
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
with contextlib.suppress(ProcessLookupError):
proc.send_signal(signal_no)
try:
await asyncio.wait_for(proc.wait(), timeout=timeout)
except asyncio.TimeoutError:
print(f"Failed to stop process {shlex.join(cmd)}. Killing it.")
proc.kill()
await proc.wait()
@asynccontextmanager
async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:
assert opts.remote
ssh_cmd = ["ssh", opts.remote] + opts.remote_ssh_options + ["--"]
cmd = ssh_cmd + ["mktemp", "-d"]
ssh_cmd = ["ssh", opts.remote, *opts.remote_ssh_options, "--"]
cmd = [*ssh_cmd, "mktemp", "-d"]
proc = await asyncio.create_subprocess_exec(*cmd, stdout=subprocess.PIPE)
assert proc.stdout is not None
line = await proc.stdout.readline()
@ -419,7 +420,7 @@ async def remote_temp_dir(opts: Options) -> AsyncIterator[str]:
try:
yield tempdir
finally:
cmd = ssh_cmd + ["rm", "-rf", tempdir]
cmd = [*ssh_cmd, "rm", "-rf", tempdir]
logger.info("run %s", shlex.join(cmd))
proc = await asyncio.create_subprocess_exec(*cmd)
await proc.wait()
@ -443,7 +444,8 @@ async def nix_eval_jobs(stack: AsyncExitStack, opts: Options) -> AsyncIterator[P
str(opts.eval_workers),
"--flake",
f"{opts.flake_url}#{opts.flake_fragment}",
] + opts.options
*opts.options,
]
if opts.skip_cached:
args.append("--check-cache-status")
if opts.remote:
@ -462,7 +464,7 @@ async def nix_eval_jobs(stack: AsyncExitStack, opts: Options) -> AsyncIterator[P
@asynccontextmanager
async def nix_output_monitor(pipe: Pipe, opts: Options) -> AsyncIterator[Process]:
cmd = maybe_remote(nix_shell(["nixpkgs#nix-output-monitor"]) + ["nom"], opts)
cmd = maybe_remote([*nix_shell(["nixpkgs#nix-output-monitor"]), "nom"], opts)
proc = await asyncio.create_subprocess_exec(*cmd, stdin=pipe.read_file)
try:
yield proc
@ -471,10 +473,8 @@ async def nix_output_monitor(pipe: Pipe, opts: Options) -> AsyncIterator[Process
try:
pipe.write_file.close()
pipe.read_file.close()
try:
with contextlib.suppress(ProcessLookupError):
proc.kill()
except ProcessLookupError:
pass
await proc.wait()
finally:
print("\033[?25h")
@ -504,7 +504,7 @@ class Build:
async def nix_copy(
self, args: list[str], exit_stack: AsyncExitStack, opts: Options
) -> int:
cmd = maybe_remote(nix_command(["copy", "--log-format", "raw"] + args), opts)
cmd = maybe_remote(nix_command(["copy", "--log-format", "raw", *args]), opts)
logger.debug("run %s", shlex.join(cmd))
proc = await asyncio.create_subprocess_exec(*cmd)
await exit_stack.enter_async_context(ensure_stop(proc, cmd))
@ -514,8 +514,14 @@ class Build:
if not opts.copy_to:
return 0
cmd = nix_command(
["copy", "--log-format", "raw", "--to", opts.copy_to]
+ list(self.outputs.values())
[
"copy",
"--log-format",
"raw",
"--to",
opts.copy_to,
*list(self.outputs.values()),
]
)
cmd = maybe_remote(cmd, opts)
logger.debug("run %s", shlex.join(cmd))
@ -534,8 +540,8 @@ class Build:
"--no-check-sigs",
"--from",
opts.remote_url,
*list(self.outputs.values()),
]
+ list(self.outputs.values())
)
logger.debug("run %s", shlex.join(cmd))
env = os.environ.copy()
@ -590,11 +596,7 @@ class QueueWithContext(Queue[T]):
async def nix_build(
attr: str, installable: str, stderr: IO[Any] | None, opts: Options
) -> AsyncIterator[Process]:
args = [
"nix-build",
installable,
"--keep-going",
] + opts.options
args = ["nix-build", installable, "--keep-going", *opts.options]
if opts.no_link:
args += ["--no-link"]
else:
@ -609,10 +611,8 @@ async def nix_build(
try:
yield proc
finally:
try:
with contextlib.suppress(ProcessLookupError):
proc.kill()
except ProcessLookupError:
pass
@dataclass
@ -756,7 +756,7 @@ async def run(stack: AsyncExitStack, opts: Options) -> int:
output_monitor: Process | None = None
if output_monitor_future:
output_monitor = await output_monitor_future
failures: DefaultDict[type, list[Failure]] = defaultdict(list)
failures: defaultdict[type, list[Failure]] = defaultdict(list)
build_queue: QueueWithContext[Job | StopTask] = QueueWithContext()
upload_queue: QueueWithContext[Build | StopTask] = QueueWithContext()
download_queue: QueueWithContext[Build | StopTask] = QueueWithContext()

View File

@ -1,25 +1,29 @@
import contextlib
import os
import signal
import subprocess
from typing import IO, Any, Dict, Iterator, List, Union
from collections.abc import Iterator
from typing import IO, Any
import pytest
_FILE = Union[None, int, IO[Any]]
_FILE = None | int | IO[Any]
class Command:
def __init__(self) -> None:
self.processes: List[subprocess.Popen[str]] = []
self.processes: list[subprocess.Popen[str]] = []
def run(
self,
command: List[str],
extra_env: Dict[str, str] = {},
command: list[str],
extra_env: dict[str, str] | None = None,
stdin: _FILE = None,
stdout: _FILE = None,
stderr: _FILE = None,
) -> subprocess.Popen[str]:
if extra_env is None:
extra_env = {}
env = os.environ.copy()
env.update(extra_env)
# We start a new session here so that we can than more reliably kill all childs as well
@ -40,13 +44,11 @@ class Command:
# We just kill all processes as quickly as possible because we don't
# care about corrupted state and want to make tests fasts.
for p in reversed(self.processes):
try:
with contextlib.suppress(OSError):
os.killpg(os.getpgid(p.pid), signal.SIGKILL)
except OSError:
pass
@pytest.fixture
@pytest.fixture()
def command() -> Iterator[Command]:
"""
Starts a background command. The process is automatically terminated in the end.

View File

@ -1,11 +1,7 @@
#!/usr/bin/env python3
import socket
import pytest
NEXT_PORT = 10000
def check_port(port: int) -> bool:
tcp = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
@ -14,33 +10,33 @@ def check_port(port: int) -> bool:
try:
tcp.bind(("127.0.0.1", port))
udp.bind(("127.0.0.1", port))
return True
except OSError:
return False
else:
return True
def check_port_range(port_range: range) -> bool:
for port in port_range:
if not check_port(port):
return False
return True
return all(check_port(port) for port in port_range)
class Ports:
NEXT_PORT = 10000
def allocate(self, num: int) -> int:
"""
Allocates
"""
global NEXT_PORT
while NEXT_PORT + num <= 65535:
start = NEXT_PORT
NEXT_PORT += num
if not check_port_range(range(start, NEXT_PORT)):
while Ports.NEXT_PORT + num <= 65535:
start = Ports.NEXT_PORT
Ports.NEXT_PORT += num
if not check_port_range(range(start, Ports.NEXT_PORT)):
continue
return start
raise Exception("cannot find enough free port")
msg = "cannot find enough free port"
raise OSError(msg)
@pytest.fixture
@pytest.fixture()
def ports() -> Ports:
return Ports()

View File

@ -2,10 +2,10 @@ import os
import shutil
import subprocess
import time
from collections.abc import Iterator
from pathlib import Path
from sys import platform
from tempfile import TemporaryDirectory
from typing import Iterator, Optional
import pytest
from command import Command
@ -20,7 +20,7 @@ class Sshd:
class SshdConfig:
def __init__(self, path: str, key: str, preload_lib: Optional[str]) -> None:
def __init__(self, path: str, key: str, preload_lib: str | None) -> None:
self.path = path
self.key = key
self.preload_lib = preload_lib
@ -30,8 +30,8 @@ class SshdConfig:
def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
# FIXME, if any parent of `project_root` is world-writable than sshd will refuse it.
with TemporaryDirectory(dir=project_root) as _dir:
dir = Path(_dir)
host_key = dir / "host_ssh_host_ed25519_key"
directory = Path(_dir)
host_key = directory / "host_ssh_host_ed25519_key"
subprocess.run(
[
"ssh-keygen",
@ -45,7 +45,7 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
check=True,
)
sshd_config = dir / "sshd_config"
sshd_config = directory / "sshd_config"
sshd_config.write_text(
f"""
HostKey {host_key}
@ -60,7 +60,7 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
lib_path = None
if platform == "linux":
# This enforces a login shell by overriding the login shell of `getpwnam(3)`
lib_path = str(dir / "libgetpwnam-preload.so")
lib_path = str(directory / "libgetpwnam-preload.so")
subprocess.run(
[
os.environ.get("CC", "cc"),
@ -75,7 +75,7 @@ def sshd_config(project_root: Path, test_root: Path) -> Iterator[SshdConfig]:
yield SshdConfig(str(sshd_config), str(host_key), lib_path)
@pytest.fixture
@pytest.fixture()
def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Sshd]:
port = ports.allocate(1)
sshd = shutil.which("sshd")
@ -104,7 +104,8 @@ def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Ss
"-p",
str(port),
"true",
]
],
check=False,
).returncode
== 0
):
@ -113,5 +114,6 @@ def sshd(sshd_config: SshdConfig, command: Command, ports: Ports) -> Iterator[Ss
else:
rc = proc.poll()
if rc is not None:
raise Exception(f"sshd processes was terminated with {rc}")
msg = f"sshd processes was terminated with {rc}"
raise OSError(msg)
time.sleep(0.1)

View File

@ -2,6 +2,7 @@ import asyncio
import os
import pwd
import pytest
from sshd import Sshd
from nix_fast_build import async_main
@ -12,29 +13,26 @@ def cli(args: list[str]) -> None:
def test_help() -> None:
try:
with pytest.raises(SystemExit) as e:
cli(["--help"])
except SystemExit as e:
assert e.code == 0
assert e.value.code == 0
def test_build() -> None:
try:
with pytest.raises(SystemExit) as e:
cli(["--option", "builders", ""])
except SystemExit as e:
assert e.code == 0
assert e.value.code == 1
def test_eval_error() -> None:
try:
with pytest.raises(SystemExit) as e:
cli(["--option", "builders", "", "--flake", ".#legacyPackages"])
except SystemExit as e:
assert e.code == 1
assert e.value.code == 1
def test_remote(sshd: Sshd) -> None:
login = pwd.getpwuid(os.getuid()).pw_name
try:
with pytest.raises(SystemExit) as e:
cli(
[
"--option",
@ -56,5 +54,4 @@ def test_remote(sshd: Sshd) -> None:
"/dev/null",
]
)
except SystemExit as e:
assert e.code == 0
assert e.value.code == 0