From 15978c9c608fb3994017655538014b622417791e Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Thu, 20 Jul 2023 19:54:54 +0530 Subject: [PATCH] Infrastructure for testing the transfer kitten --- kitty_tests/__init__.py | 23 ++++++++++++++++ kitty_tests/file_transmission.py | 46 +++++++++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/kitty_tests/__init__.py b/kitty_tests/__init__.py index 017c796aa..1bebdcfc0 100644 --- a/kitty_tests/__init__.py +++ b/kitty_tests/__init__.py @@ -28,6 +28,7 @@ class Callbacks: def __init__(self, pty=None) -> None: self.clear() self.pty = pty + self.ftc = None def write(self, data) -> None: self.wtcbuf += data @@ -123,6 +124,10 @@ class Callbacks: data = standard_b64decode(msg) self.pty.write_to_child(data) + def file_transmission(self, data): + if self.ftc: + self.ftc.handle_serialized_command(data) + def filled_line_buf(ynum=5, xnum=5, cursor=Cursor()): ans = LineBuf(ynum, xnum) @@ -211,6 +216,7 @@ class PTY: else: self.child_pid, self.master_fd = fork() self.is_child = self.child_pid == CHILD + self.child_waited_for = False if self.is_child: while read_screen_size().width != columns * cell_width: time.sleep(0.01) @@ -253,6 +259,9 @@ class PTY: if hasattr(self, 'slave_fd'): os.close(self.slave_fd) del self.slave_fd + if self.child_pid > 0 and not self.child_waited_for: + os.waitpid(self.child_pid, 0) + self.child_waited_for = True def write_to_child(self, data, flush=False): if isinstance(data, str): @@ -295,6 +304,20 @@ class PTY: if not q(): raise TimeoutError(f'The condition was not met. Screen contents: \n {repr(self.screen_contents())}') + def wait_till_child_exits(self, timeout=10, require_exit_code=None): + end_time = time.monotonic() + timeout + while time.monotonic() <= end_time: + status = os.waitid(os.P_PID, self.child_pid, os.WNOHANG | os.WEXITED) + if status is not None and status.si_pid == self.child_pid: + self.child_waited_for = True + if require_exit_code is not None and os.waitstatus_to_exitcode(status.si_status) != require_exit_code: + raise AssertionError( + f'Child exited with exit status: {status} code: {os.waitstatus_to_exitcode(status.si_status)} != {require_exit_code}.' + f' Screen contents:\n{self.screen_contents()}') + return status + self.process_input_from_child(timeout=0.02) + raise AssertionError(f'Child did not exit in {timeout} seconds. Screen contents:\n{self.screen_contents()}') + def set_window_size(self, rows=25, columns=80, send_signal=True): if hasattr(self, 'screen'): self.screen.resize(rows, columns) diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index 9ba1fe318..c397fafdb 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -7,13 +7,15 @@ import shutil import stat import tempfile import zlib +from contextlib import contextmanager from kittens.transfer.rsync import Differ, Hasher, Patcher, decode_utf8_buffer, parse_ftc from kittens.transfer.utils import set_paths +from kitty.constants import kitten_exe from kitty.file_transmission import Action, Compression, FileTransmissionCommand, FileType, TransmissionType, ZlibDecompressor from kitty.file_transmission import TestFileTransmission as FileTransmission -from . import BaseTest +from . import PTY, BaseTest def response(id='test', msg='', file_id='', name='', action='status', status='', size=-1): @@ -154,6 +156,24 @@ def test_rsync_roundtrip(self: 'TestFileTransmission') -> None: run_roundtrip_test(self, src_data, changed + b"xyz...", num_of_patches, total_patch_size) +class PtyFileTransmission(FileTransmission): + + def __init__(self, pty, allow=True): + self.pty = pty + super().__init__(allow=allow) + self.pty.callbacks.ftc = self + + def write_ftc_to_child(self, payload: FileTransmissionCommand, appendleft: bool = False, use_pending: bool = True) -> bool: + self.pty.write_to_child('\x1b]' + payload.serialize(prefix_with_osc_code=True) + '\x1b\\', flush=True) + + +class TransferPTY(PTY): + + def __init__(self, cmd, cwd, allow=True): + super().__init__(cmd, cwd=cwd) + self.fc = PtyFileTransmission(self, allow=allow) + + class TestFileTransmission(BaseTest): def setUp(self): @@ -445,3 +465,27 @@ class TestFileTransmission(BaseTest): h128 = Hasher("xxh3-128") h128.update(b'abcd') self.assertEqual(h128.hexdigest(), '8d6b60383dfa90c21be79eecd1b1353d') + + @contextmanager + def run_kitten(self, cmd, home_dir=''): + homedir_ephemeral = not home_dir + home_dir = self.home_dir = home_dir or os.path.realpath(tempfile.mkdtemp()) + cmd = [kitten_exe(), 'transfer'] + cmd + try: + pty = TransferPTY(cmd, cwd=home_dir) + i = 10 + while i > 0 and not pty.screen_contents().strip(): + pty.process_input_from_child() + i -= 1 + yield pty + finally: + if homedir_ephemeral and os.path.exists(home_dir): + shutil.rmtree(home_dir) + + def test_transfer_send(self): + src = os.path.join(self.tdir, 'src') + with open(src, 'wb') as s: + s.write(os.urandom(813)) + dest = os.path.join(self.tdir, 'dest') + with self.run_kitten([src, dest]) as pty: + pty.wait_till_child_exits(require_exit_code=0)