add type annotation to DB utils

As a cursor is needed as type, make this a public type.
This commit is contained in:
Sarah Hoffmann 2022-07-05 10:46:55 +02:00
parent e6775e713c
commit 26f30bff28
2 changed files with 12 additions and 12 deletions

View File

@ -22,7 +22,7 @@ from nominatim.errors import UsageError
LOG = logging.getLogger()
class _Cursor(psycopg2.extras.DictCursor):
class Cursor(psycopg2.extras.DictCursor):
""" A cursor returning dict-like objects and providing specialised
execution functions.
"""
@ -82,18 +82,18 @@ class Connection(psycopg2.extensions.connection):
adds convenience functions for administrating the database.
"""
@overload # type: ignore[override]
def cursor(self) -> _Cursor:
def cursor(self) -> Cursor:
...
@overload
def cursor(self, name: str) -> _Cursor:
def cursor(self, name: str) -> Cursor:
...
@overload
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
...
def cursor(self, cursor_factory = _Cursor, **kwargs): # type: ignore
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
""" Return a new cursor. By default the specialised cursor is returned.
"""
return super().cursor(cursor_factory=cursor_factory, **kwargs)

View File

@ -7,14 +7,14 @@
"""
Helper functions for handling DB accesses.
"""
from typing import IO, Optional, Union
from typing import IO, Optional, Union, Any, Iterable
import subprocess
import logging
import gzip
import io
from pathlib import Path
from nominatim.db.connection import get_pg_env
from nominatim.db.connection import get_pg_env, Cursor
from nominatim.errors import UsageError
LOG = logging.getLogger()
@ -84,20 +84,20 @@ class CopyBuffer:
""" Data collector for the copy_from command.
"""
def __init__(self):
def __init__(self) -> None:
self.buffer = io.StringIO()
def __enter__(self):
def __enter__(self) -> 'CopyBuffer':
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
if self.buffer is not None:
self.buffer.close()
def add(self, *data):
def add(self, *data: Any) -> None:
""" Add another row of data to the copy buffer.
"""
first = True
@ -113,9 +113,9 @@ class CopyBuffer:
self.buffer.write('\n')
def copy_out(self, cur, table, columns=None):
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
""" Copy all collected data into the given table.
"""
if self.buffer.tell() > 0:
self.buffer.seek(0)
cur.copy_from(self.buffer, table, columns=columns)
cur.copy_from(self.buffer, table, columns=columns) # type: ignore[no-untyped-call]