make DB helper functions free functions

Also changes the drop function so that it can drop multiple tables
at once.
This commit is contained in:
Sarah Hoffmann 2024-07-02 15:15:50 +02:00
parent 71249bd94a
commit 3742fa2929
30 changed files with 347 additions and 364 deletions

View File

@ -12,7 +12,7 @@ import argparse
import random
from ..errors import UsageError
from ..db.connection import connect
from ..db.connection import connect, table_exists
from .args import NominatimArgs
# Do not repeat documentation of subcommand classes.
@ -115,7 +115,7 @@ class AdminFuncs:
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
with connect(args.config.get_libpq_dsn()) as conn:
if conn.table_exists('search_name'):
if table_exists(conn, 'search_name'):
words = tokenizer.most_frequent_words(conn, 1000)
else:
words = []

View File

@ -13,7 +13,7 @@ import logging
from pathlib import Path
from ..config import Configuration
from ..db.connection import connect
from ..db.connection import connect, table_exists
from ..tokenizer.base import AbstractTokenizer
from .args import NominatimArgs
@ -124,7 +124,7 @@ class UpdateRefresh:
with connect(args.config.get_libpq_dsn()) as conn:
# If the table did not exist before, then the importance code
# needs to be enabled.
if not conn.table_exists('secondary_importance'):
if not table_exists(conn, 'secondary_importance'):
args.functions = True
LOG.warning('Import secondary importance raster data from %s', args.project_dir)

View File

@ -9,10 +9,9 @@ Functions for importing and managing static country information.
"""
from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
from pathlib import Path
import psycopg2.extras
from ..db import utils as db_utils
from ..db.connection import connect, Connection
from ..db.connection import connect, Connection, register_hstore
from ..errors import UsageError
from ..config import Configuration
from ..tokenizer.base import AbstractTokenizer
@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
params.append((ccode, props['names'], lang, partition))
with connect(dsn) as conn:
register_hstore(conn)
with conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute(
""" CREATE TABLE public.country_name (
country_code character varying(2),
@ -157,8 +156,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
return ':' not in key or not languages or \
key[key.index(':') + 1:] in languages
register_hstore(conn)
with conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute("""SELECT country_code, name FROM country_name
WHERE country_code is not null""")

View File

@ -7,7 +7,8 @@
"""
Specialised connection and cursor functions.
"""
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload,\
Tuple, Iterable
import contextlib
import logging
import os
@ -46,37 +47,6 @@ class Cursor(psycopg2.extras.DictCursor):
psycopg2.extras.execute_values(self, sql, argslist, template=template)
def scalar(self, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
self.execute(sql, args)
if self.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = self.fetchone()
assert result is not None
return result[0]
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. If 'cascade' is set
to True then all dependent tables are deleted as well.
"""
sql = 'DROP TABLE '
if if_exists:
sql += 'IF EXISTS '
sql += '{}'
if cascade:
sql += ' CASCADE'
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
class Connection(psycopg2.extensions.connection):
""" A connection that provides the specialised cursor by default and
adds convenience functions for administrating the database.
@ -99,80 +69,105 @@ class Connection(psycopg2.extensions.connection):
return super().cursor(cursor_factory=cursor_factory, **kwargs)
def table_exists(self, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
with self.cursor() as cur:
num = cur.scalar("""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
def execute_scalar(conn: Connection, sql: Query, args: Any = None) -> Any:
""" Execute query that returns a single value. The value is returned.
If the query yields more than one row, a ValueError is raised.
"""
with conn.cursor() as cur:
cur.execute(sql, args)
if cur.rowcount != 1:
raise RuntimeError("Query did not return a single row.")
result = cur.fetchone()
assert result is not None
return result[0]
def table_has_column(self, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
with self.cursor() as cur:
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s
and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def table_exists(conn: Connection, table: str) -> bool:
""" Check that a table with the given name exists in the database.
"""
num = execute_scalar(conn,
"""SELECT count(*) FROM pg_tables
WHERE tablename = %s and schemaname = 'public'""", (table, ))
return num == 1 if isinstance(num, int) else False
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with self.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
def table_has_column(conn: Connection, table: str, column: str) -> bool:
""" Check if the table 'table' exists and has a column with name 'column'.
"""
has_column = execute_scalar(conn,
"""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s and column_name = %s""",
(table, column))
return has_column > 0 if isinstance(has_column, int) else False
def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
""" Check that an index with the given name exists in the database.
If table is not None then the index must relate to the given
table.
"""
with conn.cursor() as cur:
cur.execute("""SELECT tablename FROM pg_indexes
WHERE indexname = %s and schemaname = 'public'""", (index, ))
if cur.rowcount == 0:
return False
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
if table is not None:
row = cur.fetchone()
if row is None or not isinstance(row[0], str):
return False
return row[0] == table
return True
return True
def drop_tables(conn: Connection, *names: str,
if_exists: bool = True, cascade: bool = False) -> None:
""" Drop one or more tables with the given names.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored. `cascade` will cause
depended objects to be dropped as well.
The caller needs to take care of committing the change.
"""
sql = pysql.SQL('DROP TABLE%s{}%s' % (
' IF EXISTS ' if if_exists else ' ',
' CASCADE' if cascade else ''))
with conn.cursor() as cur:
for name in names:
cur.execute(sql.format(pysql.Identifier(name)))
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
""" Drop the table with the given name.
Set `if_exists` to False if a non-existent table should raise
an exception instead of just being ignored.
"""
with self.cursor() as cur:
cur.drop_table(name, if_exists, cascade)
self.commit()
def server_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = conn.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
return (int(version / 10000), version % 10000)
def server_version_tuple(self) -> Tuple[int, int]:
""" Return the server version as a tuple of (major, minor).
Converts correctly for pre-10 and post-10 PostgreSQL versions.
"""
version = self.server_version
if version < 100000:
return (int(version / 10000), int((version % 10000) / 100))
def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension
has been installed already.
"""
version = execute_scalar(conn, 'SELECT postgis_lib_version()')
return (int(version / 10000), version % 10000)
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def postgis_version_tuple(self) -> Tuple[int, int]:
""" Return the postgis version installed in the database as a
tuple of (major, minor). Assumes that the PostGIS extension
has been installed already.
"""
with self.cursor() as cur:
version = cur.scalar('SELECT postgis_lib_version()')
version_parts = version.split('.')
if len(version_parts) < 2:
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
return (int(version_parts[0]), int(version_parts[1]))
def register_hstore(conn: Connection) -> None:
""" Register the hstore type with psycopg for the connection.
"""
psycopg2.extras.register_hstore(conn)
class ConnectionContext(ContextManager[Connection]):

View File

@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
"""
from typing import Optional, cast
from .connection import Connection
from .connection import Connection, table_exists
def set_property(conn: Connection, name: str, value: str) -> None:
""" Add or replace the property with the given name.
@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
""" Return the current value of the given property or None if the property
is not set.
"""
if not conn.table_exists('nominatim_properties'):
if not table_exists(conn, 'nominatim_properties'):
return None
with conn.cursor() as cur:

View File

@ -10,7 +10,7 @@ Preprocessing of SQL files.
from typing import Set, Dict, Any, cast
import jinja2
from .connection import Connection
from .connection import Connection, server_version_tuple, postgis_version_tuple
from .async_connection import WorkerPool
from ..config import Configuration
@ -66,8 +66,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
""" Set up a dictionary with various optional Postgresql/Postgis features that
depend on the database version.
"""
pg_version = conn.server_version_tuple()
postgis_version = conn.postgis_version_tuple()
pg_version = server_version_tuple(conn)
postgis_version = postgis_version_tuple(conn)
pg11plus = pg_version >= (11, 0, 0)
ps3 = postgis_version >= (3, 0)
return {

View File

@ -12,7 +12,7 @@ import datetime as dt
import logging
import re
from .connection import Connection
from .connection import Connection, table_exists, execute_scalar
from ..utils.url_utils import get_url
from ..errors import UsageError
from ..typing import TypedDict
@ -34,7 +34,7 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
data base.
"""
# If there is a date from osm2pgsql available, use that.
if conn.table_exists('osm2pgsql_properties'):
if table_exists(conn, 'osm2pgsql_properties'):
with conn.cursor() as cur:
cur.execute(""" SELECT value FROM osm2pgsql_properties
WHERE property = 'current_timestamp' """)
@ -47,15 +47,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
raise UsageError("Cannot determine database date from data in offline mode.")
# Else, find the node with the highest ID in the database
with conn.cursor() as cur:
if conn.table_exists('place'):
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
else:
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if table_exists(conn, 'place'):
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
else:
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
if osmid is None:
LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.")
if osmid is None:
LOG.fatal("No data found in the database.")
raise UsageError("No data found in the database.")
LOG.info("Using node id %d for timestamp lookup", osmid)
# Get the node from the API to find the timestamp when it was created.

View File

@ -15,7 +15,7 @@ import psycopg2.extras
from ..typing import DictCursorResults
from ..db.async_connection import DBConnection, WorkerPool
from ..db.connection import connect, Connection, Cursor
from ..db.connection import connect, Connection, Cursor, execute_scalar, register_hstore
from ..tokenizer.base import AbstractTokenizer
from .progress import ProgressLogger
from . import runners
@ -32,15 +32,15 @@ class PlaceFetcher:
self.conn: Optional[DBConnection] = DBConnection(dsn,
cursor_factory=psycopg2.extras.DictCursor)
with setup_conn.cursor() as cur:
# need to fetch those manually because register_hstore cannot
# fetch them on an asynchronous connection below.
hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
# need to fetch those manually because register_hstore cannot
# fetch them on an asynchronous connection below.
hstore_oid = execute_scalar(setup_conn, "SELECT 'hstore'::regtype::oid")
hstore_array_oid = execute_scalar(setup_conn, "SELECT 'hstore[]'::regtype::oid")
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
array_oid=hstore_array_oid)
def close(self) -> None:
""" Close the underlying asynchronous connection.
"""
@ -205,10 +205,9 @@ class Indexer:
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
with connect(self.dsn) as conn:
psycopg2.extras.register_hstore(conn)
with conn.cursor() as cur:
total_tuples = cur.scalar(runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
register_hstore(conn)
total_tuples = execute_scalar(conn, runner.sql_count_objects())
LOG.debug("Total number of rows: %i", total_tuples)
conn.commit()

View File

@ -16,7 +16,8 @@ import logging
from pathlib import Path
from textwrap import dedent
from ..db.connection import connect, Connection, Cursor
from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
drop_tables, table_exists, execute_scalar
from ..config import Configuration
from ..db.utils import CopyBuffer
from ..db.sql_preprocessor import SQLPreprocessor
@ -108,7 +109,7 @@ class ICUTokenizer(AbstractTokenizer):
""" Recompute frequencies for all name words.
"""
with connect(self.dsn) as conn:
if not conn.table_exists('search_name'):
if not table_exists(conn, 'search_name'):
return
with conn.cursor() as cur:
@ -117,10 +118,9 @@ class ICUTokenizer(AbstractTokenizer):
cur.execute('SET max_parallel_workers_per_gather TO %s',
(min(threads, 6),))
if conn.server_version_tuple() < (12, 0):
if server_version_tuple(conn) < (12, 0):
LOG.info('Computing word frequencies')
cur.drop_table('word_frequencies')
cur.drop_table('addressword_frequencies')
drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
cur.execute("""CREATE TEMP TABLE word_frequencies AS
SELECT unnest(name_vector) as id, count(*)
FROM search_name GROUP BY id""")
@ -152,17 +152,16 @@ class ICUTokenizer(AbstractTokenizer):
$$ LANGUAGE plpgsql IMMUTABLE;
""")
LOG.info('Update word table with recomputed frequencies')
cur.drop_table('tmp_word')
drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS
SELECT word_id, word_token, type, word,
word_freq_update(word_id, info) as info
FROM word
""")
cur.drop_table('word_frequencies')
cur.drop_table('addressword_frequencies')
drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
else:
LOG.info('Computing word frequencies')
cur.drop_table('word_frequencies')
drop_tables(conn, 'word_frequencies')
cur.execute("""
CREATE TEMP TABLE word_frequencies AS
WITH word_freq AS MATERIALIZED (
@ -182,7 +181,7 @@ class ICUTokenizer(AbstractTokenizer):
cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
cur.execute('ANALYSE word_frequencies')
LOG.info('Update word table with recomputed frequencies')
cur.drop_table('tmp_word')
drop_tables(conn, 'tmp_word')
cur.execute("""CREATE TABLE tmp_word AS
SELECT word_id, word_token, type, word,
(CASE WHEN wf.info is null THEN word.info
@ -191,7 +190,7 @@ class ICUTokenizer(AbstractTokenizer):
FROM word LEFT JOIN word_frequencies wf
ON word.word_id = wf.id
""")
cur.drop_table('word_frequencies')
drop_tables(conn, 'word_frequencies')
with conn.cursor() as cur:
cur.execute('SET max_parallel_workers_per_gather TO 0')
@ -210,7 +209,7 @@ class ICUTokenizer(AbstractTokenizer):
""" Remove unused house numbers.
"""
with connect(self.dsn) as conn:
if not conn.table_exists('search_name'):
if not table_exists(conn, 'search_name'):
return
with conn.cursor(name="hnr_counter") as cur:
cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
@ -311,8 +310,7 @@ class ICUTokenizer(AbstractTokenizer):
frequencies.
"""
with connect(self.dsn) as conn:
with conn.cursor() as cur:
cur.drop_table('word')
drop_tables(conn, 'word')
sqlp = SQLPreprocessor(conn, config)
sqlp.run_string(conn, """
CREATE TABLE word (
@ -370,8 +368,8 @@ class ICUTokenizer(AbstractTokenizer):
""" Rename all tables and indexes used by the tokenizer.
"""
with connect(self.dsn) as conn:
drop_tables(conn, 'word')
with conn.cursor() as cur:
cur.drop_table('word')
cur.execute(f"ALTER TABLE {old} RENAME TO word")
for idx in ('word_token', 'word_id'):
cur.execute(f"""ALTER INDEX idx_{old}_{idx}
@ -733,11 +731,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if norm_name:
result = self._cache.housenumbers.get(norm_name, result)
if result[0] is None:
with self.conn.cursor() as cur:
hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, ))
hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, ))
result = hid, norm_name
self._cache.housenumbers[norm_name] = result
result = hid, norm_name
self._cache.housenumbers[norm_name] = result
else:
# Otherwise use the analyzer to determine the canonical name.
# Per convention we use the first variant as the 'lookup name', the
@ -748,11 +745,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
if result[0] is None:
variants = analyzer.compute_variants(word_id)
if variants:
with self.conn.cursor() as cur:
hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)",
hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
(word_id, list(variants)))
result = hid, variants[0]
self._cache.housenumbers[word_id] = result
result = hid, variants[0]
self._cache.housenumbers[word_id] = result
return result

View File

@ -18,10 +18,10 @@ from textwrap import dedent
from icu import Transliterator
import psycopg2
import psycopg2.extras
from ..errors import UsageError
from ..db.connection import connect, Connection
from ..db.connection import connect, Connection, drop_tables, table_exists,\
execute_scalar, register_hstore
from ..config import Configuration
from ..db import properties
from ..db import utils as db_utils
@ -179,11 +179,10 @@ class LegacyTokenizer(AbstractTokenizer):
* Can nominatim.so be accessed by the database user?
"""
with connect(self.dsn) as conn:
with conn.cursor() as cur:
try:
out = cur.scalar("SELECT make_standard_name('a')")
except psycopg2.Error as err:
return hint.format(error=str(err))
try:
out = execute_scalar(conn, "SELECT make_standard_name('a')")
except psycopg2.Error as err:
return hint.format(error=str(err))
if out != 'a':
return hint.format(error='Unexpected result for make_standard_name()')
@ -214,9 +213,9 @@ class LegacyTokenizer(AbstractTokenizer):
""" Recompute the frequency of full words.
"""
with connect(self.dsn) as conn:
if conn.table_exists('search_name'):
if table_exists(conn, 'search_name'):
drop_tables(conn, "word_frequencies")
with conn.cursor() as cur:
cur.drop_table("word_frequencies")
LOG.info("Computing word frequencies")
cur.execute("""CREATE TEMP TABLE word_frequencies AS
SELECT unnest(name_vector) as id, count(*)
@ -226,7 +225,7 @@ class LegacyTokenizer(AbstractTokenizer):
cur.execute("""UPDATE word SET search_name_count = count
FROM word_frequencies
WHERE word_token like ' %' and word_id = id""")
cur.drop_table("word_frequencies")
drop_tables(conn, "word_frequencies")
conn.commit()
@ -316,7 +315,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
self.conn: Optional[Connection] = connect(dsn).connection
self.conn.autocommit = True
self.normalizer = normalizer
psycopg2.extras.register_hstore(self.conn)
register_hstore(self.conn)
self._cache = _TokenCache(self.conn)
@ -536,9 +535,8 @@ class _TokenInfo:
def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
""" Add token information for the names of the place.
"""
with conn.cursor() as cur:
# Create the token IDs for all names.
self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text",
# Create the token IDs for all names.
self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text",
(names, ))
@ -576,9 +574,8 @@ class _TokenInfo:
""" Add addr:street match terms.
"""
def _get_street(name: str) -> Optional[str]:
with conn.cursor() as cur:
return cast(Optional[str],
cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
return cast(Optional[str],
execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, )))
tokens = self.cache.streets.get(street, _get_street)
self.data['street'] = tokens or '{}'

View File

@ -10,12 +10,12 @@ Functions for database analysis and maintenance.
from typing import Optional, Tuple, Any, cast
import logging
from psycopg2.extras import Json, register_hstore
from psycopg2.extras import Json
from psycopg2 import DataError
from ..typing import DictCursorResult
from ..config import Configuration
from ..db.connection import connect, Cursor
from ..db.connection import connect, Cursor, register_hstore
from ..errors import UsageError
from ..tokenizer import factory as tokenizer_factory
from ..data.place_info import PlaceInfo

View File

@ -12,7 +12,8 @@ from enum import Enum
from textwrap import dedent
from ..config import Configuration
from ..db.connection import connect, Connection
from ..db.connection import connect, Connection, server_version_tuple,\
index_exists, table_exists, execute_scalar
from ..db import properties
from ..errors import UsageError
from ..tokenizer import factory as tokenizer_factory
@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]:
'idx_postcode_id',
'idx_postcode_postcode'
]
if conn.table_exists('search_name'):
if table_exists(conn, 'search_name'):
indexes.extend(('idx_search_name_nameaddress_vector',
'idx_search_name_name_vector',
'idx_search_name_centroid'))
if conn.server_version_tuple() >= (11, 0, 0):
if server_version_tuple(conn) >= (11, 0, 0):
indexes.extend(('idx_placex_housenumber',
'idx_osmline_parent_osm_id_with_hnr'))
if conn.table_exists('place'):
if table_exists(conn, 'place'):
indexes.extend(('idx_location_area_country_place_id',
'idx_place_osm_unique',
'idx_placex_rank_address_sector',
@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
Hints:
* Are you connecting to the correct database?
{instruction}
Check the Migration chapter of the Administration Guide.
@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
""" Checking database_version matches Nominatim software version
"""
if conn.table_exists('nominatim_properties'):
if table_exists(conn, 'nominatim_properties'):
db_version_str = properties.get_property(conn, 'database_version')
else:
db_version_str = None
@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
""" Checking for placex table
"""
if conn.table_exists('placex'):
if table_exists(conn, 'placex'):
return CheckState.OK
return CheckState.FATAL, dict(config=config)
@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
""" Checking for placex content
"""
with conn.cursor() as cur:
cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
return CheckState.OK if cnt > 0 else CheckState.FATAL
@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult:
def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
""" Checking for wikipedia/wikidata data
"""
if not conn.table_exists('search_name') or not conn.table_exists('place'):
if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'):
return CheckState.NOT_APPLICABLE
with conn.cursor() as cur:
if conn.table_exists('wikimedia_importance'):
cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance')
else:
cnt = cur.scalar('SELECT count(*) FROM wikipedia_article')
if table_exists(conn, 'wikimedia_importance'):
cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance')
else:
cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article')
return CheckState.WARN if cnt == 0 else CheckState.OK
return CheckState.WARN if cnt == 0 else CheckState.OK
@_check(hint="""\
@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult
def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
""" Checking indexing status
"""
with conn.cursor() as cur:
cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0')
cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0')
if cnt == 0:
return CheckState.OK
@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
Low counts of unindexed places are fine."""
return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
if conn.index_exists('idx_placex_rank_search'):
if index_exists(conn, 'idx_placex_rank_search'):
# Likely just an interrupted update.
index_cmd = 'nominatim index'
else:
@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult:
"""
missing = []
for index in _get_indexes(conn):
if not conn.index_exists(index):
if not index_exists(conn, index):
missing.append(index)
if missing:
@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult:
if not config.get_bool('USE_US_TIGER_DATA'):
return CheckState.NOT_APPLICABLE
if not conn.table_exists('location_property_tiger'):
if not table_exists(conn, 'location_property_tiger'):
return CheckState.FAIL, dict(error='TIGER data table not found.')
with conn.cursor() as cur:
if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0:
return CheckState.FAIL, dict(error='TIGER data table is empty.')
if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0:
return CheckState.FAIL, dict(error='TIGER data table is empty.')
return CheckState.OK

View File

@ -12,21 +12,16 @@ import os
import subprocess
import sys
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Union
import psutil
from psycopg2.extensions import make_dsn, parse_dsn
from psycopg2.extensions import make_dsn
from ..config import Configuration
from ..db.connection import connect
from ..db.connection import connect, server_version_tuple, execute_scalar
from ..version import NOMINATIM_VERSION
def convert_version(ver_tup: Tuple[int, int]) -> str:
"""converts tuple version (ver_tup) to a string representation"""
return ".".join(map(str, ver_tup))
def friendly_memory_string(mem: float) -> str:
"""Create a user friendly string for the amount of memory specified as mem"""
mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
@ -103,16 +98,16 @@ def report_system_information(config: Configuration) -> None:
storage, and database configuration."""
with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
postgresql_ver: str = convert_version(conn.server_version_tuple())
postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
with conn.cursor() as cur:
num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s",
(parse_dsn(config.get_libpq_dsn())['dbname'], ))
nominatim_db_exists = num == 1 if isinstance(num, int) else False
cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s",
(config.get_database_params()['dbname'], ))
nominatim_db_exists = cur.rowcount > 0
if nominatim_db_exists:
with connect(config.get_libpq_dsn()) as conn:
postgis_ver: str = convert_version(conn.postgis_version_tuple())
postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()')
else:
postgis_ver = "Unable to connect to database"

View File

@ -19,7 +19,8 @@ from psycopg2 import sql as pysql
from ..errors import UsageError
from ..config import Configuration
from ..db.connection import connect, get_pg_env, Connection
from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
postgis_version_tuple, drop_tables, table_exists, execute_scalar
from ..db.async_connection import DBConnection
from ..db.sql_preprocessor import SQLPreprocessor
from .exec_utils import run_osm2pgsql
@ -51,10 +52,10 @@ def check_existing_database_plugins(dsn: str) -> None:
""" Check that the database has the required plugins installed."""
with connect(dsn) as conn:
_require_version('PostgreSQL server',
conn.server_version_tuple(),
server_version_tuple(conn),
POSTGRESQL_REQUIRED_VERSION)
_require_version('PostGIS',
conn.postgis_version_tuple(),
postgis_version_tuple(conn),
POSTGIS_REQUIRED_VERSION)
_require_loaded('hstore', conn)
@ -80,31 +81,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
with connect(dsn) as conn:
_require_version('PostgreSQL server',
conn.server_version_tuple(),
server_version_tuple(conn),
POSTGRESQL_REQUIRED_VERSION)
if rouser is not None:
with conn.cursor() as cur:
cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
(rouser, ))
if cnt == 0:
LOG.fatal("Web user '%s' does not exist. Create it with:\n"
"\n createuser %s", rouser, rouser)
raise UsageError('Missing read-only user.')
if cnt == 0:
LOG.fatal("Web user '%s' does not exist. Create it with:\n"
"\n createuser %s", rouser, rouser)
raise UsageError('Missing read-only user.')
# Create extensions.
with conn.cursor() as cur:
cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
postgis_version = conn.postgis_version_tuple()
postgis_version = postgis_version_tuple(conn)
if postgis_version[0] >= 3:
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
conn.commit()
_require_version('PostGIS',
conn.postgis_version_tuple(),
postgis_version_tuple(conn),
POSTGIS_REQUIRED_VERSION)
@ -141,7 +141,8 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
raise UsageError('No data imported by osm2pgsql.')
if drop:
conn.drop_table('planet_osm_nodes')
drop_tables(conn, 'planet_osm_nodes')
conn.commit()
if drop and options['flatnode_file']:
Path(options['flatnode_file']).unlink()
@ -184,7 +185,7 @@ def truncate_data_tables(conn: Connection) -> None:
cur.execute('TRUNCATE location_property_tiger')
cur.execute('TRUNCATE location_property_osmline')
cur.execute('TRUNCATE location_postcode')
if conn.table_exists('search_name'):
if table_exists(conn, 'search_name'):
cur.execute('TRUNCATE search_name')
cur.execute('DROP SEQUENCE IF EXISTS seq_place')
cur.execute('CREATE SEQUENCE seq_place start 100000')

View File

@ -12,7 +12,7 @@ from pathlib import Path
from psycopg2 import sql as pysql
from ..db.connection import Connection
from ..db.connection import Connection, drop_tables, table_exists
UPDATE_TABLES = [
'address_levels',
@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None:
+ pysql.SQL(' or ').join(parts))
tables = [r[0] for r in cur]
for table in tables:
cur.drop_table(table, cascade=True)
drop_tables(conn, *tables, cascade=True)
conn.commit()
@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool:
""" Returns true if database is in a frozen state
"""
return conn.table_exists('place') is False
return table_exists(conn, 'place') is False

View File

@ -15,7 +15,8 @@ from psycopg2 import sql as pysql
from ..errors import UsageError
from ..config import Configuration
from ..db import properties
from ..db.connection import connect, Connection
from ..db.connection import connect, Connection, server_version_tuple,\
table_has_column, table_exists, execute_scalar, register_hstore
from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
from ..tokenizer import factory as tokenizer_factory
from . import refresh
@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
if necesssary.
"""
with connect(config.get_libpq_dsn()) as conn:
if conn.table_exists('nominatim_properties'):
register_hstore(conn)
if table_exists(conn, 'nominatim_properties'):
db_version_str = properties.get_property(conn, 'database_version')
else:
db_version_str = None
@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
Only migrations for 3.6 and later are supported, so bail out
when the version seems older.
"""
with conn.cursor() as cur:
# In version 3.6, the country_name table was updated. Check for that.
cnt = cur.scalar("""SELECT count(*) FROM
(SELECT svals(name) FROM country_name
WHERE country_code = 'gb')x;
""")
if cnt < 100:
LOG.fatal('It looks like your database was imported with a version '
'prior to 3.6.0. Automatic migration not possible.')
raise UsageError('Migration not possible.')
# In version 3.6, the country_name table was updated. Check for that.
cnt = execute_scalar(conn, """SELECT count(*) FROM
(SELECT svals(name) FROM country_name
WHERE country_code = 'gb')x;
""")
if cnt < 100:
LOG.fatal('It looks like your database was imported with a version '
'prior to 3.6.0. Automatic migration not possible.')
raise UsageError('Migration not possible.')
return NominatimVersion(3, 5, 0, 99)
@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
""" Add nominatim_property table.
"""
if not conn.table_exists('nominatim_properties'):
if not table_exists(conn, 'nominatim_properties'):
with conn.cursor() as cur:
cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
property TEXT,
@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
configuration for the backwards-compatible legacy tokenizer
"""
if properties.get_property(conn, 'tokenizer') is None:
with conn.cursor() as cur:
for table in ('placex', 'location_property_osmline'):
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
WHERE table_name = %s
and column_name = 'token_info'""",
(table, ))
if has_column == 0:
for table in ('placex', 'location_property_osmline'):
if not table_has_column(conn, table, 'token_info'):
with conn.cursor() as cur:
cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
.format(pysql.Identifier(table)))
tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
The inclusion is needed for efficient lookup of housenumbers in
full address searches.
"""
if conn.server_version_tuple() >= (11, 0, 0):
if server_version_tuple(conn) >= (11, 0, 0):
with conn.cursor() as cur:
cur.execute(""" CREATE INDEX IF NOT EXISTS
idx_location_property_tiger_housenumber_migrated
@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
Also converts the data into the stricter format which requires that
startnumbers comply with the odd/even requirements.
"""
if conn.table_has_column('location_property_osmline', 'step'):
if table_has_column(conn, 'location_property_osmline', 'step'):
return
with conn.cursor() as cur:
@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
""" Add a new column 'step' to the tiger data table.
"""
if conn.table_has_column('location_property_tiger', 'step'):
if table_has_column(conn, 'location_property_tiger', 'step'):
return
with conn.cursor() as cur:
@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
""" Add a new column 'derived_name' which in the future takes the
country names as imported from OSM data.
"""
if not conn.table_has_column('country_name', 'derived_name'):
if not table_has_column(conn, 'country_name', 'derived_name'):
with conn.cursor() as cur:
cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
""" Names from the country table should be marked as internal to prevent
them from being deleted. Only necessary for ICU tokenizer.
"""
import psycopg2.extras # pylint: disable=import-outside-toplevel
tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
with tokenizer.name_analyzer() as analyzer:
with conn.cursor() as cur:
psycopg2.extras.register_hstore(cur)
cur.execute("SELECT country_code, name FROM country_name")
for country_code, names in cur:
@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
The table is only necessary when updates are possible, i.e.
the database is not in freeze mode.
"""
if conn.table_exists('place'):
if table_exists(conn, 'place'):
with conn.cursor() as cur:
cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
osm_type CHAR(1),
@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
def split_pending_index(conn: Connection, **_: Any) -> None:
""" Reorganise indexes for pending updates.
"""
if conn.table_exists('place'):
if table_exists(conn, 'place'):
with conn.cursor() as cur:
cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
ON placex USING BTREE (rank_address, geometry_sector)
@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
""" Create indexes for updates with forward dependency tracking (long-running).
"""
if conn.table_exists('planet_osm_ways'):
if table_exists(conn, 'planet_osm_ways'):
with conn.cursor() as cur:
cur.execute("""SELECT * FROM pg_indexes
WHERE tablename = 'planet_osm_ways'
@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
""" Create index needed for updating postcodes when a parent changes.
"""
if conn.table_exists('planet_osm_ways'):
if table_exists(conn, 'planet_osm_ways'):
with conn.cursor() as cur:
cur.execute("""CREATE INDEX IF NOT EXISTS
idx_location_postcode_parent_place_id

View File

@ -18,7 +18,7 @@ from math import isfinite
from psycopg2 import sql as pysql
from ..db.connection import connect, Connection
from ..db.connection import connect, Connection, table_exists
from ..utils.centroid import PointsCentroid
from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool:
postcodes can be computed.
"""
with connect(dsn) as conn:
return conn.table_exists('place')
return table_exists(conn, 'place')

View File

@ -17,7 +17,8 @@ from pathlib import Path
from psycopg2 import sql as pysql
from ..config import Configuration
from ..db.connection import Connection, connect
from ..db.connection import Connection, connect, postgis_version_tuple,\
drop_tables, table_exists
from ..db.utils import execute_file, CopyBuffer
from ..db.sql_preprocessor import SQLPreprocessor
from ..version import NOMINATIM_VERSION
@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
for entry in levels:
_add_address_level_rows_from_entry(rows, entry)
with conn.cursor() as cur:
cur.drop_table(table)
drop_tables(conn, table)
with conn.cursor() as cur:
cur.execute(pysql.SQL("""CREATE TABLE {} (
country_code varchar(2),
class TEXT,
@ -159,10 +160,8 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
wd_done = set()
with connect(dsn) as conn:
drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance')
with conn.cursor() as cur:
cur.drop_table('wikipedia_article')
cur.drop_table('wikipedia_redirect')
cur.drop_table('wikimedia_importance')
cur.execute("""CREATE TABLE wikimedia_importance (
language TEXT NOT NULL,
title TEXT NOT NULL,
@ -228,7 +227,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
return 1
with connect(dsn) as conn:
postgis_version = conn.postgis_version_tuple()
postgis_version = postgis_version_tuple(conn)
if postgis_version[0] < 3:
LOG.error('PostGIS version is too old for using OSM raster data.')
return 2
@ -309,7 +308,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non
template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
search_name_table_exists = bool(conn and conn.table_exists('search_name'))
search_name_table_exists = bool(conn and table_exists(conn, 'search_name'))
for script in WEBSITE_SCRIPTS:
if not search_name_table_exists and script == 'search.php':

View File

@ -20,7 +20,7 @@ import requests
from ..errors import UsageError
from ..db import status
from ..db.connection import Connection, connect
from ..db.connection import Connection, connect, server_version_tuple
from .exec_utils import run_osm2pgsql
try:
@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -
# Consume updates with osm2pgsql.
options['append'] = True
options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
options['disable_jit'] = server_version_tuple(conn) >= (11, 0)
run_osm2pgsql(options)
# Handle deletions

View File

@ -21,7 +21,7 @@ from psycopg2.sql import Identifier, SQL
from ...typing import Protocol
from ...config import Configuration
from ...db.connection import Connection
from ...db.connection import Connection, drop_tables, index_exists
from .importer_statistics import SpecialPhrasesImporterStatistics
from .special_phrase import SpecialPhrase
from ...tokenizer.base import AbstractTokenizer
@ -233,7 +233,7 @@ class SPImporter():
index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
base_table = _classtype_table(phrase_class, phrase_type)
# Index on centroid
if not self.db_connection.index_exists(index_prefix + 'centroid'):
if not index_exists(self.db_connection, index_prefix + 'centroid'):
with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
.format(Identifier(index_prefix + 'centroid'),
@ -241,7 +241,7 @@ class SPImporter():
SQL(sql_tablespace)))
# Index on place_id
if not self.db_connection.index_exists(index_prefix + 'place_id'):
if not index_exists(self.db_connection, index_prefix + 'place_id'):
with self.db_connection.cursor() as db_cursor:
db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
.format(Identifier(index_prefix + 'place_id'),
@ -259,6 +259,7 @@ class SPImporter():
.format(Identifier(table_name),
Identifier(self.config.DATABASE_WEBUSER)))
def _remove_non_existent_tables_from_db(self) -> None:
"""
Remove special phrases which doesn't exist on the wiki anymore.
@ -268,7 +269,6 @@ class SPImporter():
# Delete place_classtype tables corresponding to class/type which
# are not on the wiki anymore.
with self.db_connection.cursor() as db_cursor:
for table in self.table_phrases_to_delete:
self.statistics_handler.notify_one_table_deleted()
db_cursor.drop_table(table)
drop_tables(self.db_connection, *self.table_phrases_to_delete)
for _ in self.table_phrases_to_delete:
self.statistics_handler.notify_one_table_deleted()

View File

@ -33,13 +33,13 @@ def simple_conns(temp_db):
conn2.close()
def test_simple_query(conn, temp_db_conn):
def test_simple_query(conn, temp_db_cursor):
conn.connect()
conn.perform('CREATE TABLE foo (id INT)')
conn.wait()
temp_db_conn.table_exists('foo')
assert temp_db_cursor.table_exists('foo')
def test_wait_for_query(conn):

View File

@ -10,61 +10,74 @@ Tests for specialised connection and cursor classes.
import pytest
import psycopg2
from nominatim_db.db.connection import connect, get_pg_env
import nominatim_db.db.connection as nc
@pytest.fixture
def db(dsn):
with connect(dsn) as conn:
with nc.connect(dsn) as conn:
yield conn
def test_connection_table_exists(db, table_factory):
assert not db.table_exists('foobar')
assert not nc.table_exists(db, 'foobar')
table_factory('foobar')
assert db.table_exists('foobar')
assert nc.table_exists(db, 'foobar')
def test_has_column_no_table(db):
assert not db.table_has_column('sometable', 'somecolumn')
assert not nc.table_has_column(db, 'sometable', 'somecolumn')
@pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
def test_has_column(db, table_factory, name, result):
table_factory('stuff', 'tram TEXT')
assert db.table_has_column('stuff', name) == result
assert nc.table_has_column(db, 'stuff', name) == result
def test_connection_index_exists(db, table_factory, temp_db_cursor):
assert not db.index_exists('some_index')
assert not nc.index_exists(db, 'some_index')
table_factory('foobar')
temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
assert db.index_exists('some_index')
assert db.index_exists('some_index', table='foobar')
assert not db.index_exists('some_index', table='bar')
assert nc.index_exists(db, 'some_index')
assert nc.index_exists(db, 'some_index', table='foobar')
assert not nc.index_exists(db, 'some_index', table='bar')
def test_drop_table_existing(db, table_factory):
table_factory('dummy')
assert db.table_exists('dummy')
assert nc.table_exists(db, 'dummy')
db.drop_table('dummy')
assert not db.table_exists('dummy')
nc.drop_tables(db, 'dummy')
assert not nc.table_exists(db, 'dummy')
def test_drop_table_non_existsing(db):
db.drop_table('dfkjgjriogjigjgjrdghehtre')
def test_drop_table_non_existing(db):
nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre')
def test_drop_many_tables(db, table_factory):
tables = [f'table{n}' for n in range(5)]
for t in tables:
table_factory(t)
assert nc.table_exists(db, t)
nc.drop_tables(db, *tables)
for t in tables:
assert not nc.table_exists(db, t)
def test_drop_table_non_existing_force(db):
with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False)
nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
def test_connection_server_version_tuple(db):
ver = db.server_version_tuple()
ver = nc.server_version_tuple(db)
assert isinstance(ver, tuple)
assert len(ver) == 2
@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db):
def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
ver = db.postgis_version_tuple()
ver = nc.postgis_version_tuple(db)
assert isinstance(ver, tuple)
assert len(ver) == 2
@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
def test_cursor_scalar(db, table_factory):
table_factory('dummy')
with db.cursor() as cur:
assert cur.scalar('SELECT count(*) FROM dummy') == 0
assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0
def test_cursor_scalar_many_rows(db):
with db.cursor() as cur:
with pytest.raises(RuntimeError):
cur.scalar('SELECT * FROM pg_tables')
with pytest.raises(RuntimeError, match='Query did not return a single row.'):
nc.execute_scalar(db, 'SELECT * FROM pg_tables')
def test_cursor_scalar_no_rows(db, table_factory):
table_factory('dummy')
with db.cursor() as cur:
with pytest.raises(RuntimeError):
cur.scalar('SELECT id FROM dummy')
with pytest.raises(RuntimeError, match='Query did not return a single row.'):
nc.execute_scalar(db, 'SELECT id FROM dummy')
def test_get_pg_env_add_variable(monkeypatch):
monkeypatch.delenv('PGPASSWORD', raising=False)
env = get_pg_env('user=fooF')
env = nc.get_pg_env('user=fooF')
assert env['PGUSER'] == 'fooF'
assert 'PGPASSWORD' not in env
@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch):
def test_get_pg_env_overwrite_variable(monkeypatch):
monkeypatch.setenv('PGUSER', 'some default')
env = get_pg_env('user=overwriter')
env = nc.get_pg_env('user=overwriter')
assert env['PGUSER'] == 'overwriter'
def test_get_pg_env_ignore_unknown():
env = get_pg_env('client_encoding=stuff', base_env={})
env = nc.get_pg_env('client_encoding=stuff', base_env={})
assert env == {}

View File

@ -8,6 +8,7 @@
Legacy word table for testing with functions to prefil and test contents
of the table.
"""
from nominatim_db.db.connection import execute_scalar
class MockIcuWordTable:
""" A word table for testing using legacy word table structure.
@ -77,18 +78,15 @@ class MockIcuWordTable:
def count(self):
with self.conn.cursor() as cur:
return cur.scalar("SELECT count(*) FROM word")
return execute_scalar(self.conn, "SELECT count(*) FROM word")
def count_special(self):
with self.conn.cursor() as cur:
return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'")
return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'")
def count_housenumbers(self):
with self.conn.cursor() as cur:
return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'")
return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'")
def get_special(self):

View File

@ -8,6 +8,7 @@
Legacy word table for testing with functions to prefil and test contents
of the table.
"""
from nominatim_db.db.connection import execute_scalar
class MockLegacyWordTable:
""" A word table for testing using legacy word table structure.
@ -58,13 +59,11 @@ class MockLegacyWordTable:
def count(self):
with self.conn.cursor() as cur:
return cur.scalar("SELECT count(*) FROM word")
return execute_scalar(self.conn, "SELECT count(*) FROM word")
def count_special(self):
with self.conn.cursor() as cur:
return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'")
return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'")
def get_special(self):

View File

@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor,
assert test_content == set((('1133', ), ))
def test_finalize_import(tokenizer_factory, temp_db_conn,
temp_db_cursor, test_config, sql_preprocessor_cfg):
def test_finalize_import(tokenizer_factory, temp_db_cursor,
test_config, sql_preprocessor_cfg):
tok = tokenizer_factory()
tok.init_new_db(test_config)
assert not temp_db_conn.index_exists('idx_word_word_id')
assert not temp_db_cursor.index_exists('word', 'idx_word_word_id')
tok.finalize_import(test_config)
assert temp_db_conn.index_exists('idx_word_word_id')
assert temp_db_cursor.index_exists('word', 'idx_word_word_id')
def test_check_database(test_config, tokenizer_factory,

View File

@ -132,7 +132,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options)
ignore_errors=True)
def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options):
def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
table_factory('place', content=((1, ), ))
table_factory('planet_osm_nodes')
@ -144,7 +144,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
assert not flatfile.exists()
assert not temp_db_conn.table_exists('planet_osm_nodes')
assert not temp_db_cursor.table_exists('planet_osm_nodes')
def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):

View File

@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer):
assert isinstance(black_list, dict) and isinstance(white_list, dict)
def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
def test_create_place_classtype_indexes(temp_db_with_extensions,
temp_db_conn, temp_db_cursor,
table_factory, sp_importer):
"""
Test that _create_place_classtype_indexes() create the
@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
temp_db_conn.commit()
assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type)
assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type)
def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer):
"""
Test that _create_place_classtype_table() create
the right place_classtype table.
@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
phrase_class = 'class'
phrase_type = 'type'
sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
temp_db_conn.commit()
assert check_table_exist(temp_db_conn, phrase_class, phrase_type)
assert check_table_exist(temp_db_cursor, phrase_class, phrase_type)
def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer):
def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
def_config, sp_importer):
"""
Test that _grant_access_to_webuser() give
right access to the web user.
@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im
table_factory(table_name)
sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
temp_db_conn.commit()
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
def test_create_place_classtype_table_and_indexes(
temp_db_conn, def_config, placex_table,
sp_importer):
temp_db_cursor, def_config, placex_table,
sp_importer, temp_db_conn):
"""
Test that _create_place_classtype_table_and_indexes()
create the right place_classtype tables and place_id indexes
@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes(
pairs = set([('class1', 'type1'), ('class2', 'type2')])
sp_importer._create_classtype_table_and_indexes(pairs)
temp_db_conn.commit()
for pair in pairs:
assert check_table_exist(temp_db_conn, pair[0], pair[1])
assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1])
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1])
assert check_table_exist(temp_db_cursor, pair[0], pair[1])
assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1])
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1])
def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
temp_db_conn):
temp_db_conn, temp_db_cursor):
"""
Check for the remove_non_existent_phrases_from_db() method.
@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
"""
sp_importer._remove_non_existent_tables_from_db()
temp_db_conn.commit()
# Changes are not committed yet. Use temp_db_conn for checking results.
with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur:
assert cur.row_set(query_tables) \
assert temp_db_cursor.row_set(query_tables) \
== {('place_classtype_testclasstypetable_to_keep', )}
@pytest.mark.parametrize("should_replace", [(True), (False)])
def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
placex_table, table_factory, tokenizer_mock,
xml_wiki_content, should_replace):
"""
@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
class_test = 'aerialway'
type_test = 'zip_line'
assert check_table_exist(temp_db_conn, class_test, type_test)
assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test)
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test)
assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter')
assert check_table_exist(temp_db_cursor, class_test, type_test)
assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter')
if should_replace:
assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type')
assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type')
assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter')
assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter')
if should_replace:
assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype')
assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype')
def check_table_exist(temp_db_conn, phrase_class, phrase_type):
def check_table_exist(temp_db_cursor, phrase_class, phrase_type):
"""
Verify that the place_classtype table exists for the given
phrase_class and phrase_type.
"""
return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
def check_grant_access(temp_db_conn, user, phrase_class, phrase_type):
def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type):
"""
Check that the web user has been granted right access to the
place_classtype table of the given phrase_class and phrase_type.
"""
table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
with temp_db_conn.cursor() as temp_db_cursor:
temp_db_cursor.execute("""
SELECT * FROM information_schema.role_table_grants
WHERE table_name='{}'
AND grantee='{}'
AND privilege_type='SELECT'""".format(table_name, user))
return temp_db_cursor.fetchone()
temp_db_cursor.execute("""
SELECT * FROM information_schema.role_table_grants
WHERE table_name='{}'
AND grantee='{}'
AND privilege_type='SELECT'""".format(table_name, user))
return temp_db_cursor.fetchone()
def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type):
def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type):
"""
Check that the place_id index and centroid index exist for the
place_classtype table of the given phrase_class and phrase_type.
"""
table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
return (
temp_db_conn.index_exists(index_prefix + 'centroid')
temp_db_cursor.index_exists(table_name, index_prefix + 'centroid')
and
temp_db_conn.index_exists(index_prefix + 'place_id')
temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
)

View File

@ -12,6 +12,7 @@ import psycopg2.extras
from nominatim_db.tools import migration
from nominatim_db.errors import UsageError
from nominatim_db.db.connection import server_version_tuple
import nominatim_db.version
from mock_legacy_word_table import MockLegacyWordTable
@ -63,7 +64,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
WHERE property = 'database_version'""")
def test_already_at_version(def_config, property_table):
def test_already_at_version(temp_db_with_extensions, def_config, property_table):
property_table.set('database_version',
str(nominatim_db.version.NOMINATIM_VERSION))
@ -71,8 +72,8 @@ def test_already_at_version(def_config, property_table):
assert migration.migrate(def_config, {}) == 0
def test_run_single_migration(def_config, temp_db_cursor, property_table,
monkeypatch, postprocess_mock):
def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor,
property_table, monkeypatch, postprocess_mock):
oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
oldversion[0] -= 1
property_table.set('database_version',
@ -226,7 +227,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact
migration.create_tiger_housenumber_index(temp_db_conn)
temp_db_conn.commit()
if temp_db_conn.server_version_tuple() >= (11, 0, 0):
if server_version_tuple(temp_db_conn) >= (11, 0, 0):
assert temp_db_cursor.index_exists('location_property_tiger',
'idx_location_property_tiger_housenumber_migrated')

View File

@ -12,6 +12,7 @@ from pathlib import Path
import pytest
from nominatim_db.tools import refresh
from nominatim_db.db.connection import postgis_version_tuple
def test_refresh_import_wikipedia_not_existing(dsn):
assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn):
def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
temp_db_cursor.execute('CREATE EXTENSION postgis')
if temp_db_conn.postgis_version_tuple()[0] < 3:
if postgis_version_tuple(temp_db_conn)[0] < 3:
assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
else:
temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
assert temp_db_conn.table_exists('secondary_importance')
assert temp_db_cursor.table_exists('secondary_importance')
@pytest.mark.parametrize("replace", (True, False))

View File

@ -12,6 +12,7 @@ from textwrap import dedent
import pytest
from nominatim_db.db.connection import execute_scalar
from nominatim_db.tools import tiger_data, freeze
from nominatim_db.errors import UsageError
@ -31,8 +32,7 @@ class MockTigerTable:
cur.execute("CREATE TABLE place (number INTEGER)")
def count(self):
with self.conn.cursor() as cur:
return cur.scalar("SELECT count(*) FROM tiger")
return execute_scalar(self.conn, "SELECT count(*) FROM tiger")
def row(self):
with self.conn.cursor() as cur: