add type annotations for database import functions

This commit is contained in:
Sarah Hoffmann 2022-07-17 00:29:34 +02:00
parent 4da1f0da6f
commit a21d4d3ac4

View File

@ -7,6 +7,7 @@
"""
Functions for setting up and importing a new Nominatim database.
"""
from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
import logging
import os
import selectors
@ -16,7 +17,8 @@ from pathlib import Path
import psutil
from psycopg2 import sql as pysql
from nominatim.db.connection import connect, get_pg_env
from nominatim.config import Configuration
from nominatim.db.connection import connect, get_pg_env, Connection
from nominatim.db.async_connection import DBConnection
from nominatim.db.sql_preprocessor import SQLPreprocessor
from nominatim.tools.exec_utils import run_osm2pgsql
@ -25,7 +27,7 @@ from nominatim.version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERS
LOG = logging.getLogger()
def _require_version(module, actual, expected):
def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int, int]) -> None:
""" Compares the version for the given module and raises an exception
if the actual version is too old.
"""
@ -36,7 +38,7 @@ def _require_version(module, actual, expected):
raise UsageError(f'{module} is too old.')
def setup_database_skeleton(dsn, rouser=None):
def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
""" Create a new database for Nominatim and populate it with the
essential extensions.
@ -80,7 +82,9 @@ def setup_database_skeleton(dsn, rouser=None):
POSTGIS_REQUIRED_VERSION)
def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
def import_osm_data(osm_files: Union[str, Sequence[str]],
options: MutableMapping[str, Any],
drop: bool = False, ignore_errors: bool = False) -> None:
""" Import the given OSM files. 'options' contains the list of
default settings for osm2pgsql.
"""
@ -91,7 +95,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
if not options['flatnode_file'] and options['osm2pgsql_cache'] == 0:
# Make some educated guesses about cache size based on the size
# of the import file and the available memory.
mem = psutil.virtual_memory()
mem = psutil.virtual_memory() # type: ignore[no-untyped-call]
fsize = 0
if isinstance(osm_files, list):
for fname in osm_files:
@ -117,7 +121,7 @@ def import_osm_data(osm_files, options, drop=False, ignore_errors=False):
Path(options['flatnode_file']).unlink()
def create_tables(conn, config, reverse_only=False):
def create_tables(conn: Connection, config: Configuration, reverse_only: bool = False) -> None:
""" Create the set of basic tables.
When `reverse_only` is True, then the main table for searching will
be skipped and only reverse search is possible.
@ -128,7 +132,7 @@ def create_tables(conn, config, reverse_only=False):
sql.run_sql_file(conn, 'tables.sql')
def create_table_triggers(conn, config):
def create_table_triggers(conn: Connection, config: Configuration) -> None:
""" Create the triggers for the tables. The trigger functions must already
have been imported with refresh.create_functions().
"""
@ -136,14 +140,14 @@ def create_table_triggers(conn, config):
sql.run_sql_file(conn, 'table-triggers.sql')
def create_partition_tables(conn, config):
def create_partition_tables(conn: Connection, config: Configuration) -> None:
""" Create tables that have explicit partitioning.
"""
sql = SQLPreprocessor(conn, config)
sql.run_sql_file(conn, 'partition-tables.src.sql')
def truncate_data_tables(conn):
def truncate_data_tables(conn: Connection) -> None:
""" Truncate all data tables to prepare for a fresh load.
"""
with conn.cursor() as cur:
@ -174,7 +178,7 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
'extratags', 'geometry')))
def load_data(dsn, threads):
def load_data(dsn: str, threads: int) -> None:
""" Copy data into the word and placex table.
"""
sel = selectors.DefaultSelector()
@ -216,12 +220,12 @@ def load_data(dsn, threads):
print('.', end='', flush=True)
print('\n')
with connect(dsn) as conn:
with conn.cursor() as cur:
with connect(dsn) as syn_conn:
with syn_conn.cursor() as cur:
cur.execute('ANALYSE')
def create_search_indices(conn, config, drop=False):
def create_search_indices(conn: Connection, config: Configuration, drop: bool = False) -> None:
""" Create tables that have explicit partitioning.
"""