From d178f4c75d06988223368fa905e64f1845568972 Mon Sep 17 00:00:00 2001 From: Kovid Goyal Date: Mon, 23 Aug 2021 11:33:03 +0530 Subject: [PATCH] Implement canceling of file transmission --- kitty/file_transmission.py | 18 +++++++++++++++++- kitty_tests/file_transmission.py | 30 +++++++++++++++++++++--------- 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/kitty/file_transmission.py b/kitty/file_transmission.py index 71c2c348e..4909da59f 100644 --- a/kitty/file_transmission.py +++ b/kitty/file_transmission.py @@ -7,6 +7,7 @@ import errno import os import tempfile from base64 import standard_b64decode, standard_b64encode +from contextlib import suppress from enum import Enum, auto from functools import partial from typing import IO, TYPE_CHECKING, Any, Dict, List, Optional, Union @@ -25,6 +26,7 @@ class Action(Enum): end_data = auto() receive = auto() invalid = auto() + cancel = auto() class Container(Enum): @@ -206,6 +208,15 @@ class ActiveCommand: self.file.close() self.file = None + def cancel(self) -> None: + needs_delete = self.file is not None and self.file.name and self.ftc.container_fmt is Container.none + if needs_delete: + fname = getattr(self.file, 'name') + self.close() + if needs_delete: + with suppress(FileNotFoundError): + os.unlink(fname) + class FileTransmission: @@ -226,13 +237,18 @@ class FileTransmission: except Exception as e: log_error(f'Failed to parse file transmission command with error: {e}') return - if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data): + if cmd.id in self.active_cmds and cmd.action not in (Action.data, Action.end_data, Action.cancel): log_error('File transmission command received while another is in flight, aborting') + self.active_cmds[cmd.id].close() del self.active_cmds[cmd.id] if cmd.action is Action.send: self.active_cmds[cmd.id] = ActiveCommand(cmd) self.start_send(cmd) + elif cmd.action is Action.cancel: + ac = self.active_cmds.pop(cmd.id, None) + if ac is not None: + ac.cancel() elif cmd.action in (Action.data, Action.end_data): if cmd.id not in self.active_cmds: log_error('File transmission data command received with unknown id') diff --git a/kitty_tests/file_transmission.py b/kitty_tests/file_transmission.py index 571363710..6af37ee6a 100644 --- a/kitty_tests/file_transmission.py +++ b/kitty_tests/file_transmission.py @@ -43,17 +43,29 @@ class TestFileTransmission(BaseTest): self.ae(ft.test_responses, [] if quiet == 2 else [{'status': 'EPERM:User refused the transfer', 'id': 'x'}]) self.assertFalse(ft.active_cmds) # simple single file send + for quiet in (0, 1, 2): + ft = FileTransmission() + dest = os.path.join(self.tdir, '1.bin') + ft.handle_serialized_command(serialized_cmd(action='send', dest=dest, quiet=quiet)) + self.assertIn('', ft.active_cmds) + self.ae(os.path.basename(ft.active_cmds[''].dest), '1.bin') + self.assertIsNone(ft.active_cmds[''].file) + self.ae(ft.test_responses, [] if quiet else [{'status': 'OK'}]) + ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) + self.ae(ft.active_cmds[''].file.name, dest) + ft.handle_serialized_command(serialized_cmd(action='end_data', data='123')) + self.assertFalse(ft.active_cmds) + self.ae(ft.test_responses, [] if quiet else [{'status': 'OK'}, {'status': 'COMPLETED'}]) + with open(dest) as f: + self.ae(f.read(), 'abcd123') + # cancel a send ft = FileTransmission() - dest = os.path.join(self.tdir, '1.bin') + dest = os.path.join(self.tdir, '2.bin') ft.handle_serialized_command(serialized_cmd(action='send', dest=dest)) - self.assertIn('', ft.active_cmds) - self.ae(os.path.basename(ft.active_cmds[''].dest), '1.bin') - self.assertIsNone(ft.active_cmds[''].file) self.ae(ft.test_responses, [{'status': 'OK'}]) ft.handle_serialized_command(serialized_cmd(action='data', data='abcd')) - self.assertIsNotNone(ft.active_cmds[''].file) - ft.handle_serialized_command(serialized_cmd(action='end_data', data='123')) + self.assertTrue(os.path.exists(dest)) + ft.handle_serialized_command(serialized_cmd(action='cancel')) + self.ae(ft.test_responses, [{'status': 'OK'}]) + self.assertFalse(os.path.exists(dest)) self.assertFalse(ft.active_cmds) - self.ae(ft.test_responses, [{'status': 'OK'}, {'status': 'COMPLETED'}]) - with open(dest) as f: - self.ae(f.read(), 'abcd123')