diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 587492043..71c2c348e 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -201,6 +201,11 @@ class ActiveCommand: def __init__(self, ftc: FileTransmissionCommand) -> None: self.ftc = ftc + def close(self) -> None: + if self.file is not None: + self.file.close() + self.file = None + class FileTransmission: @@ -210,6 +215,11 @@ class FileTransmission: self.window_id = window_id self.active_cmds = {} + def __del__(self) -> None: + for cmd in self.active_cmds.values(): + cmd.close() + self.active_cmds = {} + def handle_serialized_command(self, data: str) -> None: try: cmd = parse_command(data) @@ -227,9 +237,16 @@ class FileTransmission: if cmd.id not in self.active_cmds: log_error('File transmission data command received with unknown id') return - self.add_data(cmd) + try: + self.add_data(cmd) + except Exception: + self.abort_in_flight(cmd.id) + raise if cmd.action is Action.end_data and cmd.id in self.active_cmds: - self.commit(cmd.id) + try: + self.commit(cmd.id) + except Exception: + self.abort_in_flight(cmd.id) def send_response(self, ac: Optional[FileTransmissionCommand], **fields: str) -> None: if ac is None: @@ -263,6 +280,7 @@ class FileTransmission: if cmd.ftc.quiet: return else: + cmd.close() del self.active_cmds[cmd_id] if cmd.ftc.quiet > 1: return @@ -274,13 +292,18 @@ class FileTransmission: errname = errno.errorcode.get(err.errno, 'EFAIL') self.send_response(ac, status=f'{errname}:{msg}') + def abort_in_flight(self, cmd_id: str) -> None: + c = self.active_cmds.pop(cmd_id, None) + if c is not None: + c.close() + def add_data(self, cmd: FileTransmissionCommand) -> None: ac = self.active_cmds.get(cmd.id) def abort_in_flight() -> None: - self.active_cmds.pop(cmd.id, None) + self.abort_in_flight(cmd.id) - if ac is None or not ac.dest: + if ac is None or not ac.dest or ac.ftc.action is not Action.send: return abort_in_flight() if ac.file is None: @@ -310,9 +333,20 @@ class FileTransmission: def commit(self, cmd_id: str) -> None: cmd = self.active_cmds.pop(cmd_id, None) - if cmd is not None and cmd.ftc.container_fmt and cmd.file is not None: - cmd.file.seek(0, os.SEEK_SET) - Container.extractor_for_container_fmt(cmd.file, cmd.ftc.container_fmt)(cmd.dest) + if cmd is not None: + try: + if cmd.ftc.container_fmt is not Container.none and cmd.file is not None: + cmd.file.seek(0, os.SEEK_SET) + try: + Container.extractor_for_container_fmt(cmd.file, cmd.ftc.container_fmt)(cmd.dest) + except OSError as e: + self.send_fail_on_os_error(cmd.ftc, e, 'Failed to extract files from container') + except Exception: + self.send_response(cmd.ftc, status='EFAIL:Failed to extract files from container') + if not cmd.ftc.quiet: + self.send_response(cmd.ftc, status='COMPLETED') + finally: + cmd.close() class TestFileTransmission(FileTransmission): diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index f09880f49..571363710 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -36,8 +36,24 @@ class TestFileTransmission(BaseTest): shutil.rmtree(self.tdir) def test_file_put(self): + # send refusal + for quiet in (0, 1, 2): + ft = FileTransmission() + ft.handle_serialized_command(serialized_cmd(action='send', id='x', quiet=quiet)) + self.ae(ft.test_responses, [] if quiet == 2 else [{'status': 'EPERM:User refused the transfer', 'id': 'x'}]) + self.assertFalse(ft.active_cmds) + # simple single file send ft = FileTransmission() - ft.handle_serialized_command(serialized_cmd(action='send', id='1', dest=os.path.join(self.tdir, '1.bin'))) - self.assertIn('1', ft.active_cmds) - self.ae(os.path.basename(ft.active_cmds['1'].dest), '1.bin') - self.assertIsNone(ft.active_cmds['1'].file) + dest = os.path.join(self.tdir, '1.bin') + ft.handle_serialized_command(serialized_cmd(action='send', dest=dest)) + self.assertIn('', ft.active_cmds) + self.ae(os.path.basename(ft.active_cmds[''].dest), '1.bin') + self.assertIsNone(ft.active_cmds[''].file) + self.ae(ft.test_responses, [{'status': 'OK'}]) + ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) + self.assertIsNotNone(ft.active_cmds[''].file) + ft.handle_serialized_command(serialized_cmd(action='end_data', data='123')) + self.assertFalse(ft.active_cmds) + self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}]) + with open(dest) as f: + self.ae(f.read(), 'abcd123')