diff --git a/nominatim/db/connection.py b/nominatim/db/connection.py index 10327725..3bfc582d 100644 --- a/nominatim/db/connection.py +++ b/nominatim/db/connection.py @@ -77,7 +77,7 @@ class _Cursor(psycopg2.extras.DictCursor): self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore -class _Connection(psycopg2.extensions.connection): +class Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. """ @@ -174,19 +174,22 @@ class _Connection(psycopg2.extensions.connection): return (int(version_parts[0]), int(version_parts[1])) -class _ConnectionContext(ContextManager[_Connection]): - connection: _Connection +class ConnectionContext(ContextManager[Connection]): + """ Context manager of the connection that also provides direct access + to the underlying connection. + """ + connection: Connection -def connect(dsn: str) -> _ConnectionContext: +def connect(dsn: str) -> ConnectionContext: """ Open a connection to the database using the specialised connection factory. The returned object may be used in conjunction with 'with'. When used outside a context manager, use the `connection` attribute to get the connection. """ try: - conn = psycopg2.connect(dsn, connection_factory=_Connection) - ctxmgr = cast(_ConnectionContext, contextlib.closing(conn)) - ctxmgr.connection = cast(_Connection, conn) + conn = psycopg2.connect(dsn, connection_factory=Connection) + ctxmgr = cast(ConnectionContext, contextlib.closing(conn)) + ctxmgr.connection = cast(Connection, conn) return ctxmgr except psycopg2.OperationalError as err: raise UsageError(f"Cannot connect to database: {err}") from err diff --git a/nominatim/tools/freeze.py b/nominatim/tools/freeze.py index b0ebb2c0..39c3279d 100644 --- a/nominatim/tools/freeze.py +++ b/nominatim/tools/freeze.py @@ -7,10 +7,13 @@ """ Functions for removing unnecessary data from the database. """ +from typing import Optional from pathlib import Path from psycopg2 import sql as pysql +from nominatim.db.connection import Connection + UPDATE_TABLES = [ 'address_levels', 'gb_postcode', @@ -25,7 +28,7 @@ UPDATE_TABLES = [ 'wikipedia_%' ] -def drop_update_tables(conn): +def drop_update_tables(conn: Connection) -> None: """ Drop all tables only necessary for updating the database from OSM replication data. """ @@ -42,7 +45,7 @@ def drop_update_tables(conn): conn.commit() -def drop_flatnode_file(fpath): +def drop_flatnode_file(fpath: Optional[Path]) -> None: """ Remove the flatnode file if it exists. """ if fpath and fpath.exists():