diff --git a/src/nominatim_db/clicmd/admin.py b/src/nominatim_db/clicmd/admin.py index 7744595b..1edff174 100644 --- a/src/nominatim_db/clicmd/admin.py +++ b/src/nominatim_db/clicmd/admin.py @@ -12,7 +12,7 @@ import argparse import random from ..errors import UsageError -from ..db.connection import connect +from ..db.connection import connect, table_exists from .args import NominatimArgs # Do not repeat documentation of subcommand classes. @@ -115,7 +115,7 @@ class AdminFuncs: tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config) with connect(args.config.get_libpq_dsn()) as conn: - if conn.table_exists('search_name'): + if table_exists(conn, 'search_name'): words = tokenizer.most_frequent_words(conn, 1000) else: words = [] diff --git a/src/nominatim_db/clicmd/refresh.py b/src/nominatim_db/clicmd/refresh.py index d5acf54b..363bad78 100644 --- a/src/nominatim_db/clicmd/refresh.py +++ b/src/nominatim_db/clicmd/refresh.py @@ -13,7 +13,7 @@ import logging from pathlib import Path from ..config import Configuration -from ..db.connection import connect +from ..db.connection import connect, table_exists from ..tokenizer.base import AbstractTokenizer from .args import NominatimArgs @@ -124,7 +124,7 @@ class UpdateRefresh: with connect(args.config.get_libpq_dsn()) as conn: # If the table did not exist before, then the importance code # needs to be enabled. - if not conn.table_exists('secondary_importance'): + if not table_exists(conn, 'secondary_importance'): args.functions = True LOG.warning('Import secondary importance raster data from %s', args.project_dir) diff --git a/src/nominatim_db/data/country_info.py b/src/nominatim_db/data/country_info.py index c8002ee7..e2bf5133 100644 --- a/src/nominatim_db/data/country_info.py +++ b/src/nominatim_db/data/country_info.py @@ -9,10 +9,9 @@ Functions for importing and managing static country information. """ from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload from pathlib import Path -import psycopg2.extras from ..db import utils as db_utils -from ..db.connection import connect, Connection +from ..db.connection import connect, Connection, register_hstore from ..errors import UsageError from ..config import Configuration from ..tokenizer.base import AbstractTokenizer @@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals params.append((ccode, props['names'], lang, partition)) with connect(dsn) as conn: + register_hstore(conn) with conn.cursor() as cur: - psycopg2.extras.register_hstore(cur) cur.execute( """ CREATE TABLE public.country_name ( country_code character varying(2), @@ -157,8 +156,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer, return ':' not in key or not languages or \ key[key.index(':') + 1:] in languages + register_hstore(conn) with conn.cursor() as cur: - psycopg2.extras.register_hstore(cur) cur.execute("""SELECT country_code, name FROM country_name WHERE country_code is not null""") diff --git a/src/nominatim_db/db/connection.py b/src/nominatim_db/db/connection.py index 8faa3f93..629fad6a 100644 --- a/src/nominatim_db/db/connection.py +++ b/src/nominatim_db/db/connection.py @@ -7,7 +7,8 @@ """ Specialised connection and cursor functions. """ -from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable +from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\ + Tuple, Iterable import contextlib import logging import os @@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor): psycopg2.extras.execute_values(self, sql, argslist, template=template) - def scalar(self, sql: Query, args: Any = None) -> Any: - """ Execute query that returns a single value. The value is returned. - If the query yields more than one row, a ValueError is raised. - """ - self.execute(sql, args) - - if self.rowcount != 1: - raise RuntimeError("Query did not return a single row.") - - result = self.fetchone() - assert result is not None - - return result[0] - - - def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: - """ Drop the table with the given name. - Set `if_exists` to False if a non-existent table should raise - an exception instead of just being ignored. If 'cascade' is set - to True then all dependent tables are deleted as well. - """ - sql = 'DROP TABLE ' - if if_exists: - sql += 'IF EXISTS ' - sql += '{}' - if cascade: - sql += ' CASCADE' - - self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) - - class Connection(psycopg2.extensions.connection): """ A connection that provides the specialised cursor by default and adds convenience functions for administrating the database. @@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection): return super().cursor(cursor_factory=cursor_factory, **kwargs) - def table_exists(self, table: str) -> bool: - """ Check that a table with the given name exists in the database. - """ - with self.cursor() as cur: - num = cur.scalar("""SELECT count(*) FROM pg_tables - WHERE tablename = %s and schemaname = 'public'""", (table, )) - return num == 1 if isinstance(num, int) else False +def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any: + """ Execute query that returns a single value. The value is returned. + If the query yields more than one row, a ValueError is raised. + """ + with conn.cursor() as cur: + cur.execute(sql, args) + + if cur.rowcount != 1: + raise RuntimeError("Query did not return a single row.") + + result = cur.fetchone() + + assert result is not None + return result[0] - def table_has_column(self, table: str, column: str) -> bool: - """ Check if the table 'table' exists and has a column with name 'column'. - """ - with self.cursor() as cur: - has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns - WHERE table_name = %s - and column_name = %s""", - (table, column)) - return has_column > 0 if isinstance(has_column, int) else False +def table_exists(conn: Connection, table: str) -> bool: + """ Check that a table with the given name exists in the database. + """ + num = execute_scalar(conn, + """SELECT count(*) FROM pg_tables + WHERE tablename = %s and schemaname = 'public'""", (table, )) + return num == 1 if isinstance(num, int) else False - def index_exists(self, index: str, table: Optional[str] = None) -> bool: - """ Check that an index with the given name exists in the database. - If table is not None then the index must relate to the given - table. - """ - with self.cursor() as cur: - cur.execute("""SELECT tablename FROM pg_indexes - WHERE indexname = %s and schemaname = 'public'""", (index, )) - if cur.rowcount == 0: +def table_has_column(conn: Connection, table: str, column: str) -> bool: + """ Check if the table 'table' exists and has a column with name 'column'. + """ + has_column = execute_scalar(conn, + """SELECT count(*) FROM information_schema.columns + WHERE table_name = %s and column_name = %s""", + (table, column)) + return has_column > 0 if isinstance(has_column, int) else False + + +def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool: + """ Check that an index with the given name exists in the database. + If table is not None then the index must relate to the given + table. + """ + with conn.cursor() as cur: + cur.execute("""SELECT tablename FROM pg_indexes + WHERE indexname = %s and schemaname = 'public'""", (index, )) + if cur.rowcount == 0: + return False + + if table is not None: + row = cur.fetchone() + if row is None or not isinstance(row[0], str): return False + return row[0] == table - if table is not None: - row = cur.fetchone() - if row is None or not isinstance(row[0], str): - return False - return row[0] == table + return True - return True +def drop_tables(conn: Connection, *names: str, + if_exists: bool = True, cascade: bool = False) -> None: + """ Drop one or more tables with the given names. + Set `if_exists` to False if a non-existent table should raise + an exception instead of just being ignored. `cascade` will cause + depended objects to be dropped as well. + The caller needs to take care of committing the change. + """ + sql = pysql.SQL('DROP TABLE%s{}%s' % ( + ' IF EXISTS ' if if_exists else ' ', + ' CASCADE' if cascade else '')) + + with conn.cursor() as cur: + for name in names: + cur.execute(sql.format(pysql.Identifier(name))) - def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None: - """ Drop the table with the given name. - Set `if_exists` to False if a non-existent table should raise - an exception instead of just being ignored. - """ - with self.cursor() as cur: - cur.drop_table(name, if_exists, cascade) - self.commit() +def server_version_tuple(conn: Connection) -> Tuple[int, int]: + """ Return the server version as a tuple of (major, minor). + Converts correctly for pre-10 and post-10 PostgreSQL versions. + """ + version = conn.server_version + if version < 100000: + return (int(version / 10000), int((version % 10000) / 100)) + + return (int(version / 10000), version % 10000) - def server_version_tuple(self) -> Tuple[int, int]: - """ Return the server version as a tuple of (major, minor). - Converts correctly for pre-10 and post-10 PostgreSQL versions. - """ - version = self.server_version - if version < 100000: - return (int(version / 10000), int((version % 10000) / 100)) +def postgis_version_tuple(conn: Connection) -> Tuple[int, int]: + """ Return the postgis version installed in the database as a + tuple of (major, minor). Assumes that the PostGIS extension + has been installed already. + """ + version = execute_scalar(conn, 'SELECT postgis_lib_version()') - return (int(version / 10000), version % 10000) + version_parts = version.split('.') + if len(version_parts) < 2: + raise UsageError(f"Error fetching Postgis version. Bad format: {version}") + return (int(version_parts[0]), int(version_parts[1])) - def postgis_version_tuple(self) -> Tuple[int, int]: - """ Return the postgis version installed in the database as a - tuple of (major, minor). Assumes that the PostGIS extension - has been installed already. - """ - with self.cursor() as cur: - version = cur.scalar('SELECT postgis_lib_version()') - - version_parts = version.split('.') - if len(version_parts) < 2: - raise UsageError(f"Error fetching Postgis version. Bad format: {version}") - - return (int(version_parts[0]), int(version_parts[1])) +def register_hstore(conn: Connection) -> None: + """ Register the hstore type with psycopg for the connection. + """ + psycopg2.extras.register_hstore(conn) class ConnectionContext(ContextManager[Connection]): diff --git a/src/nominatim_db/db/properties.py b/src/nominatim_db/db/properties.py index 3549382f..0e017ead 100644 --- a/src/nominatim_db/db/properties.py +++ b/src/nominatim_db/db/properties.py @@ -9,7 +9,7 @@ Query and access functions for the in-database property table. """ from typing import Optional, cast -from .connection import Connection +from .connection import Connection, table_exists def set_property(conn: Connection, name: str, value: str) -> None: """ Add or replace the property with the given name. @@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]: """ Return the current value of the given property or None if the property is not set. """ - if not conn.table_exists('nominatim_properties'): + if not table_exists(conn, 'nominatim_properties'): return None with conn.cursor() as cur: diff --git a/src/nominatim_db/db/sql_preprocessor.py b/src/nominatim_db/db/sql_preprocessor.py index 468f3510..691ab6c5 100644 --- a/src/nominatim_db/db/sql_preprocessor.py +++ b/src/nominatim_db/db/sql_preprocessor.py @@ -10,7 +10,7 @@ Preprocessing of SQL files. from typing import Set, Dict, Any, cast import jinja2 -from .connection import Connection +from .connection import Connection, server_version_tuple, postgis_version_tuple from .async_connection import WorkerPool from ..config import Configuration @@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]: """ Set up a dictionary with various optional Postgresql/Postgis features that depend on the database version. """ - pg_version = conn.server_version_tuple() - postgis_version = conn.postgis_version_tuple() + pg_version = server_version_tuple(conn) + postgis_version = postgis_version_tuple(conn) pg11plus = pg_version >= (11, 0, 0) ps3 = postgis_version >= (3, 0) return { diff --git a/src/nominatim_db/db/status.py b/src/nominatim_db/db/status.py index 1278359c..1d2b3bec 100644 --- a/src/nominatim_db/db/status.py +++ b/src/nominatim_db/db/status.py @@ -12,7 +12,7 @@ import datetime as dt import logging import re -from .connection import Connection +from .connection import Connection, table_exists, execute_scalar from ..utils.url_utils import get_url from ..errors import UsageError from ..typing import TypedDict @@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim data base. """ # If there is a date from osm2pgsql available, use that. - if conn.table_exists('osm2pgsql_properties'): + if table_exists(conn, 'osm2pgsql_properties'): with conn.cursor() as cur: cur.execute(""" SELECT value FROM osm2pgsql_properties WHERE property = 'current_timestamp' """) @@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim raise UsageError("Cannot determine database date from data in offline mode.") # Else, find the node with the highest ID in the database - with conn.cursor() as cur: - if conn.table_exists('place'): - osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'") - else: - osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'") + if table_exists(conn, 'place'): + osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'") + else: + osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'") - if osmid is None: - LOG.fatal("No data found in the database.") - raise UsageError("No data found in the database.") + if osmid is None: + LOG.fatal("No data found in the database.") + raise UsageError("No data found in the database.") LOG.info("Using node id %d for timestamp lookup", osmid) # Get the node from the API to find the timestamp when it was created. diff --git a/src/nominatim_db/indexer/indexer.py b/src/nominatim_db/indexer/indexer.py index 5a219f6b..b4c9732c 100644 --- a/src/nominatim_db/indexer/indexer.py +++ b/src/nominatim_db/indexer/indexer.py @@ -15,7 +15,7 @@ import psycopg2.extras from ..typing import DictCursorResults from ..db.async_connection import DBConnection, WorkerPool -from ..db.connection import connect, Connection, Cursor +from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore from ..tokenizer.base import AbstractTokenizer from .progress import ProgressLogger from . import runners @@ -32,15 +32,15 @@ class PlaceFetcher: self.conn: Optional[DBConnection] = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor) - with setup_conn.cursor() as cur: - # need to fetch those manually because register_hstore cannot - # fetch them on an asynchronous connection below. - hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid") - hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid") + # need to fetch those manually because register_hstore cannot + # fetch them on an asynchronous connection below. + hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid") + hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid") psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid, array_oid=hstore_array_oid) + def close(self) -> None: """ Close the underlying asynchronous connection. """ @@ -205,10 +205,9 @@ class Indexer: LOG.warning("Starting %s (using batch size %s)", runner.name(), batch) with connect(self.dsn) as conn: - psycopg2.extras.register_hstore(conn) - with conn.cursor() as cur: - total_tuples = cur.scalar(runner.sql_count_objects()) - LOG.debug("Total number of rows: %i", total_tuples) + register_hstore(conn) + total_tuples = execute_scalar(conn, runner.sql_count_objects()) + LOG.debug("Total number of rows: %i", total_tuples) conn.commit() diff --git a/src/nominatim_db/tokenizer/icu_tokenizer.py b/src/nominatim_db/tokenizer/icu_tokenizer.py index 22e2d048..70c5c27a 100644 --- a/src/nominatim_db/tokenizer/icu_tokenizer.py +++ b/src/nominatim_db/tokenizer/icu_tokenizer.py @@ -16,7 +16,8 @@ import logging from pathlib import Path from textwrap import dedent -from ..db.connection import connect, Connection, Cursor +from ..db.connection import connect, Connection, Cursor, server_version_tuple,\ + drop_tables, table_exists, execute_scalar from ..config import Configuration from ..db.utils import CopyBuffer from ..db.sql_preprocessor import SQLPreprocessor @@ -108,7 +109,7 @@ class ICUTokenizer(AbstractTokenizer): """ Recompute frequencies for all name words. """ with connect(self.dsn) as conn: - if not conn.table_exists('search_name'): + if not table_exists(conn, 'search_name'): return with conn.cursor() as cur: @@ -117,10 +118,9 @@ class ICUTokenizer(AbstractTokenizer): cur.execute('SET max_parallel_workers_per_gather TO %s', (min(threads, 6),)) - if conn.server_version_tuple() < (12, 0): + if server_version_tuple(conn) < (12, 0): LOG.info('Computing word frequencies') - cur.drop_table('word_frequencies') - cur.drop_table('addressword_frequencies') + drop_tables(conn, 'word_frequencies', 'addressword_frequencies') cur.execute("""CREATE TEMP TABLE word_frequencies AS SELECT unnest(name_vector) as id, count(*) FROM search_name GROUP BY id""") @@ -152,17 +152,16 @@ class ICUTokenizer(AbstractTokenizer): $$ LANGUAGE plpgsql IMMUTABLE; """) LOG.info('Update word table with recomputed frequencies') - cur.drop_table('tmp_word') + drop_tables(conn, 'tmp_word') cur.execute("""CREATE TABLE tmp_word AS SELECT word_id, word_token, type, word, word_freq_update(word_id, info) as info FROM word """) - cur.drop_table('word_frequencies') - cur.drop_table('addressword_frequencies') + drop_tables(conn, 'word_frequencies', 'addressword_frequencies') else: LOG.info('Computing word frequencies') - cur.drop_table('word_frequencies') + drop_tables(conn, 'word_frequencies') cur.execute(""" CREATE TEMP TABLE word_frequencies AS WITH word_freq AS MATERIALIZED ( @@ -182,7 +181,7 @@ class ICUTokenizer(AbstractTokenizer): cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)') cur.execute('ANALYSE word_frequencies') LOG.info('Update word table with recomputed frequencies') - cur.drop_table('tmp_word') + drop_tables(conn, 'tmp_word') cur.execute("""CREATE TABLE tmp_word AS SELECT word_id, word_token, type, word, (CASE WHEN wf.info is null THEN word.info @@ -191,7 +190,7 @@ class ICUTokenizer(AbstractTokenizer): FROM word LEFT JOIN word_frequencies wf ON word.word_id = wf.id """) - cur.drop_table('word_frequencies') + drop_tables(conn, 'word_frequencies') with conn.cursor() as cur: cur.execute('SET max_parallel_workers_per_gather TO 0') @@ -210,7 +209,7 @@ class ICUTokenizer(AbstractTokenizer): """ Remove unused house numbers. """ with connect(self.dsn) as conn: - if not conn.table_exists('search_name'): + if not table_exists(conn, 'search_name'): return with conn.cursor(name="hnr_counter") as cur: cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token) @@ -311,8 +310,7 @@ class ICUTokenizer(AbstractTokenizer): frequencies. """ with connect(self.dsn) as conn: - with conn.cursor() as cur: - cur.drop_table('word') + drop_tables(conn, 'word') sqlp = SQLPreprocessor(conn, config) sqlp.run_string(conn, """ CREATE TABLE word ( @@ -370,8 +368,8 @@ class ICUTokenizer(AbstractTokenizer): """ Rename all tables and indexes used by the tokenizer. """ with connect(self.dsn) as conn: + drop_tables(conn, 'word') with conn.cursor() as cur: - cur.drop_table('word') cur.execute(f"ALTER TABLE {old} RENAME TO word") for idx in ('word_token', 'word_id'): cur.execute(f"""ALTER INDEX idx_{old}_{idx} @@ -733,11 +731,10 @@ class ICUNameAnalyzer(AbstractAnalyzer): if norm_name: result = self._cache.housenumbers.get(norm_name, result) if result[0] is None: - with self.conn.cursor() as cur: - hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, )) + hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, )) - result = hid, norm_name - self._cache.housenumbers[norm_name] = result + result = hid, norm_name + self._cache.housenumbers[norm_name] = result else: # Otherwise use the analyzer to determine the canonical name. # Per convention we use the first variant as the 'lookup name', the @@ -748,11 +745,10 @@ class ICUNameAnalyzer(AbstractAnalyzer): if result[0] is None: variants = analyzer.compute_variants(word_id) if variants: - with self.conn.cursor() as cur: - hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)", + hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)", (word_id, list(variants))) - result = hid, variants[0] - self._cache.housenumbers[word_id] = result + result = hid, variants[0] + self._cache.housenumbers[word_id] = result return result diff --git a/src/nominatim_db/tokenizer/legacy_tokenizer.py b/src/nominatim_db/tokenizer/legacy_tokenizer.py index 136a7331..0e8dfcf9 100644 --- a/src/nominatim_db/tokenizer/legacy_tokenizer.py +++ b/src/nominatim_db/tokenizer/legacy_tokenizer.py @@ -18,10 +18,10 @@ from textwrap import dedent from icu import Transliterator import psycopg2 -import psycopg2.extras from ..errors import UsageError -from ..db.connection import connect, Connection +from ..db.connection import connect, Connection, drop_tables, table_exists,\ + execute_scalar, register_hstore from ..config import Configuration from ..db import properties from ..db import utils as db_utils @@ -179,11 +179,10 @@ class LegacyTokenizer(AbstractTokenizer): * Can nominatim.so be accessed by the database user? """ with connect(self.dsn) as conn: - with conn.cursor() as cur: - try: - out = cur.scalar("SELECT make_standard_name('a')") - except psycopg2.Error as err: - return hint.format(error=str(err)) + try: + out = execute_scalar(conn, "SELECT make_standard_name('a')") + except psycopg2.Error as err: + return hint.format(error=str(err)) if out != 'a': return hint.format(error='Unexpected result for make_standard_name()') @@ -214,9 +213,9 @@ class LegacyTokenizer(AbstractTokenizer): """ Recompute the frequency of full words. """ with connect(self.dsn) as conn: - if conn.table_exists('search_name'): + if table_exists(conn, 'search_name'): + drop_tables(conn, "word_frequencies") with conn.cursor() as cur: - cur.drop_table("word_frequencies") LOG.info("Computing word frequencies") cur.execute("""CREATE TEMP TABLE word_frequencies AS SELECT unnest(name_vector) as id, count(*) @@ -226,7 +225,7 @@ class LegacyTokenizer(AbstractTokenizer): cur.execute("""UPDATE word SET search_name_count = count FROM word_frequencies WHERE word_token like ' %' and word_id = id""") - cur.drop_table("word_frequencies") + drop_tables(conn, "word_frequencies") conn.commit() @@ -316,7 +315,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer): self.conn: Optional[Connection] = connect(dsn).connection self.conn.autocommit = True self.normalizer = normalizer - psycopg2.extras.register_hstore(self.conn) + register_hstore(self.conn) self._cache = _TokenCache(self.conn) @@ -536,9 +535,8 @@ class _TokenInfo: def add_names(self, conn: Connection, names: Mapping[str, str]) -> None: """ Add token information for the names of the place. """ - with conn.cursor() as cur: - # Create the token IDs for all names. - self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text", + # Create the token IDs for all names. + self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text", (names, )) @@ -576,9 +574,8 @@ class _TokenInfo: """ Add addr:street match terms. """ def _get_street(name: str) -> Optional[str]: - with conn.cursor() as cur: - return cast(Optional[str], - cur.scalar("SELECT word_ids_from_name(%s)::text", (name, ))) + return cast(Optional[str], + execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, ))) tokens = self.cache.streets.get(street, _get_street) self.data['street'] = tokens or '{}' diff --git a/src/nominatim_db/tools/admin.py b/src/nominatim_db/tools/admin.py index cea2ad66..3e502199 100644 --- a/src/nominatim_db/tools/admin.py +++ b/src/nominatim_db/tools/admin.py @@ -10,12 +10,12 @@ Functions for database analysis and maintenance. from typing import Optional, Tuple, Any, cast import logging -from psycopg2.extras import Json, register_hstore +from psycopg2.extras import Json from psycopg2 import DataError from ..typing import DictCursorResult from ..config import Configuration -from ..db.connection import connect, Cursor +from ..db.connection import connect, Cursor, register_hstore from ..errors import UsageError from ..tokenizer import factory as tokenizer_factory from ..data.place_info import PlaceInfo diff --git a/src/nominatim_db/tools/check_database.py b/src/nominatim_db/tools/check_database.py index ef28a0e5..946f9291 100644 --- a/src/nominatim_db/tools/check_database.py +++ b/src/nominatim_db/tools/check_database.py @@ -12,7 +12,8 @@ from enum import Enum from textwrap import dedent from ..config import Configuration -from ..db.connection import connect, Connection +from ..db.connection import connect, Connection, server_version_tuple,\ + index_exists, table_exists, execute_scalar from ..db import properties from ..errors import UsageError from ..tokenizer import factory as tokenizer_factory @@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]: 'idx_postcode_id', 'idx_postcode_postcode' ] - if conn.table_exists('search_name'): + if table_exists(conn, 'search_name'): indexes.extend(('idx_search_name_nameaddress_vector', 'idx_search_name_name_vector', 'idx_search_name_centroid')) - if conn.server_version_tuple() >= (11, 0, 0): + if server_version_tuple(conn) >= (11, 0, 0): indexes.extend(('idx_placex_housenumber', 'idx_osmline_parent_osm_id_with_hnr')) - if conn.table_exists('place'): + if table_exists(conn, 'place'): indexes.extend(('idx_location_area_country_place_id', 'idx_place_osm_unique', 'idx_placex_rank_address_sector', @@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult: Hints: * Are you connecting to the correct database? - + {instruction} Check the Migration chapter of the Administration Guide. @@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu """ Checking database_version matches Nominatim software version """ - if conn.table_exists('nominatim_properties'): + if table_exists(conn, 'nominatim_properties'): db_version_str = properties.get_property(conn, 'database_version') else: db_version_str = None @@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu def check_placex_table(conn: Connection, config: Configuration) -> CheckResult: """ Checking for placex table """ - if conn.table_exists('placex'): + if table_exists(conn, 'placex'): return CheckState.OK return CheckState.FATAL, dict(config=config) @@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult: def check_placex_size(conn: Connection, _: Configuration) -> CheckResult: """ Checking for placex content """ - with conn.cursor() as cur: - cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x') + cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x') return CheckState.OK if cnt > 0 else CheckState.FATAL @@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult: def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult: """ Checking for wikipedia/wikidata data """ - if not conn.table_exists('search_name') or not conn.table_exists('place'): + if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'): return CheckState.NOT_APPLICABLE - with conn.cursor() as cur: - if conn.table_exists('wikimedia_importance'): - cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance') - else: - cnt = cur.scalar('SELECT count(*) FROM wikipedia_article') + if table_exists(conn, 'wikimedia_importance'): + cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance') + else: + cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article') - return CheckState.WARN if cnt == 0 else CheckState.OK + return CheckState.WARN if cnt == 0 else CheckState.OK @_check(hint="""\ @@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult def check_indexing(conn: Connection, _: Configuration) -> CheckResult: """ Checking indexing status """ - with conn.cursor() as cur: - cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0') + cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0') if cnt == 0: return CheckState.OK @@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult: Low counts of unindexed places are fine.""" return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd) - if conn.index_exists('idx_placex_rank_search'): + if index_exists(conn, 'idx_placex_rank_search'): # Likely just an interrupted update. index_cmd = 'nominatim index' else: @@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult: """ missing = [] for index in _get_indexes(conn): - if not conn.index_exists(index): + if not index_exists(conn, index): missing.append(index) if missing: @@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult: if not config.get_bool('USE_US_TIGER_DATA'): return CheckState.NOT_APPLICABLE - if not conn.table_exists('location_property_tiger'): + if not table_exists(conn, 'location_property_tiger'): return CheckState.FAIL, dict(error='TIGER data table not found.') - with conn.cursor() as cur: - if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0: - return CheckState.FAIL, dict(error='TIGER data table is empty.') + if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0: + return CheckState.FAIL, dict(error='TIGER data table is empty.') return CheckState.OK diff --git a/src/nominatim_db/tools/collect_os_info.py b/src/nominatim_db/tools/collect_os_info.py index e1f8b166..db3e773d 100644 --- a/src/nominatim_db/tools/collect_os_info.py +++ b/src/nominatim_db/tools/collect_os_info.py @@ -12,21 +12,16 @@ import os import subprocess import sys from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union import psutil -from psycopg2.extensions import make_dsn, parse_dsn +from psycopg2.extensions import make_dsn from ..config import Configuration -from ..db.connection import connect +from ..db.connection import connect, server_version_tuple, execute_scalar from ..version import NOMINATIM_VERSION -def convert_version(ver_tup: Tuple[int, int]) -> str: - """converts tuple version (ver_tup) to a string representation""" - return ".".join(map(str, ver_tup)) - - def friendly_memory_string(mem: float) -> str: """Create a user friendly string for the amount of memory specified as mem""" mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB") @@ -103,16 +98,16 @@ def report_system_information(config: Configuration) -> None: storage, and database configuration.""" with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn: - postgresql_ver: str = convert_version(conn.server_version_tuple()) + postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn))) with conn.cursor() as cur: - num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s", - (parse_dsn(config.get_libpq_dsn())['dbname'], )) - nominatim_db_exists = num == 1 if isinstance(num, int) else False + cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s", + (config.get_database_params()['dbname'], )) + nominatim_db_exists = cur.rowcount > 0 if nominatim_db_exists: with connect(config.get_libpq_dsn()) as conn: - postgis_ver: str = convert_version(conn.postgis_version_tuple()) + postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()') else: postgis_ver = "Unable to connect to database" diff --git a/src/nominatim_db/tools/database_import.py b/src/nominatim_db/tools/database_import.py index c4b3023a..2398d404 100644 --- a/src/nominatim_db/tools/database_import.py +++ b/src/nominatim_db/tools/database_import.py @@ -19,7 +19,8 @@ from psycopg2 import sql as pysql from ..errors import UsageError from ..config import Configuration -from ..db.connection import connect, get_pg_env, Connection +from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\ + postgis_version_tuple, drop_tables, table_exists, execute_scalar from ..db.async_connection import DBConnection from ..db.sql_preprocessor import SQLPreprocessor from .exec_utils import run_osm2pgsql @@ -51,10 +52,10 @@ def check_existing_database_plugins(dsn: str) -> None: """ Check that the database has the required plugins installed.""" with connect(dsn) as conn: _require_version('PostgreSQL server', - conn.server_version_tuple(), + server_version_tuple(conn), POSTGRESQL_REQUIRED_VERSION) _require_version('PostGIS', - conn.postgis_version_tuple(), + postgis_version_tuple(conn), POSTGIS_REQUIRED_VERSION) _require_loaded('hstore', conn) @@ -80,31 +81,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None: with connect(dsn) as conn: _require_version('PostgreSQL server', - conn.server_version_tuple(), + server_version_tuple(conn), POSTGRESQL_REQUIRED_VERSION) if rouser is not None: - with conn.cursor() as cur: - cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s', + cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s', (rouser, )) - if cnt == 0: - LOG.fatal("Web user '%s' does not exist. Create it with:\n" - "\n createuser %s", rouser, rouser) - raise UsageError('Missing read-only user.') + if cnt == 0: + LOG.fatal("Web user '%s' does not exist. Create it with:\n" + "\n createuser %s", rouser, rouser) + raise UsageError('Missing read-only user.') # Create extensions. with conn.cursor() as cur: cur.execute('CREATE EXTENSION IF NOT EXISTS hstore') cur.execute('CREATE EXTENSION IF NOT EXISTS postgis') - postgis_version = conn.postgis_version_tuple() + postgis_version = postgis_version_tuple(conn) if postgis_version[0] >= 3: cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster') conn.commit() _require_version('PostGIS', - conn.postgis_version_tuple(), + postgis_version_tuple(conn), POSTGIS_REQUIRED_VERSION) @@ -141,7 +141,8 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]], raise UsageError('No data imported by osm2pgsql.') if drop: - conn.drop_table('planet_osm_nodes') + drop_tables(conn, 'planet_osm_nodes') + conn.commit() if drop and options['flatnode_file']: Path(options['flatnode_file']).unlink() @@ -184,7 +185,7 @@ def truncate_data_tables(conn: Connection) -> None: cur.execute('TRUNCATE location_property_tiger') cur.execute('TRUNCATE location_property_osmline') cur.execute('TRUNCATE location_postcode') - if conn.table_exists('search_name'): + if table_exists(conn, 'search_name'): cur.execute('TRUNCATE search_name') cur.execute('DROP SEQUENCE IF EXISTS seq_place') cur.execute('CREATE SEQUENCE seq_place start 100000') diff --git a/src/nominatim_db/tools/freeze.py b/src/nominatim_db/tools/freeze.py index bd52ba9a..e6d80e1e 100644 --- a/src/nominatim_db/tools/freeze.py +++ b/src/nominatim_db/tools/freeze.py @@ -12,7 +12,7 @@ from pathlib import Path from psycopg2 import sql as pysql -from ..db.connection import Connection +from ..db.connection import Connection, drop_tables, table_exists UPDATE_TABLES = [ 'address_levels', @@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None: + pysql.SQL(' or ').join(parts)) tables = [r[0] for r in cur] - for table in tables: - cur.drop_table(table, cascade=True) - + drop_tables(conn, *tables, cascade=True) conn.commit() @@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool: """ Returns true if database is in a frozen state """ - return conn.table_exists('place') is False + return table_exists(conn, 'place') is False diff --git a/src/nominatim_db/tools/migration.py b/src/nominatim_db/tools/migration.py index e6803c7d..46ba0125 100644 --- a/src/nominatim_db/tools/migration.py +++ b/src/nominatim_db/tools/migration.py @@ -15,7 +15,8 @@ from psycopg2 import sql as pysql from ..errors import UsageError from ..config import Configuration from ..db import properties -from ..db.connection import connect, Connection +from ..db.connection import connect, Connection, server_version_tuple,\ + table_has_column, table_exists, execute_scalar, register_hstore from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version from ..tokenizer import factory as tokenizer_factory from . import refresh @@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int: if necesssary. """ with connect(config.get_libpq_dsn()) as conn: - if conn.table_exists('nominatim_properties'): + register_hstore(conn) + if table_exists(conn, 'nominatim_properties'): db_version_str = properties.get_property(conn, 'database_version') else: db_version_str = None @@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion: Only migrations for 3.6 and later are supported, so bail out when the version seems older. """ - with conn.cursor() as cur: - # In version 3.6, the country_name table was updated. Check for that. - cnt = cur.scalar("""SELECT count(*) FROM - (SELECT svals(name) FROM country_name - WHERE country_code = 'gb')x; - """) - if cnt < 100: - LOG.fatal('It looks like your database was imported with a version ' - 'prior to 3.6.0. Automatic migration not possible.') - raise UsageError('Migration not possible.') + # In version 3.6, the country_name table was updated. Check for that. + cnt = execute_scalar(conn, """SELECT count(*) FROM + (SELECT svals(name) FROM country_name + WHERE country_code = 'gb')x; + """) + if cnt < 100: + LOG.fatal('It looks like your database was imported with a version ' + 'prior to 3.6.0. Automatic migration not possible.') + raise UsageError('Migration not possible.') return NominatimVersion(3, 5, 0, 99) @@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None: def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None: """ Add nominatim_property table. """ - if not conn.table_exists('nominatim_properties'): + if not table_exists(conn, 'nominatim_properties'): with conn.cursor() as cur: cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties ( property TEXT, @@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any) configuration for the backwards-compatible legacy tokenizer """ if properties.get_property(conn, 'tokenizer') is None: - with conn.cursor() as cur: - for table in ('placex', 'location_property_osmline'): - has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns - WHERE table_name = %s - and column_name = 'token_info'""", - (table, )) - if has_column == 0: + for table in ('placex', 'location_property_osmline'): + if not table_has_column(conn, table, 'token_info'): + with conn.cursor() as cur: cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB') .format(pysql.Identifier(table))) tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False, @@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None: The inclusion is needed for efficient lookup of housenumbers in full address searches. """ - if conn.server_version_tuple() >= (11, 0, 0): + if server_version_tuple(conn) >= (11, 0, 0): with conn.cursor() as cur: cur.execute(""" CREATE INDEX IF NOT EXISTS idx_location_property_tiger_housenumber_migrated @@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None: Also converts the data into the stricter format which requires that startnumbers comply with the odd/even requirements. """ - if conn.table_has_column('location_property_osmline', 'step'): + if table_has_column(conn, 'location_property_osmline', 'step'): return with conn.cursor() as cur: @@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None: def add_step_column_for_tiger(conn: Connection, **_: Any) -> None: """ Add a new column 'step' to the tiger data table. """ - if conn.table_has_column('location_property_tiger', 'step'): + if table_has_column(conn, 'location_property_tiger', 'step'): return with conn.cursor() as cur: @@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non """ Add a new column 'derived_name' which in the future takes the country names as imported from OSM data. """ - if not conn.table_has_column('country_name', 'derived_name'): + if not table_has_column(conn, 'country_name', 'derived_name'): with conn.cursor() as cur: cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE") @@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An """ Names from the country table should be marked as internal to prevent them from being deleted. Only necessary for ICU tokenizer. """ - import psycopg2.extras # pylint: disable=import-outside-toplevel - tokenizer = tokenizer_factory.get_tokenizer_for_db(config) with tokenizer.name_analyzer() as analyzer: with conn.cursor() as cur: - psycopg2.extras.register_hstore(cur) cur.execute("SELECT country_code, name FROM country_name") for country_code, names in cur: @@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None: The table is only necessary when updates are possible, i.e. the database is not in freeze mode. """ - if conn.table_exists('place'): + if table_exists(conn, 'place'): with conn.cursor() as cur: cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted ( osm_type CHAR(1), @@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None: def split_pending_index(conn: Connection, **_: Any) -> None: """ Reorganise indexes for pending updates. """ - if conn.table_exists('place'): + if table_exists(conn, 'place'): with conn.cursor() as cur: cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector ON placex USING BTREE (rank_address, geometry_sector) @@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None: def enable_forward_dependencies(conn: Connection, **_: Any) -> None: """ Create indexes for updates with forward dependency tracking (long-running). """ - if conn.table_exists('planet_osm_ways'): + if table_exists(conn, 'planet_osm_ways'): with conn.cursor() as cur: cur.execute("""SELECT * FROM pg_indexes WHERE tablename = 'planet_osm_ways' @@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None: def create_postcode_parent_index(conn: Connection, **_: Any) -> None: """ Create index needed for updating postcodes when a parent changes. """ - if conn.table_exists('planet_osm_ways'): + if table_exists(conn, 'planet_osm_ways'): with conn.cursor() as cur: cur.execute("""CREATE INDEX IF NOT EXISTS idx_location_postcode_parent_place_id diff --git a/src/nominatim_db/tools/postcodes.py b/src/nominatim_db/tools/postcodes.py index 8dc5bdbd..a5d8ef8b 100644 --- a/src/nominatim_db/tools/postcodes.py +++ b/src/nominatim_db/tools/postcodes.py @@ -18,7 +18,7 @@ from math import isfinite from psycopg2 import sql as pysql -from ..db.connection import connect, Connection +from ..db.connection import connect, Connection, table_exists from ..utils.centroid import PointsCentroid from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer @@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool: postcodes can be computed. """ with connect(dsn) as conn: - return conn.table_exists('place') + return table_exists(conn, 'place') diff --git a/src/nominatim_db/tools/refresh.py b/src/nominatim_db/tools/refresh.py index 6a40c0a7..2e2ffabd 100644 --- a/src/nominatim_db/tools/refresh.py +++ b/src/nominatim_db/tools/refresh.py @@ -17,7 +17,8 @@ from pathlib import Path from psycopg2 import sql as pysql from ..config import Configuration -from ..db.connection import Connection, connect +from ..db.connection import Connection, connect, postgis_version_tuple,\ + drop_tables, table_exists from ..db.utils import execute_file, CopyBuffer from ..db.sql_preprocessor import SQLPreprocessor from ..version import NOMINATIM_VERSION @@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s for entry in levels: _add_address_level_rows_from_entry(rows, entry) - with conn.cursor() as cur: - cur.drop_table(table) + drop_tables(conn, table) + with conn.cursor() as cur: cur.execute(pysql.SQL("""CREATE TABLE {} ( country_code varchar(2), class TEXT, @@ -159,10 +160,8 @@ def import_importance_csv(dsn: str, data_file: Path) -> int: wd_done = set() with connect(dsn) as conn: + drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance') with conn.cursor() as cur: - cur.drop_table('wikipedia_article') - cur.drop_table('wikipedia_redirect') - cur.drop_table('wikimedia_importance') cur.execute("""CREATE TABLE wikimedia_importance ( language TEXT NOT NULL, title TEXT NOT NULL, @@ -228,7 +227,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool = return 1 with connect(dsn) as conn: - postgis_version = conn.postgis_version_tuple() + postgis_version = postgis_version_tuple(conn) if postgis_version[0] < 3: LOG.error('PostGIS version is too old for using OSM raster data.') return 2 @@ -309,7 +308,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non template = "\nrequire_once(CONST_LibDir.'/website/{}');\n" - search_name_table_exists = bool(conn and conn.table_exists('search_name')) + search_name_table_exists = bool(conn and table_exists(conn, 'search_name')) for script in WEBSITE_SCRIPTS: if not search_name_table_exists and script == 'search.php': diff --git a/src/nominatim_db/tools/replication.py b/src/nominatim_db/tools/replication.py index bf1189df..2b1d444f 100644 --- a/src/nominatim_db/tools/replication.py +++ b/src/nominatim_db/tools/replication.py @@ -20,7 +20,7 @@ import requests from ..errors import UsageError from ..db import status -from ..db.connection import Connection, connect +from ..db.connection import Connection, connect, server_version_tuple from .exec_utils import run_osm2pgsql try: @@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) - # Consume updates with osm2pgsql. options['append'] = True - options['disable_jit'] = conn.server_version_tuple() >= (11, 0) + options['disable_jit'] = server_version_tuple(conn) >= (11, 0) run_osm2pgsql(options) # Handle deletions diff --git a/src/nominatim_db/tools/special_phrases/sp_importer.py b/src/nominatim_db/tools/special_phrases/sp_importer.py index 1bdcdaf1..4a63ff14 100644 --- a/src/nominatim_db/tools/special_phrases/sp_importer.py +++ b/src/nominatim_db/tools/special_phrases/sp_importer.py @@ -21,7 +21,7 @@ from psycopg2.sql import Identifier, SQL from ...typing import Protocol from ...config import Configuration -from ...db.connection import Connection +from ...db.connection import Connection, drop_tables, index_exists from .importer_statistics import SpecialPhrasesImporterStatistics from .special_phrase import SpecialPhrase from ...tokenizer.base import AbstractTokenizer @@ -233,7 +233,7 @@ class SPImporter(): index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_' base_table = _classtype_table(phrase_class, phrase_type) # Index on centroid - if not self.db_connection.index_exists(index_prefix + 'centroid'): + if not index_exists(self.db_connection, index_prefix + 'centroid'): with self.db_connection.cursor() as db_cursor: db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}") .format(Identifier(index_prefix + 'centroid'), @@ -241,7 +241,7 @@ class SPImporter(): SQL(sql_tablespace))) # Index on place_id - if not self.db_connection.index_exists(index_prefix + 'place_id'): + if not index_exists(self.db_connection, index_prefix + 'place_id'): with self.db_connection.cursor() as db_cursor: db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}") .format(Identifier(index_prefix + 'place_id'), @@ -259,6 +259,7 @@ class SPImporter(): .format(Identifier(table_name), Identifier(self.config.DATABASE_WEBUSER))) + def _remove_non_existent_tables_from_db(self) -> None: """ Remove special phrases which doesn't exist on the wiki anymore. @@ -268,7 +269,6 @@ class SPImporter(): # Delete place_classtype tables corresponding to class/type which # are not on the wiki anymore. - with self.db_connection.cursor() as db_cursor: - for table in self.table_phrases_to_delete: - self.statistics_handler.notify_one_table_deleted() - db_cursor.drop_table(table) + drop_tables(self.db_connection, *self.table_phrases_to_delete) + for _ in self.table_phrases_to_delete: + self.statistics_handler.notify_one_table_deleted() diff --git a/test/python/db/test_async_connection.py b/test/python/db/test_async_connection.py index fff695e5..9647bedc 100644 --- a/test/python/db/test_async_connection.py +++ b/test/python/db/test_async_connection.py @@ -33,13 +33,13 @@ def simple_conns(temp_db): conn2.close() -def test_simple_query(conn, temp_db_conn): +def test_simple_query(conn, temp_db_cursor): conn.connect() conn.perform('CREATE TABLE foo (id INT)') conn.wait() - temp_db_conn.table_exists('foo') + assert temp_db_cursor.table_exists('foo') def test_wait_for_query(conn): diff --git a/test/python/db/test_connection.py b/test/python/db/test_connection.py index 8b4cc62f..9f1442f3 100644 --- a/test/python/db/test_connection.py +++ b/test/python/db/test_connection.py @@ -10,61 +10,74 @@ Tests for specialised connection and cursor classes. import pytest import psycopg2 -from nominatim_db.db.connection import connect, get_pg_env +import nominatim_db.db.connection as nc @pytest.fixture def db(dsn): - with connect(dsn) as conn: + with nc.connect(dsn) as conn: yield conn def test_connection_table_exists(db, table_factory): - assert not db.table_exists('foobar') + assert not nc.table_exists(db, 'foobar') table_factory('foobar') - assert db.table_exists('foobar') + assert nc.table_exists(db, 'foobar') def test_has_column_no_table(db): - assert not db.table_has_column('sometable', 'somecolumn') + assert not nc.table_has_column(db, 'sometable', 'somecolumn') @pytest.mark.parametrize('name,result', [('tram', True), ('car', False)]) def test_has_column(db, table_factory, name, result): table_factory('stuff', 'tram TEXT') - assert db.table_has_column('stuff', name) == result + assert nc.table_has_column(db, 'stuff', name) == result def test_connection_index_exists(db, table_factory, temp_db_cursor): - assert not db.index_exists('some_index') + assert not nc.index_exists(db, 'some_index') table_factory('foobar') temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)') - assert db.index_exists('some_index') - assert db.index_exists('some_index', table='foobar') - assert not db.index_exists('some_index', table='bar') + assert nc.index_exists(db, 'some_index') + assert nc.index_exists(db, 'some_index', table='foobar') + assert not nc.index_exists(db, 'some_index', table='bar') def test_drop_table_existing(db, table_factory): table_factory('dummy') - assert db.table_exists('dummy') + assert nc.table_exists(db, 'dummy') - db.drop_table('dummy') - assert not db.table_exists('dummy') + nc.drop_tables(db, 'dummy') + assert not nc.table_exists(db, 'dummy') -def test_drop_table_non_existsing(db): - db.drop_table('dfkjgjriogjigjgjrdghehtre') +def test_drop_table_non_existing(db): + nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre') + + +def test_drop_many_tables(db, table_factory): + tables = [f'table{n}' for n in range(5)] + + for t in tables: + table_factory(t) + assert nc.table_exists(db, t) + + nc.drop_tables(db, *tables) + + for t in tables: + assert not nc.table_exists(db, t) def test_drop_table_non_existing_force(db): with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'): - db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False) + nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False) def test_connection_server_version_tuple(db): - ver = db.server_version_tuple() + ver = nc.server_version_tuple(db) assert isinstance(ver, tuple) assert len(ver) == 2 @@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db): def test_connection_postgis_version_tuple(db, temp_db_with_extensions): - ver = db.postgis_version_tuple() + ver = nc.postgis_version_tuple(db) assert isinstance(ver, tuple) assert len(ver) == 2 @@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions): def test_cursor_scalar(db, table_factory): table_factory('dummy') - with db.cursor() as cur: - assert cur.scalar('SELECT count(*) FROM dummy') == 0 + assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0 def test_cursor_scalar_many_rows(db): - with db.cursor() as cur: - with pytest.raises(RuntimeError): - cur.scalar('SELECT * FROM pg_tables') + with pytest.raises(RuntimeError, match='Query did not return a single row.'): + nc.execute_scalar(db, 'SELECT * FROM pg_tables') def test_cursor_scalar_no_rows(db, table_factory): table_factory('dummy') - with db.cursor() as cur: - with pytest.raises(RuntimeError): - cur.scalar('SELECT id FROM dummy') + with pytest.raises(RuntimeError, match='Query did not return a single row.'): + nc.execute_scalar(db, 'SELECT id FROM dummy') def test_get_pg_env_add_variable(monkeypatch): monkeypatch.delenv('PGPASSWORD', raising=False) - env = get_pg_env('user=fooF') + env = nc.get_pg_env('user=fooF') assert env['PGUSER'] == 'fooF' assert 'PGPASSWORD' not in env @@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch): def test_get_pg_env_overwrite_variable(monkeypatch): monkeypatch.setenv('PGUSER', 'some default') - env = get_pg_env('user=overwriter') + env = nc.get_pg_env('user=overwriter') assert env['PGUSER'] == 'overwriter' def test_get_pg_env_ignore_unknown(): - env = get_pg_env('client_encoding=stuff', base_env={}) + env = nc.get_pg_env('client_encoding=stuff', base_env={}) assert env == {} diff --git a/test/python/mock_icu_word_table.py b/test/python/mock_icu_word_table.py index 5c465e8b..67be1892 100644 --- a/test/python/mock_icu_word_table.py +++ b/test/python/mock_icu_word_table.py @@ -8,6 +8,7 @@ Legacy word table for testing with functions to prefil and test contents of the table. """ +from nominatim_db.db.connection import execute_scalar class MockIcuWordTable: """ A word table for testing using legacy word table structure. @@ -77,18 +78,15 @@ class MockIcuWordTable: def count(self): - with self.conn.cursor() as cur: - return cur.scalar("SELECT count(*) FROM word") + return execute_scalar(self.conn, "SELECT count(*) FROM word") def count_special(self): - with self.conn.cursor() as cur: - return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'") + return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'") def count_housenumbers(self): - with self.conn.cursor() as cur: - return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'") + return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'") def get_special(self): diff --git a/test/python/mock_legacy_word_table.py b/test/python/mock_legacy_word_table.py index 9804341f..d1c523eb 100644 --- a/test/python/mock_legacy_word_table.py +++ b/test/python/mock_legacy_word_table.py @@ -8,6 +8,7 @@ Legacy word table for testing with functions to prefil and test contents of the table. """ +from nominatim_db.db.connection import execute_scalar class MockLegacyWordTable: """ A word table for testing using legacy word table structure. @@ -58,13 +59,11 @@ class MockLegacyWordTable: def count(self): - with self.conn.cursor() as cur: - return cur.scalar("SELECT count(*) FROM word") + return execute_scalar(self.conn, "SELECT count(*) FROM word") def count_special(self): - with self.conn.cursor() as cur: - return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'") + return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'") def get_special(self): diff --git a/test/python/tokenizer/test_icu.py b/test/python/tokenizer/test_icu.py index 357b7d4a..a2bf6766 100644 --- a/test/python/tokenizer/test_icu.py +++ b/test/python/tokenizer/test_icu.py @@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor, assert test_content == set((('1133', ), )) -def test_finalize_import(tokenizer_factory, temp_db_conn, - temp_db_cursor, test_config, sql_preprocessor_cfg): +def test_finalize_import(tokenizer_factory, temp_db_cursor, + test_config, sql_preprocessor_cfg): tok = tokenizer_factory() tok.init_new_db(test_config) - assert not temp_db_conn.index_exists('idx_word_word_id') + assert not temp_db_cursor.index_exists('word', 'idx_word_word_id') tok.finalize_import(test_config) - assert temp_db_conn.index_exists('idx_word_word_id') + assert temp_db_cursor.index_exists('word', 'idx_word_word_id') def test_check_database(test_config, tokenizer_factory, diff --git a/test/python/tools/test_database_import.py b/test/python/tools/test_database_import.py index 548ec800..9d56efa0 100644 --- a/test/python/tools/test_database_import.py +++ b/test/python/tools/test_database_import.py @@ -132,7 +132,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options) ignore_errors=True) -def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options): +def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options): table_factory('place', content=((1, ), )) table_factory('planet_osm_nodes') @@ -144,7 +144,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True) assert not flatfile.exists() - assert not temp_db_conn.table_exists('planet_osm_nodes') + assert not temp_db_cursor.table_exists('planet_osm_nodes') def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd): diff --git a/test/python/tools/test_import_special_phrases.py b/test/python/tools/test_import_special_phrases.py index 64eb7b18..0d33e6e0 100644 --- a/test/python/tools/test_import_special_phrases.py +++ b/test/python/tools/test_import_special_phrases.py @@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer): assert isinstance(black_list, dict) and isinstance(white_list, dict) -def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn, +def test_create_place_classtype_indexes(temp_db_with_extensions, + temp_db_conn, temp_db_cursor, table_factory, sp_importer): """ Test that _create_place_classtype_indexes() create the @@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn, table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY') sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type) + temp_db_conn.commit() - assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type) + assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type) -def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer): +def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer): """ Test that _create_place_classtype_table() create the right place_classtype table. @@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer): phrase_class = 'class' phrase_type = 'type' sp_importer._create_place_classtype_table('', phrase_class, phrase_type) + temp_db_conn.commit() - assert check_table_exist(temp_db_conn, phrase_class, phrase_type) + assert check_table_exist(temp_db_cursor, phrase_class, phrase_type) -def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer): +def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory, + def_config, sp_importer): """ Test that _grant_access_to_webuser() give right access to the web user. @@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im table_factory(table_name) sp_importer._grant_access_to_webuser(phrase_class, phrase_type) + temp_db_conn.commit() - assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type) + assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type) def test_create_place_classtype_table_and_indexes( - temp_db_conn, def_config, placex_table, - sp_importer): + temp_db_cursor, def_config, placex_table, + sp_importer, temp_db_conn): """ Test that _create_place_classtype_table_and_indexes() create the right place_classtype tables and place_id indexes @@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes( pairs = set([('class1', 'type1'), ('class2', 'type2')]) sp_importer._create_classtype_table_and_indexes(pairs) + temp_db_conn.commit() for pair in pairs: - assert check_table_exist(temp_db_conn, pair[0], pair[1]) - assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1]) - assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1]) + assert check_table_exist(temp_db_cursor, pair[0], pair[1]) + assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1]) + assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1]) def test_remove_non_existent_tables_from_db(sp_importer, default_phrases, - temp_db_conn): + temp_db_conn, temp_db_cursor): """ Check for the remove_non_existent_phrases_from_db() method. @@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases, """ sp_importer._remove_non_existent_tables_from_db() + temp_db_conn.commit() - # Changes are not committed yet. Use temp_db_conn for checking results. - with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur: - assert cur.row_set(query_tables) \ + assert temp_db_cursor.row_set(query_tables) \ == {('place_classtype_testclasstypetable_to_keep', )} @pytest.mark.parametrize("should_replace", [(True), (False)]) -def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer, +def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer, placex_table, table_factory, tokenizer_mock, xml_wiki_content, should_replace): """ @@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer, class_test = 'aerialway' type_test = 'zip_line' - assert check_table_exist(temp_db_conn, class_test, type_test) - assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test) - assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test) - assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter') + assert check_table_exist(temp_db_cursor, class_test, type_test) + assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test) + assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test) + assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter') if should_replace: - assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type') + assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type') - assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter') + assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter') if should_replace: - assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype') + assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype') -def check_table_exist(temp_db_conn, phrase_class, phrase_type): +def check_table_exist(temp_db_cursor, phrase_class, phrase_type): """ Verify that the place_classtype table exists for the given phrase_class and phrase_type. """ - return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type)) + return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type)) -def check_grant_access(temp_db_conn, user, phrase_class, phrase_type): +def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type): """ Check that the web user has been granted right access to the place_classtype table of the given phrase_class and phrase_type. """ table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type) - with temp_db_conn.cursor() as temp_db_cursor: - temp_db_cursor.execute(""" - SELECT * FROM information_schema.role_table_grants - WHERE table_name='{}' - AND grantee='{}' - AND privilege_type='SELECT'""".format(table_name, user)) - return temp_db_cursor.fetchone() + temp_db_cursor.execute(""" + SELECT * FROM information_schema.role_table_grants + WHERE table_name='{}' + AND grantee='{}' + AND privilege_type='SELECT'""".format(table_name, user)) + return temp_db_cursor.fetchone() -def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type): +def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type): """ Check that the place_id index and centroid index exist for the place_classtype table of the given phrase_class and phrase_type. """ + table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type) index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type) return ( - temp_db_conn.index_exists(index_prefix + 'centroid') + temp_db_cursor.index_exists(table_name, index_prefix + 'centroid') and - temp_db_conn.index_exists(index_prefix + 'place_id') + temp_db_cursor.index_exists(table_name, index_prefix + 'place_id') ) diff --git a/test/python/tools/test_migration.py b/test/python/tools/test_migration.py index 3a849adb..8821f694 100644 --- a/test/python/tools/test_migration.py +++ b/test/python/tools/test_migration.py @@ -12,6 +12,7 @@ import psycopg2.extras from nominatim_db.tools import migration from nominatim_db.errors import UsageError +from nominatim_db.db.connection import server_version_tuple import nominatim_db.version from mock_legacy_word_table import MockLegacyWordTable @@ -63,7 +64,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor, WHERE property = 'database_version'""") -def test_already_at_version(def_config, property_table): +def test_already_at_version(temp_db_with_extensions, def_config, property_table): property_table.set('database_version', str(nominatim_db.version.NOMINATIM_VERSION)) @@ -71,8 +72,8 @@ def test_already_at_version(def_config, property_table): assert migration.migrate(def_config, {}) == 0 -def test_run_single_migration(def_config, temp_db_cursor, property_table, - monkeypatch, postprocess_mock): +def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor, + property_table, monkeypatch, postprocess_mock): oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION] oldversion[0] -= 1 property_table.set('database_version', @@ -226,7 +227,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact migration.create_tiger_housenumber_index(temp_db_conn) temp_db_conn.commit() - if temp_db_conn.server_version_tuple() >= (11, 0, 0): + if server_version_tuple(temp_db_conn) >= (11, 0, 0): assert temp_db_cursor.index_exists('location_property_tiger', 'idx_location_property_tiger_housenumber_migrated') diff --git a/test/python/tools/test_refresh.py b/test/python/tools/test_refresh.py index 50ff6398..1f1968cf 100644 --- a/test/python/tools/test_refresh.py +++ b/test/python/tools/test_refresh.py @@ -12,6 +12,7 @@ from pathlib import Path import pytest from nominatim_db.tools import refresh +from nominatim_db.db.connection import postgis_version_tuple def test_refresh_import_wikipedia_not_existing(dsn): assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1 @@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn): def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor): temp_db_cursor.execute('CREATE EXTENSION postgis') - if temp_db_conn.postgis_version_tuple()[0] < 3: + if postgis_version_tuple(temp_db_conn)[0] < 3: assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0 else: temp_db_cursor.execute('CREATE EXTENSION postgis_raster') assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0 - assert temp_db_conn.table_exists('secondary_importance') + assert temp_db_cursor.table_exists('secondary_importance') @pytest.mark.parametrize("replace", (True, False)) diff --git a/test/python/tools/test_tiger_data.py b/test/python/tools/test_tiger_data.py index 7ef6a1e6..fc01f22f 100644 --- a/test/python/tools/test_tiger_data.py +++ b/test/python/tools/test_tiger_data.py @@ -12,6 +12,7 @@ from textwrap import dedent import pytest +from nominatim_db.db.connection import execute_scalar from nominatim_db.tools import tiger_data, freeze from nominatim_db.errors import UsageError @@ -31,8 +32,7 @@ class MockTigerTable: cur.execute("CREATE TABLE place (number INTEGER)") def count(self): - with self.conn.cursor() as cur: - return cur.scalar("SELECT count(*) FROM tiger") + return execute_scalar(self.conn, "SELECT count(*) FROM tiger") def row(self): with self.conn.cursor() as cur: