add type annotations to freeze functions

This commit is contained in:
Sarah Hoffmann 2022-07-03 19:04:05 +02:00
parent aaf2b6032e
commit 845c43137a
2 changed files with 15 additions and 9 deletions

View File

@ -77,7 +77,7 @@ class _Cursor(psycopg2.extras.DictCursor):
self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore
class _Connection(psycopg2.extensions.connection):
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
"""
@ -174,19 +174,22 @@ class _Connection(psycopg2.extensions.connection):
return (int(version_parts[0]), int(version_parts[1]))
class _ConnectionContext(ContextManager[_Connection]):
connection: _Connection
class ConnectionContext(ContextManager[Connection]):
""" Context manager of the connection that also provides direct access
to the underlying connection.
"""
connection: Connection
def connect(dsn: str) -> _ConnectionContext:
def connect(dsn: str) -> ConnectionContext:
""" Open a connection to the database using the specialised connection
factory. The returned object may be used in conjunction with 'with'.
When used outside a context manager, use the `connection` attribute
to get the connection.
"""
try:
conn = psycopg2.connect(dsn, connection_factory=_Connection)
ctxmgr = cast(_ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = cast(_Connection, conn)
conn = psycopg2.connect(dsn, connection_factory=Connection)
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
ctxmgr.connection = cast(Connection, conn)
return ctxmgr
except psycopg2.OperationalError as err:
raise UsageError(f"Cannot connect to database: {err}") from err

View File

@ -7,10 +7,13 @@
"""
Functions for removing unnecessary data from the database.
"""
from typing import Optional
from pathlib import Path
from psycopg2 import sql as pysql
from nominatim.db.connection import Connection
UPDATE_TABLES = [
'address_levels',
'gb_postcode',
@ -25,7 +28,7 @@ UPDATE_TABLES = [
'wikipedia_%'
]
def drop_update_tables(conn):
def drop_update_tables(conn: Connection) -> None:
""" Drop all tables only necessary for updating the database from
OSM replication data.
"""
@ -42,7 +45,7 @@ def drop_update_tables(conn):
conn.commit()
def drop_flatnode_file(fpath):
def drop_flatnode_file(fpath: Optional[Path]) -> None:
""" Remove the flatnode file if it exists.
"""
if fpath and fpath.exists():