From a6c3c57d40096d565818f051b5df17d316bebe11 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Tue, 5 Jul 2022 18:54:20 +0530 Subject: [PATCH] Test stdio redirection with socket prewarm --- kitty/prewarm.py | 2 ++ kitty_tests/__init__.py | 24 ++++++++++++++++++++---- kitty_tests/prewarm.py | 19 +++++++++++++++++-- 3 files changed, 39 insertions(+), 6 deletions(-) diff --git a/kitty/prewarm.py b/kitty/prewarm.py index 5ccde0e2a..0c3eda0bd 100644 --- a/kitty/prewarm.py +++ b/kitty/prewarm.py @@ -362,6 +362,8 @@ class SocketClosed(Exception): def verify_socket_creds(conn: socket.socket) -> bool: + # needed as abstract unix sockets used on Linux have no permissions and + # older BSDs ignore socket file permissions uid, gid = getpeereid(conn.fileno()) return uid == os.geteuid() and gid == os.getegid() diff --git a/kitty_tests/__init__.py b/kitty_tests/__init__.py index 294a7b5dc..4fec7807d 100644 --- a/kitty_tests/__init__.py +++ b/kitty_tests/__init__.py @@ -10,7 +10,7 @@ import struct import sys import termios import time -from pty import CHILD, fork +from pty import CHILD, fork, STDIN_FILENO, STDOUT_FILENO from unittest import TestCase from kitty.config import finalize_keys, finalize_mouse_mappings @@ -180,9 +180,12 @@ class BaseTest(TestCase): s = Screen(c, lines, cols, scrollback, cell_width, cell_height, 0, c) return s - def create_pty(self, argv=None, cols=80, lines=100, scrollback=100, cell_width=10, cell_height=20, options=None, cwd=None, env=None): + def create_pty( + self, argv=None, cols=80, lines=100, scrollback=100, cell_width=10, cell_height=20, + options=None, cwd=None, env=None, stdin_fd=None, stdout_fd=None + ): self.set_options(options) - return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env) + return PTY(argv, lines, cols, scrollback, cell_width, cell_height, cwd, env, stdin_fd=stdin_fd, stdout_fd=stdout_fd) def assertEqualAttributes(self, c1, c2): x1, y1, c1.x, c1.y = c1.x, c1.y, 0, 0 @@ -195,7 +198,10 @@ class BaseTest(TestCase): class PTY: - def __init__(self, argv=None, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, cwd=None, env=None): + def __init__( + self, argv=None, rows=25, columns=80, scrollback=100, cell_width=10, cell_height=20, + cwd=None, env=None, stdin_fd=None, stdout_fd=None + ): if isinstance(argv, str): argv = shlex.split(argv) self.write_buf = b'' @@ -211,7 +217,17 @@ class PTY: time.sleep(0.01) if cwd: os.chdir(cwd) + if stdin_fd is not None: + os.dup2(stdin_fd, STDIN_FILENO) + os.close(stdin_fd) + if stdout_fd is not None: + os.dup2(stdout_fd, STDOUT_FILENO) + os.close(stdout_fd) os.execvpe(argv[0], argv, env or os.environ) + if stdin_fd is not None: + os.close(stdin_fd) + if stdout_fd is not None: + os.close(stdout_fd) os.set_blocking(self.master_fd, False) self.cell_width = cell_width self.cell_height = cell_height diff --git a/kitty_tests/prewarm.py b/kitty_tests/prewarm.py index ef31531de..ae0f9b029 100644 --- a/kitty_tests/prewarm.py +++ b/kitty_tests/prewarm.py @@ -32,11 +32,13 @@ def socket_child_main(exit_code=0): 'test_env': os.environ.get('TEST_ENV_PASS', ''), 'cwd': os.getcwd(), 'font_family': get_options().font_family, - 'cols': read_screen_size().cols, + 'cols': read_screen_size(fd=sys.stderr.fileno()).cols, + 'stdin_data': sys.stdin.read(), 'done': 'hello', } print(json.dumps(output, indent=2), file=sys.stderr, flush=True) + print('testing stdout', end='') raise SystemExit(exit_code) # END_socket_child_main @@ -61,7 +63,18 @@ class Prewarm(BaseTest): return env = {'TEST_ENV_PASS': 'xyz', 'KITTY_PREWARM_SOCKET': p.socket_env_var()} cols = 117 - pty = self.create_pty(argv=[kitty_exe(), '+runpy', src + f'socket_child_main({exit_code})'], cols=cols, env=env, cwd=cwd) + stdin_r, stdin_w = os.pipe() + os.set_inheritable(stdin_w, False) + stdout_r, stdout_w = os.pipe() + os.set_inheritable(stdout_r, False) + pty = self.create_pty( + argv=[kitty_exe(), '+runpy', src + f'socket_child_main({exit_code})'], cols=cols, env=env, cwd=cwd, + stdin_fd=stdin_r, stdout_fd=stdout_w) + stdin_data = 'testing--stdin-read' + with open(stdin_w, 'w') as f: + f.write(stdin_data) + with open(stdout_r) as f: + stdout_data = f.read() status = os.waitpid(pty.child_pid, 0)[1] with suppress(AttributeError): self.assertEqual(os.waitstatus_to_exitcode(status), exit_code) @@ -71,6 +84,8 @@ class Prewarm(BaseTest): self.assertEqual(output['cwd'], cwd) self.assertEqual(output['font_family'], 'prewarm') self.assertEqual(output['cols'], cols) + self.assertEqual(output['stdin_data'], stdin_data) + self.assertEqual(stdout_data, 'testing stdout') def test_prewarming(self): from kitty.prewarm import fork_prewarm_process