More typing work

This commit is contained in:
Kovid Goyal 2020-03-08 22:08:18 +05:30
parent 353db678a2
commit 9beae321d7
No known key found for this signature in database
GPG Key ID: 06BC317B515ACE7C
6 changed files with 57 additions and 41 deletions

View File

@ -7,6 +7,7 @@ import re
import shlex
import subprocess
import sys
from typing import List, Tuple
SHELL_SCRIPT = '''\
#!/bin/sh
@ -43,7 +44,8 @@ exec -a "-$shell_name" "$0"
def get_ssh_cli():
other_ssh_args, boolean_ssh_args = [], []
other_ssh_args: List[str] = []
boolean_ssh_args: List[str] = []
raw = subprocess.Popen(['ssh'], stderr=subprocess.PIPE).stderr.read().decode('utf-8')
for m in re.finditer(r'\[(.+?)\]', raw):
q = m.group(1)
@ -56,11 +58,11 @@ def get_ssh_cli():
return set('-' + x for x in boolean_ssh_args), set('-' + x for x in other_ssh_args)
def parse_ssh_args(args):
def parse_ssh_args(args: List[str]) -> Tuple[List[str], List[str], bool]:
boolean_ssh_args, other_ssh_args = get_ssh_cli()
passthrough_args = {'-' + x for x in 'Nnf'}
ssh_args = []
server_args = []
server_args: List[str] = []
expecting_option_val = False
passthrough = False
for arg in args:
@ -97,7 +99,7 @@ def parse_ssh_args(args):
return ssh_args, server_args, passthrough
def quote(x):
def quote(x: str) -> str:
# we have to escape unbalanced quotes and other unparsable
# args as they will break the shell script
# But we do not want to quote things like * or 'echo hello'
@ -117,8 +119,8 @@ def main(args):
terminfo = subprocess.check_output(['infocmp']).decode('utf-8')
sh_script = SHELL_SCRIPT.replace('TERMINFO', terminfo, 1)
if len(server_args) > 1:
command_to_execute = [quote(c) for c in server_args[1:]]
command_to_execute = 'exec ' + ' '.join(command_to_execute)
command_to_executeg = (quote(c) for c in server_args[1:])
command_to_execute = 'exec ' + ' '.join(command_to_executeg)
else:
command_to_execute = ''
sh_script = sh_script.replace('EXEC_CMD', command_to_execute)

View File

@ -5,7 +5,7 @@
import sys
from contextlib import contextmanager
from functools import wraps
from typing import List
from typing import List, Optional, Union
from kitty.rgb import Color, color_as_sharp, to_color
@ -234,12 +234,18 @@ def alternate_screen(f=None):
def set_default_colors(fg=None, bg=None, cursor=None, select_bg=None, select_fg=None) -> str:
ans = ''
def item(which, num):
def item(which: Optional[Union[Color, str]], num: int) -> None:
nonlocal ans
if which is None:
ans += '\x1b]1{}\x1b\\'.format(num)
else:
ans += '\x1b]{};{}\x1b\\'.format(num, color_as_sharp(which if isinstance(which, Color) else to_color(which)))
if isinstance(which, Color):
q = color_as_sharp(which)
else:
x = to_color(which)
assert x is not None
q = color_as_sharp(x)
ans += '\x1b]{};{}\x1b\\'.format(num, q)
item(fg, 10)
item(bg, 11)
@ -249,7 +255,7 @@ def set_default_colors(fg=None, bg=None, cursor=None, select_bg=None, select_fg=
return ans
def write_to_clipboard(data, use_primary=False) -> str:
def write_to_clipboard(data: Union[str, bytes], use_primary=False) -> str:
if isinstance(data, str):
data = data.encode('utf-8')
from base64 import standard_b64encode
@ -260,8 +266,8 @@ def write_to_clipboard(data, use_primary=False) -> str:
ans = esc('!') # clear clipboard buffer
for chunk in (data[i:i+512] for i in range(0, len(data), 512)):
chunk = standard_b64encode(chunk).decode('ascii')
ans += esc(chunk)
s = standard_b64encode(chunk).decode('ascii')
ans += esc(s)
return ans

View File

@ -11,12 +11,13 @@ import math
from functools import partial as p
from itertools import repeat
from typing import (
Callable, Dict, Generator, Iterable, List, Optional, Sequence, Tuple, cast
Callable, Dict, Generator, Iterable, List, MutableSequence, Optional,
Sequence, Tuple, cast
)
scale = (0.001, 1., 1.5, 2.)
_dpi = 96.0
BufType = bytearray
BufType = MutableSequence[int]
def set_scale(new_scale: Sequence[float]) -> None:

View File

@ -6,7 +6,7 @@ import ctypes
import sys
from functools import partial
from math import ceil, cos, floor, pi
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, cast
from kitty.config import defaults
from kitty.constants import is_macos
@ -15,7 +15,9 @@ from kitty.fast_data_types import (
set_options, set_send_sprite_to_gpu, sprite_map_set_limits,
test_render_line, test_shape
)
from kitty.fonts.box_drawing import render_box_char, render_missing_glyph
from kitty.fonts.box_drawing import (
BufType, render_box_char, render_missing_glyph
)
from kitty.options_stub import Options as OptionsStub
from kitty.utils import log_error
@ -238,10 +240,11 @@ def prerender_function(
return tuple(map(ctypes.addressof, cells)) + (cells,)
def render_box_drawing(codepoint, cell_width, cell_height, dpi):
def render_box_drawing(codepoint: int, cell_width: int, cell_height: int, dpi: float):
CharTexture = ctypes.c_ubyte * (cell_width * cell_height)
buf = render_box_char(
chr(codepoint), CharTexture(), cell_width, cell_height, dpi
buf = CharTexture()
render_box_char(
chr(codepoint), cast(BufType, buf), cell_width, cell_height, dpi
)
return ctypes.addressof(buf), buf

View File

@ -16,6 +16,7 @@ import sys
import tempfile
import time
from contextlib import suppress
from typing import IO, Optional, cast
import requests
@ -33,10 +34,12 @@ appname = re.search(r"^appname\s+=\s+'([^']+)'", raw, flags=re.MULTILINE).group(
ALL_ACTIONS = 'man html build tag sdist upload website'.split()
def call(*cmd, cwd=None):
def call(*cmd: str, cwd: Optional[str] = None) -> None:
if len(cmd) == 1:
cmd = shlex.split(cmd[0])
ret = subprocess.Popen(cmd, cwd=cwd).wait()
q = shlex.split(cmd[0])
else:
q = list(cmd)
ret = subprocess.Popen(q, cwd=cwd).wait()
if ret != 0:
raise SystemExit(ret)
@ -118,25 +121,25 @@ def run_sdist(args):
subprocess.check_call(['xz', '-9', dest])
class ReadFileWithProgressReporting(io.BufferedReader): # {{{
def __init__(self, path, mode='rb'):
io.BufferedReader.__init__(self, open(path, mode))
class ReadFileWithProgressReporting(io.FileIO): # {{{
def __init__(self, path):
io.FileIO.__init__(self, path, 'rb')
self.seek(0, os.SEEK_END)
self._total = self.tell()
self.seek(0)
self.start_time = time.time()
self.start_time = time.monotonic()
def __len__(self):
def __len__(self) -> int:
return self._total
def read(self, size):
data = io.BufferedReader.read(self, size)
def read(self, size: int = -1) -> Optional[bytes]:
data = io.FileIO.read(self, size)
if data:
self.report_progress(len(data))
return data
def report_progress(self, size):
def write(*args):
def report_progress(self, size: int) -> None:
def write(*args: str) -> None:
print(*args, end='')
write('\x1b[s\x1b[K')
@ -144,7 +147,7 @@ class ReadFileWithProgressReporting(io.BufferedReader): # {{{
mb_pos = self.tell() / float(1024**2)
mb_tot = self._total / float(1024**2)
kb_pos = self.tell() / 1024.0
kb_rate = kb_pos / (time.time() - self.start_time)
kb_rate = kb_pos / (time.monotonic() - self.start_time)
bit_rate = kb_rate * 1024
eta = int((self._total - self.tell()) / bit_rate) + 1
eta_m, eta_s = eta / 60, eta % 60
@ -264,7 +267,7 @@ class GitHub(Base): # {{{
'Content-Length': str(f._total)
},
params={'name': fname},
data=f)
data=cast(IO[bytes], f))
def fail(self, r, msg):
print(msg, ' Status Code: %s' % r.status_code, file=sys.stderr)
@ -281,7 +284,7 @@ class GitHub(Base): # {{{
self.username, self.reponame, release_id)
r = self.requests.get(url)
if r.status_code != 200:
self.fail('Failed to get assets for release')
self.fail(r, 'Failed to get assets for release')
return {asset['name']: asset['id'] for asset in r.json()}
def releases(self):

View File

@ -734,15 +734,16 @@ def compile_python(base_path):
for f in files:
if f.rpartition('.')[-1] in ('pyc', 'pyo'):
os.remove(os.path.join(root, f))
for optimize in (0, 1, 2):
def c(base_path: str, **kw) -> None:
try:
kw = {'invalidation_mode': py_compile.PycInvalidationMode.UNCHECKED_HASH}
kw['invalidation_mode'] = py_compile.PycInvalidationMode.UNCHECKED_HASH
except AttributeError:
kw = {}
compileall.compile_dir(
base_path, ddir='', force=True, optimize=optimize, quiet=1,
workers=num_workers, **kw
)
pass
compileall.compile_dir(base_path, **kw)
for optimize in (0, 1, 2):
c(base_path, ddir='', force=True, optimize=optimize, quiet=1, workers=num_workers)
def create_linux_bundle_gunk(ddir, libdir_name):