mirror of
https://github.com/osm-search/Nominatim.git
synced 2024-11-22 12:06:27 +03:00
Merge pull request #3487 from lonvia/port-to-psycopg3
Move importer code to psycopg3
This commit is contained in:
commit
0add25e335
4
.github/actions/build-nominatim/action.yml
vendored
4
.github/actions/build-nominatim/action.yml
vendored
@ -27,9 +27,9 @@ runs:
|
||||
run: |
|
||||
sudo apt-get install -y -qq libboost-system-dev libboost-filesystem-dev libexpat1-dev zlib1g-dev libbz2-dev libpq-dev libproj-dev libicu-dev liblua${LUA_VERSION}-dev lua${LUA_VERSION} lua-dkjson nlohmann-json3-dev libspatialite7 libsqlite3-mod-spatialite
|
||||
if [ "$FLAVOUR" == "oldstuff" ]; then
|
||||
pip3 install MarkupSafe==2.0.1 python-dotenv psycopg2==2.7.7 jinja2==2.8 psutil==5.4.2 pyicu==2.9 osmium PyYAML==5.1 sqlalchemy==1.4.31 datrie asyncpg aiosqlite
|
||||
pip3 install MarkupSafe==2.0.1 python-dotenv jinja2==2.8 psutil==5.4.2 pyicu==2.9 osmium PyYAML==5.1 sqlalchemy==1.4.31 psycopg==3.1.7 datrie asyncpg aiosqlite
|
||||
else
|
||||
sudo apt-get install -y -qq python3-icu python3-datrie python3-pyosmium python3-jinja2 python3-psutil python3-psycopg2 python3-dotenv python3-yaml
|
||||
sudo apt-get install -y -qq python3-icu python3-datrie python3-pyosmium python3-jinja2 python3-psutil python3-dotenv python3-yaml
|
||||
pip3 install sqlalchemy psycopg aiosqlite
|
||||
fi
|
||||
shell: bash
|
||||
|
@ -36,19 +36,15 @@ For running Nominatim:
|
||||
|
||||
Furthermore the following Python libraries are required:
|
||||
|
||||
* [Psycopg2](https://www.psycopg.org) (2.7+)
|
||||
* [Psycopg3](https://www.psycopg.org)
|
||||
* [Python Dotenv](https://github.com/theskumar/python-dotenv)
|
||||
* [psutil](https://github.com/giampaolo/psutil)
|
||||
* [Jinja2](https://palletsprojects.com/p/jinja/)
|
||||
* [SQLAlchemy](https://www.sqlalchemy.org/) (1.4.31+ with greenlet support)
|
||||
* one of
|
||||
* [psycopg3](https://www.psycopg.org)
|
||||
* [asyncpg](https://magicstack.github.io/asyncpg) (0.8+)
|
||||
* [PyICU](https://pypi.org/project/PyICU/)
|
||||
* [PyYaml](https://pyyaml.org/) (5.1+)
|
||||
* [datrie](https://github.com/pytries/datrie)
|
||||
|
||||
These will be installed automatically, when using pip installation.
|
||||
These will be installed automatically when using pip installation.
|
||||
|
||||
When using legacy CMake-based installation:
|
||||
|
||||
@ -69,6 +65,8 @@ For running continuous updates:
|
||||
|
||||
For running the Python frontend:
|
||||
|
||||
* [SQLAlchemy](https://www.sqlalchemy.org/) (1.4.31+ with greenlet support)
|
||||
* [asyncpg](https://magicstack.github.io/asyncpg) (0.8+, only when using SQLAlchemy < 2.0)
|
||||
* one of the following web frameworks:
|
||||
* [falcon](https://falconframework.org/) (3.0+)
|
||||
* [starlette](https://www.starlette.io/)
|
||||
|
@ -31,6 +31,7 @@ CREATE INDEX IF NOT EXISTS idx_placex_geometry ON placex
|
||||
-- Index is needed during import but can be dropped as soon as a full
|
||||
-- geometry index is in place. The partial index is almost as big as the full
|
||||
-- index.
|
||||
---
|
||||
DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways;
|
||||
---
|
||||
CREATE INDEX IF NOT EXISTS idx_placex_geometry_reverse_lookupPolygon
|
||||
@ -60,7 +61,6 @@ CREATE INDEX IF NOT EXISTS idx_postcode_postcode
|
||||
---
|
||||
DROP INDEX IF EXISTS idx_placex_geometry_address_area_candidates;
|
||||
DROP INDEX IF EXISTS idx_placex_geometry_buildings;
|
||||
DROP INDEX IF EXISTS idx_placex_geometry_lower_rank_ways;
|
||||
DROP INDEX IF EXISTS idx_placex_wikidata;
|
||||
DROP INDEX IF EXISTS idx_placex_rank_address_sector;
|
||||
DROP INDEX IF EXISTS idx_placex_rank_boundaries_sector;
|
||||
|
@ -15,7 +15,7 @@ classifiers = [
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
dependencies = [
|
||||
"psycopg2-binary",
|
||||
"psycopg",
|
||||
"python-dotenv",
|
||||
"jinja2",
|
||||
"pyYAML>=5.1",
|
||||
|
@ -7,7 +7,7 @@
|
||||
"""
|
||||
Implementation of classes for API access via libraries.
|
||||
"""
|
||||
from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple
|
||||
from typing import Mapping, Optional, Any, AsyncIterator, Dict, Sequence, List, Tuple, cast
|
||||
import asyncio
|
||||
import sys
|
||||
import contextlib
|
||||
@ -107,16 +107,16 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
|
||||
raise UsageError(f"SQlite database '{params.get('dbname')}' does not exist.")
|
||||
else:
|
||||
dsn = self.config.get_database_params()
|
||||
query = {k: v for k, v in dsn.items()
|
||||
query = {k: str(v) for k, v in dsn.items()
|
||||
if k not in ('user', 'password', 'dbname', 'host', 'port')}
|
||||
|
||||
dburl = sa.engine.URL.create(
|
||||
f'postgresql+{PGCORE_LIB}',
|
||||
database=dsn.get('dbname'),
|
||||
username=dsn.get('user'),
|
||||
password=dsn.get('password'),
|
||||
host=dsn.get('host'),
|
||||
port=int(dsn['port']) if 'port' in dsn else None,
|
||||
database=cast(str, dsn.get('dbname')),
|
||||
username=cast(str, dsn.get('user')),
|
||||
password=cast(str, dsn.get('password')),
|
||||
host=cast(str, dsn.get('host')),
|
||||
port=int(cast(str, dsn['port'])) if 'port' in dsn else None,
|
||||
query=query)
|
||||
|
||||
engine = sa_asyncio.create_async_engine(dburl, **extra_args)
|
||||
|
@ -14,6 +14,7 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from .config import Configuration
|
||||
@ -170,24 +171,32 @@ class AdminServe:
|
||||
raise UsageError("PHP frontend not configured.")
|
||||
run_php_server(args.server, args.project_dir / 'website')
|
||||
else:
|
||||
import uvicorn # pylint: disable=import-outside-toplevel
|
||||
server_info = args.server.split(':', 1)
|
||||
host = server_info[0]
|
||||
if len(server_info) > 1:
|
||||
if not server_info[1].isdigit():
|
||||
raise UsageError('Invalid format for --server parameter. Use <host>:<port>')
|
||||
port = int(server_info[1])
|
||||
else:
|
||||
port = 8088
|
||||
|
||||
server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server')
|
||||
|
||||
app = server_module.get_application(args.project_dir)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
asyncio.run(self.run_uvicorn(args))
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
async def run_uvicorn(self, args: NominatimArgs) -> None:
|
||||
import uvicorn # pylint: disable=import-outside-toplevel
|
||||
|
||||
server_info = args.server.split(':', 1)
|
||||
host = server_info[0]
|
||||
if len(server_info) > 1:
|
||||
if not server_info[1].isdigit():
|
||||
raise UsageError('Invalid format for --server parameter. Use <host>:<port>')
|
||||
port = int(server_info[1])
|
||||
else:
|
||||
port = 8088
|
||||
|
||||
server_module = importlib.import_module(f'nominatim_api.server.{args.engine}.server')
|
||||
|
||||
app = server_module.get_application(args.project_dir)
|
||||
|
||||
config = uvicorn.Config(app, host=host, port=port)
|
||||
server = uvicorn.Server(config)
|
||||
await server.serve()
|
||||
|
||||
|
||||
def get_set_parser() -> CommandlineParser:
|
||||
"""\
|
||||
Initializes the parser and adds various subcommands for
|
||||
|
@ -10,6 +10,7 @@ Implementation of the 'add-data' subcommand.
|
||||
from typing import cast
|
||||
import argparse
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import psutil
|
||||
|
||||
@ -64,15 +65,10 @@ class UpdateAddData:
|
||||
|
||||
|
||||
def run(self, args: NominatimArgs) -> int:
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
from ..tools import tiger_data, add_osm_data
|
||||
from ..tools import add_osm_data
|
||||
|
||||
if args.tiger_data:
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
|
||||
return tiger_data.add_tiger_data(args.tiger_data,
|
||||
args.config,
|
||||
args.threads or psutil.cpu_count() or 1,
|
||||
tokenizer)
|
||||
return asyncio.run(self._add_tiger_data(args))
|
||||
|
||||
osm2pgsql_params = args.osm2pgsql_options(default_cache=1000, default_threads=1)
|
||||
if args.file or args.diff:
|
||||
@ -99,3 +95,16 @@ class UpdateAddData:
|
||||
osm2pgsql_params)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
async def _add_tiger_data(self, args: NominatimArgs) -> int:
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
from ..tools import tiger_data
|
||||
|
||||
assert args.tiger_data
|
||||
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
|
||||
return await tiger_data.add_tiger_data(args.tiger_data,
|
||||
args.config,
|
||||
args.threads or psutil.cpu_count() or 1,
|
||||
tokenizer)
|
||||
|
@ -12,7 +12,7 @@ import argparse
|
||||
import random
|
||||
|
||||
from ..errors import UsageError
|
||||
from ..db.connection import connect
|
||||
from ..db.connection import connect, table_exists
|
||||
from .args import NominatimArgs
|
||||
|
||||
# Do not repeat documentation of subcommand classes.
|
||||
@ -115,7 +115,7 @@ class AdminFuncs:
|
||||
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
|
||||
with connect(args.config.get_libpq_dsn()) as conn:
|
||||
if conn.table_exists('search_name'):
|
||||
if table_exists(conn, 'search_name'):
|
||||
words = tokenizer.most_frequent_words(conn, 1000)
|
||||
else:
|
||||
words = []
|
||||
|
@ -8,6 +8,7 @@
|
||||
Implementation of the 'index' subcommand.
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
|
||||
import psutil
|
||||
|
||||
@ -44,19 +45,7 @@ class UpdateIndex:
|
||||
|
||||
|
||||
def run(self, args: NominatimArgs) -> int:
|
||||
from ..indexer.indexer import Indexer
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
|
||||
|
||||
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
|
||||
args.threads or psutil.cpu_count() or 1)
|
||||
|
||||
if not args.no_boundaries:
|
||||
indexer.index_boundaries(args.minrank, args.maxrank)
|
||||
if not args.boundaries_only:
|
||||
indexer.index_by_rank(args.minrank, args.maxrank)
|
||||
indexer.index_postcodes()
|
||||
asyncio.run(self._do_index(args))
|
||||
|
||||
if not args.no_boundaries and not args.boundaries_only \
|
||||
and args.minrank == 0 and args.maxrank == 30:
|
||||
@ -64,3 +53,22 @@ class UpdateIndex:
|
||||
status.set_indexed(conn, True)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
async def _do_index(self, args: NominatimArgs) -> None:
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(args.config)
|
||||
from ..indexer.indexer import Indexer
|
||||
|
||||
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
|
||||
args.threads or psutil.cpu_count() or 1)
|
||||
|
||||
has_pending = True # run at least once
|
||||
while has_pending:
|
||||
if not args.no_boundaries:
|
||||
await indexer.index_boundaries(args.minrank, args.maxrank)
|
||||
if not args.boundaries_only:
|
||||
await indexer.index_by_rank(args.minrank, args.maxrank)
|
||||
await indexer.index_postcodes()
|
||||
has_pending = indexer.has_pending()
|
||||
|
@ -11,9 +11,10 @@ from typing import Tuple, Optional
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
from ..config import Configuration
|
||||
from ..db.connection import connect
|
||||
from ..db.connection import connect, table_exists
|
||||
from ..tokenizer.base import AbstractTokenizer
|
||||
from .args import NominatimArgs
|
||||
|
||||
@ -99,7 +100,7 @@ class UpdateRefresh:
|
||||
args.project_dir, tokenizer)
|
||||
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer,
|
||||
args.threads or 1)
|
||||
indexer.index_postcodes()
|
||||
asyncio.run(indexer.index_postcodes())
|
||||
else:
|
||||
LOG.error("The place table doesn't exist. "
|
||||
"Postcode updates on a frozen database is not possible.")
|
||||
@ -124,7 +125,7 @@ class UpdateRefresh:
|
||||
with connect(args.config.get_libpq_dsn()) as conn:
|
||||
# If the table did not exist before, then the importance code
|
||||
# needs to be enabled.
|
||||
if not conn.table_exists('secondary_importance'):
|
||||
if not table_exists(conn, 'secondary_importance'):
|
||||
args.functions = True
|
||||
|
||||
LOG.warning('Import secondary importance raster data from %s', args.project_dir)
|
||||
|
@ -13,6 +13,7 @@ import datetime as dt
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
from ..db import status
|
||||
from ..db.connection import connect
|
||||
@ -123,7 +124,7 @@ class UpdateReplication:
|
||||
return update_interval
|
||||
|
||||
|
||||
def _update(self, args: NominatimArgs) -> None:
|
||||
async def _update(self, args: NominatimArgs) -> None:
|
||||
# pylint: disable=too-many-locals
|
||||
from ..tools import replication
|
||||
from ..indexer.indexer import Indexer
|
||||
@ -161,7 +162,7 @@ class UpdateReplication:
|
||||
|
||||
if state is not replication.UpdateState.NO_CHANGES and args.do_index:
|
||||
index_start = dt.datetime.now(dt.timezone.utc)
|
||||
indexer.index_full(analyse=False)
|
||||
await indexer.index_full(analyse=False)
|
||||
|
||||
with connect(dsn) as conn:
|
||||
status.set_indexed(conn, True)
|
||||
@ -172,8 +173,7 @@ class UpdateReplication:
|
||||
|
||||
if state is replication.UpdateState.NO_CHANGES and \
|
||||
args.catch_up or update_interval > 40*60:
|
||||
while indexer.has_pending():
|
||||
indexer.index_full(analyse=False)
|
||||
await indexer.index_full(analyse=False)
|
||||
|
||||
if LOG.isEnabledFor(logging.WARNING):
|
||||
assert batchdate is not None
|
||||
@ -196,5 +196,5 @@ class UpdateReplication:
|
||||
if args.check_for_updates:
|
||||
return self._check_for_updates(args)
|
||||
|
||||
self._update(args)
|
||||
asyncio.run(self._update(args))
|
||||
return 0
|
||||
|
@ -11,6 +11,7 @@ from typing import Optional
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
|
||||
import psutil
|
||||
|
||||
@ -71,14 +72,6 @@ class SetupAll:
|
||||
|
||||
|
||||
def run(self, args: NominatimArgs) -> int: # pylint: disable=too-many-statements, too-many-branches
|
||||
from ..data import country_info
|
||||
from ..tools import database_import, refresh, postcodes, freeze
|
||||
from ..indexer.indexer import Indexer
|
||||
|
||||
num_threads = args.threads or psutil.cpu_count() or 1
|
||||
|
||||
country_info.setup_country_config(args.config)
|
||||
|
||||
if args.osm_file is None and args.continue_at is None and not args.prepare_database:
|
||||
raise UsageError("No input files (use --osm-file).")
|
||||
|
||||
@ -90,6 +83,16 @@ class SetupAll:
|
||||
"Cannot use --continue and --prepare-database together."
|
||||
)
|
||||
|
||||
return asyncio.run(self.async_run(args))
|
||||
|
||||
|
||||
async def async_run(self, args: NominatimArgs) -> int:
|
||||
from ..data import country_info
|
||||
from ..tools import database_import, refresh, postcodes, freeze
|
||||
from ..indexer.indexer import Indexer
|
||||
|
||||
num_threads = args.threads or psutil.cpu_count() or 1
|
||||
country_info.setup_country_config(args.config)
|
||||
|
||||
if args.prepare_database or args.continue_at is None:
|
||||
LOG.warning('Creating database')
|
||||
@ -99,39 +102,7 @@ class SetupAll:
|
||||
return 0
|
||||
|
||||
if args.continue_at in (None, 'import-from-file'):
|
||||
files = args.get_osm_file_list()
|
||||
if not files:
|
||||
raise UsageError("No input files (use --osm-file).")
|
||||
|
||||
if args.continue_at in ('import-from-file', None):
|
||||
# Check if the correct plugins are installed
|
||||
database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
|
||||
LOG.warning('Setting up country tables')
|
||||
country_info.setup_country_tables(args.config.get_libpq_dsn(),
|
||||
args.config.lib_dir.data,
|
||||
args.no_partitions)
|
||||
|
||||
LOG.warning('Importing OSM data file')
|
||||
database_import.import_osm_data(files,
|
||||
args.osm2pgsql_options(0, 1),
|
||||
drop=args.no_updates,
|
||||
ignore_errors=args.ignore_errors)
|
||||
|
||||
LOG.warning('Importing wikipedia importance data')
|
||||
data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
|
||||
if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
|
||||
data_path) > 0:
|
||||
LOG.error('Wikipedia importance dump file not found. '
|
||||
'Calculating importance values of locations will not '
|
||||
'use Wikipedia importance data.')
|
||||
|
||||
LOG.warning('Importing secondary importance raster data')
|
||||
if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
|
||||
args.project_dir) != 0:
|
||||
LOG.error('Secondary importance file not imported. '
|
||||
'Falling back to default ranking.')
|
||||
|
||||
self._setup_tables(args.config, args.reverse_only)
|
||||
self._base_import(args)
|
||||
|
||||
if args.continue_at in ('import-from-file', 'load-data', None):
|
||||
LOG.warning('Initialise tables')
|
||||
@ -139,7 +110,7 @@ class SetupAll:
|
||||
database_import.truncate_data_tables(conn)
|
||||
|
||||
LOG.warning('Load data into placex table')
|
||||
database_import.load_data(args.config.get_libpq_dsn(), num_threads)
|
||||
await database_import.load_data(args.config.get_libpq_dsn(), num_threads)
|
||||
|
||||
LOG.warning("Setting up tokenizer")
|
||||
tokenizer = self._get_tokenizer(args.continue_at, args.config)
|
||||
@ -153,13 +124,13 @@ class SetupAll:
|
||||
('import-from-file', 'load-data', 'indexing', None):
|
||||
LOG.warning('Indexing places')
|
||||
indexer = Indexer(args.config.get_libpq_dsn(), tokenizer, num_threads)
|
||||
indexer.index_full(analyse=not args.index_noanalyse)
|
||||
await indexer.index_full(analyse=not args.index_noanalyse)
|
||||
|
||||
LOG.warning('Post-process tables')
|
||||
with connect(args.config.get_libpq_dsn()) as conn:
|
||||
database_import.create_search_indices(conn, args.config,
|
||||
drop=args.no_updates,
|
||||
threads=num_threads)
|
||||
await database_import.create_search_indices(conn, args.config,
|
||||
drop=args.no_updates,
|
||||
threads=num_threads)
|
||||
LOG.warning('Create search index for default country names.')
|
||||
country_info.create_country_names(conn, tokenizer,
|
||||
args.config.get_str_list('LANGUAGES'))
|
||||
@ -180,6 +151,45 @@ class SetupAll:
|
||||
return 0
|
||||
|
||||
|
||||
def _base_import(self, args: NominatimArgs) -> None:
|
||||
from ..tools import database_import, refresh
|
||||
from ..data import country_info
|
||||
|
||||
files = args.get_osm_file_list()
|
||||
if not files:
|
||||
raise UsageError("No input files (use --osm-file).")
|
||||
|
||||
if args.continue_at in ('import-from-file', None):
|
||||
# Check if the correct plugins are installed
|
||||
database_import.check_existing_database_plugins(args.config.get_libpq_dsn())
|
||||
LOG.warning('Setting up country tables')
|
||||
country_info.setup_country_tables(args.config.get_libpq_dsn(),
|
||||
args.config.lib_dir.data,
|
||||
args.no_partitions)
|
||||
|
||||
LOG.warning('Importing OSM data file')
|
||||
database_import.import_osm_data(files,
|
||||
args.osm2pgsql_options(0, 1),
|
||||
drop=args.no_updates,
|
||||
ignore_errors=args.ignore_errors)
|
||||
|
||||
LOG.warning('Importing wikipedia importance data')
|
||||
data_path = Path(args.config.WIKIPEDIA_DATA_PATH or args.project_dir)
|
||||
if refresh.import_wikipedia_articles(args.config.get_libpq_dsn(),
|
||||
data_path) > 0:
|
||||
LOG.error('Wikipedia importance dump file not found. '
|
||||
'Calculating importance values of locations will not '
|
||||
'use Wikipedia importance data.')
|
||||
|
||||
LOG.warning('Importing secondary importance raster data')
|
||||
if refresh.import_secondary_importance(args.config.get_libpq_dsn(),
|
||||
args.project_dir) != 0:
|
||||
LOG.error('Secondary importance file not imported. '
|
||||
'Falling back to default ranking.')
|
||||
|
||||
self._setup_tables(args.config, args.reverse_only)
|
||||
|
||||
|
||||
def _setup_tables(self, config: Configuration, reverse_only: bool) -> None:
|
||||
""" Set up the basic database layout: tables, indexes and functions.
|
||||
"""
|
||||
|
@ -7,7 +7,7 @@
|
||||
"""
|
||||
Nominatim configuration accessor.
|
||||
"""
|
||||
from typing import Dict, Any, List, Mapping, Optional
|
||||
from typing import Union, Dict, Any, List, Mapping, Optional
|
||||
import importlib.util
|
||||
import logging
|
||||
import os
|
||||
@ -18,10 +18,7 @@ import yaml
|
||||
|
||||
from dotenv import dotenv_values
|
||||
|
||||
try:
|
||||
from psycopg2.extensions import parse_dsn
|
||||
except ModuleNotFoundError:
|
||||
from psycopg.conninfo import conninfo_to_dict as parse_dsn # type: ignore[assignment]
|
||||
from psycopg.conninfo import conninfo_to_dict
|
||||
|
||||
from .typing import StrPath
|
||||
from .errors import UsageError
|
||||
@ -198,7 +195,7 @@ class Configuration:
|
||||
return dsn
|
||||
|
||||
|
||||
def get_database_params(self) -> Mapping[str, str]:
|
||||
def get_database_params(self) -> Mapping[str, Union[str, int, None]]:
|
||||
""" Get the configured parameters for the database connection
|
||||
as a mapping.
|
||||
"""
|
||||
@ -207,7 +204,7 @@ class Configuration:
|
||||
if dsn.startswith('pgsql:'):
|
||||
return dict((p.split('=', 1) for p in dsn[6:].split(';')))
|
||||
|
||||
return parse_dsn(dsn)
|
||||
return conninfo_to_dict(dsn)
|
||||
|
||||
|
||||
def get_import_style_file(self) -> Path:
|
||||
|
@ -9,10 +9,9 @@ Functions for importing and managing static country information.
|
||||
"""
|
||||
from typing import Dict, Any, Iterable, Tuple, Optional, Container, overload
|
||||
from pathlib import Path
|
||||
import psycopg2.extras
|
||||
|
||||
from ..db import utils as db_utils
|
||||
from ..db.connection import connect, Connection
|
||||
from ..db.connection import connect, Connection, register_hstore
|
||||
from ..errors import UsageError
|
||||
from ..config import Configuration
|
||||
from ..tokenizer.base import AbstractTokenizer
|
||||
@ -129,8 +128,8 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
|
||||
|
||||
params.append((ccode, props['names'], lang, partition))
|
||||
with connect(dsn) as conn:
|
||||
register_hstore(conn)
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.register_hstore(cur)
|
||||
cur.execute(
|
||||
""" CREATE TABLE public.country_name (
|
||||
country_code character varying(2),
|
||||
@ -139,9 +138,10 @@ def setup_country_tables(dsn: str, sql_dir: Path, ignore_partitions: bool = Fals
|
||||
country_default_language_code text,
|
||||
partition integer
|
||||
); """)
|
||||
cur.execute_values(
|
||||
cur.executemany(
|
||||
""" INSERT INTO public.country_name
|
||||
(country_code, name, country_default_language_code, partition) VALUES %s
|
||||
(country_code, name, country_default_language_code, partition)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
""", params)
|
||||
conn.commit()
|
||||
|
||||
@ -157,8 +157,8 @@ def create_country_names(conn: Connection, tokenizer: AbstractTokenizer,
|
||||
return ':' not in key or not languages or \
|
||||
key[key.index(':') + 1:] in languages
|
||||
|
||||
register_hstore(conn)
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.register_hstore(cur)
|
||||
cur.execute("""SELECT country_code, name FROM country_name
|
||||
WHERE country_code is not null""")
|
||||
|
||||
|
@ -1,236 +0,0 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
""" Non-blocking database connections.
|
||||
"""
|
||||
from typing import Callable, Any, Optional, Iterator, Sequence
|
||||
import logging
|
||||
import select
|
||||
import time
|
||||
|
||||
import psycopg2
|
||||
from psycopg2.extras import wait_select
|
||||
|
||||
# psycopg2 emits different exceptions pre and post 2.8. Detect if the new error
|
||||
# module is available and adapt the error handling accordingly.
|
||||
try:
|
||||
import psycopg2.errors # pylint: disable=no-name-in-module,import-error
|
||||
__has_psycopg2_errors__ = True
|
||||
except ImportError:
|
||||
__has_psycopg2_errors__ = False
|
||||
|
||||
from ..typing import T_cursor, Query
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class DeadlockHandler:
|
||||
""" Context manager that catches deadlock exceptions and calls
|
||||
the given handler function. All other exceptions are passed on
|
||||
normally.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Callable[[], None], ignore_sql_errors: bool = False) -> None:
|
||||
self.handler = handler
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
def __enter__(self) -> 'DeadlockHandler':
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> bool:
|
||||
if __has_psycopg2_errors__:
|
||||
if exc_type == psycopg2.errors.DeadlockDetected: # pylint: disable=E1101
|
||||
self.handler()
|
||||
return True
|
||||
elif exc_type == psycopg2.extensions.TransactionRollbackError \
|
||||
and exc_value.pgcode == '40P01':
|
||||
self.handler()
|
||||
return True
|
||||
|
||||
if self.ignore_sql_errors and isinstance(exc_value, psycopg2.Error):
|
||||
LOG.info("SQL error ignored: %s", exc_value)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class DBConnection:
|
||||
""" A single non-blocking database connection.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str,
|
||||
cursor_factory: Optional[Callable[..., T_cursor]] = None,
|
||||
ignore_sql_errors: bool = False) -> None:
|
||||
self.dsn = dsn
|
||||
|
||||
self.current_query: Optional[Query] = None
|
||||
self.current_params: Optional[Sequence[Any]] = None
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
self.conn: Optional['psycopg2._psycopg.connection'] = None
|
||||
self.cursor: Optional['psycopg2._psycopg.cursor'] = None
|
||||
self.connect(cursor_factory=cursor_factory)
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all open connections. Does not wait for pending requests.
|
||||
"""
|
||||
if self.conn is not None:
|
||||
if self.cursor is not None:
|
||||
self.cursor.close()
|
||||
self.cursor = None
|
||||
self.conn.close()
|
||||
|
||||
self.conn = None
|
||||
|
||||
def connect(self, cursor_factory: Optional[Callable[..., T_cursor]] = None) -> None:
|
||||
""" (Re)connect to the database. Creates an asynchronous connection
|
||||
with JIT and parallel processing disabled. If a connection was
|
||||
already open, it is closed and a new connection established.
|
||||
The caller must ensure that no query is pending before reconnecting.
|
||||
"""
|
||||
self.close()
|
||||
|
||||
# Use a dict to hand in the parameters because async is a reserved
|
||||
# word in Python3.
|
||||
self.conn = psycopg2.connect(**{'dsn': self.dsn, 'async': True}) # type: ignore
|
||||
assert self.conn
|
||||
self.wait()
|
||||
|
||||
if cursor_factory is not None:
|
||||
self.cursor = self.conn.cursor(cursor_factory=cursor_factory)
|
||||
else:
|
||||
self.cursor = self.conn.cursor()
|
||||
# Disable JIT and parallel workers as they are known to cause problems.
|
||||
# Update pg_settings instead of using SET because it does not yield
|
||||
# errors on older versions of Postgres where the settings are not
|
||||
# implemented.
|
||||
self.perform(
|
||||
""" UPDATE pg_settings SET setting = -1 WHERE name = 'jit_above_cost';
|
||||
UPDATE pg_settings SET setting = 0
|
||||
WHERE name = 'max_parallel_workers_per_gather';""")
|
||||
self.wait()
|
||||
|
||||
def _deadlock_handler(self) -> None:
|
||||
LOG.info("Deadlock detected (params = %s), retry.", str(self.current_params))
|
||||
assert self.cursor is not None
|
||||
assert self.current_query is not None
|
||||
assert self.current_params is not None
|
||||
|
||||
self.cursor.execute(self.current_query, self.current_params)
|
||||
|
||||
def wait(self) -> None:
|
||||
""" Block until any pending operation is done.
|
||||
"""
|
||||
while True:
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
wait_select(self.conn)
|
||||
self.current_query = None
|
||||
return
|
||||
|
||||
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
|
||||
""" Send SQL query to the server. Returns immediately without
|
||||
blocking.
|
||||
"""
|
||||
assert self.cursor is not None
|
||||
self.current_query = sql
|
||||
self.current_params = args
|
||||
self.cursor.execute(sql, args)
|
||||
|
||||
def fileno(self) -> int:
|
||||
""" File descriptor to wait for. (Makes this class select()able.)
|
||||
"""
|
||||
assert self.conn is not None
|
||||
return self.conn.fileno()
|
||||
|
||||
def is_done(self) -> bool:
|
||||
""" Check if the connection is available for a new query.
|
||||
|
||||
Also checks if the previous query has run into a deadlock.
|
||||
If so, then the previous query is repeated.
|
||||
"""
|
||||
assert self.conn is not None
|
||||
|
||||
if self.current_query is None:
|
||||
return True
|
||||
|
||||
with DeadlockHandler(self._deadlock_handler, self.ignore_sql_errors):
|
||||
if self.conn.poll() == psycopg2.extensions.POLL_OK:
|
||||
self.current_query = None
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class WorkerPool:
|
||||
""" A pool of asynchronous database connections.
|
||||
|
||||
The pool may be used as a context manager.
|
||||
"""
|
||||
REOPEN_CONNECTIONS_AFTER = 100000
|
||||
|
||||
def __init__(self, dsn: str, pool_size: int, ignore_sql_errors: bool = False) -> None:
|
||||
self.threads = [DBConnection(dsn, ignore_sql_errors=ignore_sql_errors)
|
||||
for _ in range(pool_size)]
|
||||
self.free_workers = self._yield_free_worker()
|
||||
self.wait_time = 0.0
|
||||
|
||||
|
||||
def finish_all(self) -> None:
|
||||
""" Wait for all connection to finish.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
|
||||
self.free_workers = self._yield_free_worker()
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close all connections and clear the pool.
|
||||
"""
|
||||
for thread in self.threads:
|
||||
thread.close()
|
||||
self.threads = []
|
||||
self.free_workers = iter([])
|
||||
|
||||
|
||||
def next_free_worker(self) -> DBConnection:
|
||||
""" Get the next free connection.
|
||||
"""
|
||||
return next(self.free_workers)
|
||||
|
||||
|
||||
def _yield_free_worker(self) -> Iterator[DBConnection]:
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
while True:
|
||||
for thread in ready:
|
||||
if thread.is_done():
|
||||
command_stat += 1
|
||||
yield thread
|
||||
|
||||
if command_stat > self.REOPEN_CONNECTIONS_AFTER:
|
||||
self._reconnect_threads()
|
||||
ready = self.threads
|
||||
command_stat = 0
|
||||
else:
|
||||
tstart = time.time()
|
||||
_, ready, _ = select.select([], self.threads, [])
|
||||
self.wait_time += time.time() - tstart
|
||||
|
||||
|
||||
def _reconnect_threads(self) -> None:
|
||||
for thread in self.threads:
|
||||
while not thread.is_done():
|
||||
thread.wait()
|
||||
thread.connect()
|
||||
|
||||
|
||||
def __enter__(self) -> 'WorkerPool':
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
self.finish_all()
|
||||
self.close()
|
@ -7,200 +7,136 @@
|
||||
"""
|
||||
Specialised connection and cursor functions.
|
||||
"""
|
||||
from typing import Optional, Any, Callable, ContextManager, Dict, cast, overload, Tuple, Iterable
|
||||
import contextlib
|
||||
from typing import Optional, Any, Dict, Tuple
|
||||
import logging
|
||||
import os
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
from psycopg2 import sql as pysql
|
||||
import psycopg
|
||||
import psycopg.types.hstore
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..typing import SysEnv, Query, T_cursor
|
||||
from ..typing import SysEnv
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
class Cursor(psycopg2.extras.DictCursor):
|
||||
""" A cursor returning dict-like objects and providing specialised
|
||||
execution functions.
|
||||
Cursor = psycopg.Cursor[Any]
|
||||
Connection = psycopg.Connection[Any]
|
||||
|
||||
def execute_scalar(conn: Connection, sql: psycopg.abc.Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
# pylint: disable=arguments-renamed,arguments-differ
|
||||
def execute(self, query: Query, args: Any = None) -> None:
|
||||
""" Query execution that logs the SQL query when debugging is enabled.
|
||||
"""
|
||||
if LOG.isEnabledFor(logging.DEBUG):
|
||||
LOG.debug(self.mogrify(query, args).decode('utf-8'))
|
||||
with conn.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
cur.execute(sql, args)
|
||||
|
||||
super().execute(query, args)
|
||||
|
||||
|
||||
def execute_values(self, sql: Query, argslist: Iterable[Tuple[Any, ...]],
|
||||
template: Optional[Query] = None) -> None:
|
||||
""" Wrapper for the psycopg2 convenience function to execute
|
||||
SQL for a list of values.
|
||||
"""
|
||||
LOG.debug("SQL execute_values(%s, %s)", sql, argslist)
|
||||
|
||||
psycopg2.extras.execute_values(self, sql, argslist, template=template)
|
||||
|
||||
|
||||
def scalar(self, sql: Query, args: Any = None) -> Any:
|
||||
""" Execute query that returns a single value. The value is returned.
|
||||
If the query yields more than one row, a ValueError is raised.
|
||||
"""
|
||||
self.execute(sql, args)
|
||||
|
||||
if self.rowcount != 1:
|
||||
if cur.rowcount != 1:
|
||||
raise RuntimeError("Query did not return a single row.")
|
||||
|
||||
result = self.fetchone()
|
||||
assert result is not None
|
||||
result = cur.fetchone()
|
||||
|
||||
return result[0]
|
||||
assert result is not None
|
||||
return result[0]
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored. If 'cascade' is set
|
||||
to True then all dependent tables are deleted as well.
|
||||
"""
|
||||
sql = 'DROP TABLE '
|
||||
if if_exists:
|
||||
sql += 'IF EXISTS '
|
||||
sql += '{}'
|
||||
if cascade:
|
||||
sql += ' CASCADE'
|
||||
|
||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
""" A connection that provides the specialised cursor by default and
|
||||
adds convenience functions for administrating the database.
|
||||
def table_exists(conn: Connection, table: str) -> bool:
|
||||
""" Check that a table with the given name exists in the database.
|
||||
"""
|
||||
@overload # type: ignore[override]
|
||||
def cursor(self) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, name: str) -> Cursor:
|
||||
...
|
||||
|
||||
@overload
|
||||
def cursor(self, cursor_factory: Callable[..., T_cursor]) -> T_cursor:
|
||||
...
|
||||
|
||||
def cursor(self, cursor_factory = Cursor, **kwargs): # type: ignore
|
||||
""" Return a new cursor. By default the specialised cursor is returned.
|
||||
"""
|
||||
return super().cursor(cursor_factory=cursor_factory, **kwargs)
|
||||
num = execute_scalar(conn,
|
||||
"""SELECT count(*) FROM pg_tables
|
||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||
return num == 1 if isinstance(num, int) else False
|
||||
|
||||
|
||||
def table_exists(self, table: str) -> bool:
|
||||
""" Check that a table with the given name exists in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
num = cur.scalar("""SELECT count(*) FROM pg_tables
|
||||
WHERE tablename = %s and schemaname = 'public'""", (table, ))
|
||||
return num == 1 if isinstance(num, int) else False
|
||||
def table_has_column(conn: Connection, table: str, column: str) -> bool:
|
||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||
"""
|
||||
has_column = execute_scalar(conn,
|
||||
"""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s and column_name = %s""",
|
||||
(table, column))
|
||||
return has_column > 0 if isinstance(has_column, int) else False
|
||||
|
||||
|
||||
def table_has_column(self, table: str, column: str) -> bool:
|
||||
""" Check if the table 'table' exists and has a column with name 'column'.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s
|
||||
and column_name = %s""",
|
||||
(table, column))
|
||||
return has_column > 0 if isinstance(has_column, int) else False
|
||||
def index_exists(conn: Connection, index: str, table: Optional[str] = None) -> bool:
|
||||
""" Check that an index with the given name exists in the database.
|
||||
If table is not None then the index must relate to the given
|
||||
table.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""SELECT tablename FROM pg_indexes
|
||||
WHERE indexname = %s and schemaname = 'public'""", (index, ))
|
||||
if cur.rowcount == 0:
|
||||
return False
|
||||
|
||||
|
||||
def index_exists(self, index: str, table: Optional[str] = None) -> bool:
|
||||
""" Check that an index with the given name exists in the database.
|
||||
If table is not None then the index must relate to the given
|
||||
table.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute("""SELECT tablename FROM pg_indexes
|
||||
WHERE indexname = %s and schemaname = 'public'""", (index, ))
|
||||
if cur.rowcount == 0:
|
||||
if table is not None:
|
||||
row = cur.fetchone()
|
||||
if row is None or not isinstance(row[0], str):
|
||||
return False
|
||||
return row[0] == table
|
||||
|
||||
if table is not None:
|
||||
row = cur.fetchone()
|
||||
if row is None or not isinstance(row[0], str):
|
||||
return False
|
||||
return row[0] == table
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def drop_table(self, name: str, if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop the table with the given name.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.drop_table(name, if_exists, cascade)
|
||||
self.commit()
|
||||
|
||||
|
||||
def server_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = self.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
return (int(version / 10000), version % 10000)
|
||||
|
||||
|
||||
def postgis_version_tuple(self) -> Tuple[int, int]:
|
||||
""" Return the postgis version installed in the database as a
|
||||
tuple of (major, minor). Assumes that the PostGIS extension
|
||||
has been installed already.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
version = cur.scalar('SELECT postgis_lib_version()')
|
||||
|
||||
version_parts = version.split('.')
|
||||
if len(version_parts) < 2:
|
||||
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
|
||||
def extension_loaded(self, extension_name: str) -> bool:
|
||||
""" Return True if the hstore extension is loaded in the database.
|
||||
"""
|
||||
with self.cursor() as cur:
|
||||
cur.execute('SELECT extname FROM pg_extension WHERE extname = %s', (extension_name, ))
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
class ConnectionContext(ContextManager[Connection]):
|
||||
""" Context manager of the connection that also provides direct access
|
||||
to the underlying connection.
|
||||
def drop_tables(conn: Connection, *names: str,
|
||||
if_exists: bool = True, cascade: bool = False) -> None:
|
||||
""" Drop one or more tables with the given names.
|
||||
Set `if_exists` to False if a non-existent table should raise
|
||||
an exception instead of just being ignored. `cascade` will cause
|
||||
depended objects to be dropped as well.
|
||||
The caller needs to take care of committing the change.
|
||||
"""
|
||||
connection: Connection
|
||||
sql = pysql.SQL('DROP TABLE%s{}%s' % (
|
||||
' IF EXISTS ' if if_exists else ' ',
|
||||
' CASCADE' if cascade else ''))
|
||||
|
||||
def connect(dsn: str) -> ConnectionContext:
|
||||
with conn.cursor() as cur:
|
||||
for name in names:
|
||||
cur.execute(sql.format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
def server_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
""" Return the server version as a tuple of (major, minor).
|
||||
Converts correctly for pre-10 and post-10 PostgreSQL versions.
|
||||
"""
|
||||
version = conn.info.server_version
|
||||
if version < 100000:
|
||||
return (int(version / 10000), int((version % 10000) / 100))
|
||||
|
||||
return (int(version / 10000), version % 10000)
|
||||
|
||||
|
||||
def postgis_version_tuple(conn: Connection) -> Tuple[int, int]:
|
||||
""" Return the postgis version installed in the database as a
|
||||
tuple of (major, minor). Assumes that the PostGIS extension
|
||||
has been installed already.
|
||||
"""
|
||||
version = execute_scalar(conn, 'SELECT postgis_lib_version()')
|
||||
|
||||
version_parts = version.split('.')
|
||||
if len(version_parts) < 2:
|
||||
raise UsageError(f"Error fetching Postgis version. Bad format: {version}")
|
||||
|
||||
return (int(version_parts[0]), int(version_parts[1]))
|
||||
|
||||
|
||||
def register_hstore(conn: Connection) -> None:
|
||||
""" Register the hstore type with psycopg for the connection.
|
||||
"""
|
||||
info = psycopg.types.TypeInfo.fetch(conn, "hstore")
|
||||
if info is None:
|
||||
raise RuntimeError('Hstore extension is requested but not installed.')
|
||||
psycopg.types.hstore.register_hstore(info, conn)
|
||||
|
||||
|
||||
def connect(dsn: str, **kwargs: Any) -> Connection:
|
||||
""" Open a connection to the database using the specialised connection
|
||||
factory. The returned object may be used in conjunction with 'with'.
|
||||
When used outside a context manager, use the `connection` attribute
|
||||
to get the connection.
|
||||
"""
|
||||
try:
|
||||
conn = psycopg2.connect(dsn, connection_factory=Connection)
|
||||
ctxmgr = cast(ConnectionContext, contextlib.closing(conn))
|
||||
ctxmgr.connection = conn
|
||||
return ctxmgr
|
||||
except psycopg2.OperationalError as err:
|
||||
return psycopg.connect(dsn, row_factory=psycopg.rows.namedtuple_row, **kwargs)
|
||||
except psycopg.OperationalError as err:
|
||||
raise UsageError(f"Cannot connect to database: {err}") from err
|
||||
|
||||
|
||||
@ -245,10 +181,18 @@ def get_pg_env(dsn: str,
|
||||
"""
|
||||
env = dict(base_env if base_env is not None else os.environ)
|
||||
|
||||
for param, value in psycopg2.extensions.parse_dsn(dsn).items():
|
||||
for param, value in psycopg.conninfo.conninfo_to_dict(dsn).items():
|
||||
if param in _PG_CONNECTION_STRINGS:
|
||||
env[_PG_CONNECTION_STRINGS[param]] = value
|
||||
env[_PG_CONNECTION_STRINGS[param]] = str(value)
|
||||
else:
|
||||
LOG.error("Unknown connection parameter '%s' ignored.", param)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
async def run_async_query(dsn: str, query: psycopg.abc.Query) -> None:
|
||||
""" Open a connection to the database and run a single query
|
||||
asynchronously.
|
||||
"""
|
||||
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
|
||||
await aconn.execute(query)
|
||||
|
@ -9,7 +9,7 @@ Query and access functions for the in-database property table.
|
||||
"""
|
||||
from typing import Optional, cast
|
||||
|
||||
from .connection import Connection
|
||||
from .connection import Connection, table_exists
|
||||
|
||||
def set_property(conn: Connection, name: str, value: str) -> None:
|
||||
""" Add or replace the property with the given name.
|
||||
@ -31,7 +31,7 @@ def get_property(conn: Connection, name: str) -> Optional[str]:
|
||||
""" Return the current value of the given property or None if the property
|
||||
is not set.
|
||||
"""
|
||||
if not conn.table_exists('nominatim_properties'):
|
||||
if not table_exists(conn, 'nominatim_properties'):
|
||||
return None
|
||||
|
||||
with conn.cursor() as cur:
|
||||
|
87
src/nominatim_db/db/query_pool.py
Normal file
87
src/nominatim_db/db/query_pool.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
A connection pool that executes incoming queries in parallel.
|
||||
"""
|
||||
from typing import Any, Tuple, Optional
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
|
||||
import psycopg
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
QueueItem = Optional[Tuple[psycopg.abc.Query, Any]]
|
||||
|
||||
class QueryPool:
|
||||
""" Pool to run SQL queries in parallel asynchronous execution.
|
||||
|
||||
All queries are run in autocommit mode. If parallel execution leads
|
||||
to a deadlock, then the query is repeated.
|
||||
The results of the queries is discarded.
|
||||
"""
|
||||
def __init__(self, dsn: str, pool_size: int = 1, **conn_args: Any) -> None:
|
||||
self.wait_time = 0.0
|
||||
self.query_queue: 'asyncio.Queue[QueueItem]' = asyncio.Queue(maxsize=2 * pool_size)
|
||||
|
||||
self.pool = [asyncio.create_task(self._worker_loop(dsn, **conn_args))
|
||||
for _ in range(pool_size)]
|
||||
|
||||
|
||||
async def put_query(self, query: psycopg.abc.Query, params: Any) -> None:
|
||||
""" Schedule a query for execution.
|
||||
"""
|
||||
tstart = time.time()
|
||||
await self.query_queue.put((query, params))
|
||||
self.wait_time += time.time() - tstart
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
async def finish(self) -> None:
|
||||
""" Wait for all queries to finish and close the pool.
|
||||
"""
|
||||
for _ in self.pool:
|
||||
await self.query_queue.put(None)
|
||||
|
||||
tstart = time.time()
|
||||
await asyncio.wait(self.pool)
|
||||
self.wait_time += time.time() - tstart
|
||||
|
||||
for task in self.pool:
|
||||
excp = task.exception()
|
||||
if excp is not None:
|
||||
raise excp
|
||||
|
||||
|
||||
async def _worker_loop(self, dsn: str, **conn_args: Any) -> None:
|
||||
conn_args['autocommit'] = True
|
||||
aconn = await psycopg.AsyncConnection.connect(dsn, **conn_args)
|
||||
async with aconn:
|
||||
async with aconn.cursor() as cur:
|
||||
item = await self.query_queue.get()
|
||||
while item is not None:
|
||||
try:
|
||||
if item[1] is None:
|
||||
await cur.execute(item[0])
|
||||
else:
|
||||
await cur.execute(item[0], item[1])
|
||||
|
||||
item = await self.query_queue.get()
|
||||
except psycopg.errors.DeadlockDetected:
|
||||
assert item is not None
|
||||
LOG.info("Deadlock detected (sql = %s, params = %s), retry.",
|
||||
str(item[0]), str(item[1]))
|
||||
# item is still valid here, causing a retry
|
||||
|
||||
|
||||
async def __aenter__(self) -> 'QueryPool':
|
||||
return self
|
||||
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
await self.finish()
|
@ -8,11 +8,12 @@
|
||||
Preprocessing of SQL files.
|
||||
"""
|
||||
from typing import Set, Dict, Any, cast
|
||||
|
||||
import jinja2
|
||||
|
||||
from .connection import Connection
|
||||
from .async_connection import WorkerPool
|
||||
from .connection import Connection, server_version_tuple, postgis_version_tuple
|
||||
from ..config import Configuration
|
||||
from ..db.query_pool import QueryPool
|
||||
|
||||
def _get_partitions(conn: Connection) -> Set[int]:
|
||||
""" Get the set of partitions currently in use.
|
||||
@ -66,8 +67,8 @@ def _setup_postgresql_features(conn: Connection) -> Dict[str, Any]:
|
||||
""" Set up a dictionary with various optional Postgresql/Postgis features that
|
||||
depend on the database version.
|
||||
"""
|
||||
pg_version = conn.server_version_tuple()
|
||||
postgis_version = conn.postgis_version_tuple()
|
||||
pg_version = server_version_tuple(conn)
|
||||
postgis_version = postgis_version_tuple(conn)
|
||||
pg11plus = pg_version >= (11, 0, 0)
|
||||
ps3 = postgis_version >= (3, 0)
|
||||
return {
|
||||
@ -125,8 +126,8 @@ class SQLPreprocessor:
|
||||
conn.commit()
|
||||
|
||||
|
||||
def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
|
||||
**kwargs: Any) -> None:
|
||||
async def run_parallel_sql_file(self, dsn: str, name: str, num_threads: int = 1,
|
||||
**kwargs: Any) -> None:
|
||||
""" Execute the given SQL files using parallel asynchronous connections.
|
||||
The keyword arguments may supply additional parameters for
|
||||
preprocessing.
|
||||
@ -138,6 +139,6 @@ class SQLPreprocessor:
|
||||
|
||||
parts = sql.split('\n---\n')
|
||||
|
||||
with WorkerPool(dsn, num_threads) as pool:
|
||||
async with QueryPool(dsn, num_threads) as pool:
|
||||
for part in parts:
|
||||
pool.next_free_worker().perform(part)
|
||||
await pool.put_query(part, None)
|
||||
|
@ -7,34 +7,25 @@
|
||||
"""
|
||||
Access and helper functions for the status and status log table.
|
||||
"""
|
||||
from typing import Optional, Tuple, cast
|
||||
from typing import Optional, Tuple
|
||||
import datetime as dt
|
||||
import logging
|
||||
import re
|
||||
|
||||
from .connection import Connection
|
||||
from .connection import Connection, table_exists, execute_scalar
|
||||
from ..utils.url_utils import get_url
|
||||
from ..errors import UsageError
|
||||
from ..typing import TypedDict
|
||||
|
||||
LOG = logging.getLogger()
|
||||
ISODATE_FORMAT = '%Y-%m-%dT%H:%M:%S'
|
||||
|
||||
|
||||
class StatusRow(TypedDict):
|
||||
""" Dictionary of columns of the import_status table.
|
||||
"""
|
||||
lastimportdate: dt.datetime
|
||||
sequence_id: Optional[int]
|
||||
indexed: Optional[bool]
|
||||
|
||||
|
||||
def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetime:
|
||||
""" Determine the date of the database from the newest object in the
|
||||
data base.
|
||||
"""
|
||||
# If there is a date from osm2pgsql available, use that.
|
||||
if conn.table_exists('osm2pgsql_properties'):
|
||||
if table_exists(conn, 'osm2pgsql_properties'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(""" SELECT value FROM osm2pgsql_properties
|
||||
WHERE property = 'current_timestamp' """)
|
||||
@ -47,15 +38,14 @@ def compute_database_date(conn: Connection, offline: bool = False) -> dt.datetim
|
||||
raise UsageError("Cannot determine database date from data in offline mode.")
|
||||
|
||||
# Else, find the node with the highest ID in the database
|
||||
with conn.cursor() as cur:
|
||||
if conn.table_exists('place'):
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM place WHERE osm_type='N'")
|
||||
else:
|
||||
osmid = cur.scalar("SELECT max(osm_id) FROM placex WHERE osm_type='N'")
|
||||
if table_exists(conn, 'place'):
|
||||
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM place WHERE osm_type='N'")
|
||||
else:
|
||||
osmid = execute_scalar(conn, "SELECT max(osm_id) FROM placex WHERE osm_type='N'")
|
||||
|
||||
if osmid is None:
|
||||
LOG.fatal("No data found in the database.")
|
||||
raise UsageError("No data found in the database.")
|
||||
if osmid is None:
|
||||
LOG.fatal("No data found in the database.")
|
||||
raise UsageError("No data found in the database.")
|
||||
|
||||
LOG.info("Using node id %d for timestamp lookup", osmid)
|
||||
# Get the node from the API to find the timestamp when it was created.
|
||||
@ -103,8 +93,9 @@ def get_status(conn: Connection) -> Tuple[Optional[dt.datetime], Optional[int],
|
||||
if cur.rowcount < 1:
|
||||
return None, None, None
|
||||
|
||||
row = cast(StatusRow, cur.fetchone())
|
||||
return row['lastimportdate'], row['sequence_id'], row['indexed']
|
||||
row = cur.fetchone()
|
||||
assert row
|
||||
return row.lastimportdate, row.sequence_id, row.indexed
|
||||
|
||||
|
||||
def set_indexed(conn: Connection, state: bool) -> None:
|
||||
|
@ -7,14 +7,13 @@
|
||||
"""
|
||||
Helper functions for handling DB accesses.
|
||||
"""
|
||||
from typing import IO, Optional, Union, Any, Iterable
|
||||
from typing import IO, Optional, Union
|
||||
import subprocess
|
||||
import logging
|
||||
import gzip
|
||||
import io
|
||||
from pathlib import Path
|
||||
|
||||
from .connection import get_pg_env, Cursor
|
||||
from .connection import get_pg_env
|
||||
from ..errors import UsageError
|
||||
|
||||
LOG = logging.getLogger()
|
||||
@ -72,58 +71,3 @@ def execute_file(dsn: str, fname: Path,
|
||||
|
||||
if ret != 0 or remain > 0:
|
||||
raise UsageError("Failed to execute SQL file.")
|
||||
|
||||
|
||||
# List of characters that need to be quoted for the copy command.
|
||||
_SQL_TRANSLATION = {ord('\\'): '\\\\',
|
||||
ord('\t'): '\\t',
|
||||
ord('\n'): '\\n'}
|
||||
|
||||
|
||||
class CopyBuffer:
|
||||
""" Data collector for the copy_from command.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.buffer = io.StringIO()
|
||||
|
||||
|
||||
def __enter__(self) -> 'CopyBuffer':
|
||||
return self
|
||||
|
||||
|
||||
def size(self) -> int:
|
||||
""" Return the number of bytes the buffer currently contains.
|
||||
"""
|
||||
return self.buffer.tell()
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
if self.buffer is not None:
|
||||
self.buffer.close()
|
||||
|
||||
|
||||
def add(self, *data: Any) -> None:
|
||||
""" Add another row of data to the copy buffer.
|
||||
"""
|
||||
first = True
|
||||
for column in data:
|
||||
if first:
|
||||
first = False
|
||||
else:
|
||||
self.buffer.write('\t')
|
||||
if column is None:
|
||||
self.buffer.write('\\N')
|
||||
else:
|
||||
self.buffer.write(str(column).translate(_SQL_TRANSLATION))
|
||||
self.buffer.write('\n')
|
||||
|
||||
|
||||
def copy_out(self, cur: Cursor, table: str, columns: Optional[Iterable[str]] = None) -> None:
|
||||
""" Copy all collected data into the given table.
|
||||
|
||||
The buffer is empty and reusable after this operation.
|
||||
"""
|
||||
if self.buffer.tell() > 0:
|
||||
self.buffer.seek(0)
|
||||
cur.copy_from(self.buffer, table, columns=columns)
|
||||
self.buffer = io.StringIO()
|
||||
|
@ -7,92 +7,20 @@
|
||||
"""
|
||||
Main work horse for indexing (computing addresses) the database.
|
||||
"""
|
||||
from typing import Optional, Any, cast
|
||||
from typing import cast, List, Any
|
||||
import logging
|
||||
import time
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg
|
||||
|
||||
from ..typing import DictCursorResults
|
||||
from ..db.async_connection import DBConnection, WorkerPool
|
||||
from ..db.connection import connect, Connection, Cursor
|
||||
from ..db.connection import connect, execute_scalar
|
||||
from ..db.query_pool import QueryPool
|
||||
from ..tokenizer.base import AbstractTokenizer
|
||||
from .progress import ProgressLogger
|
||||
from . import runners
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
|
||||
class PlaceFetcher:
|
||||
""" Asynchronous connection that fetches place details for processing.
|
||||
"""
|
||||
def __init__(self, dsn: str, setup_conn: Connection) -> None:
|
||||
self.wait_time = 0.0
|
||||
self.current_ids: Optional[DictCursorResults] = None
|
||||
self.conn: Optional[DBConnection] = DBConnection(dsn,
|
||||
cursor_factory=psycopg2.extras.DictCursor)
|
||||
|
||||
with setup_conn.cursor() as cur:
|
||||
# need to fetch those manually because register_hstore cannot
|
||||
# fetch them on an asynchronous connection below.
|
||||
hstore_oid = cur.scalar("SELECT 'hstore'::regtype::oid")
|
||||
hstore_array_oid = cur.scalar("SELECT 'hstore[]'::regtype::oid")
|
||||
|
||||
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
|
||||
array_oid=hstore_array_oid)
|
||||
|
||||
def close(self) -> None:
|
||||
""" Close the underlying asynchronous connection.
|
||||
"""
|
||||
if self.conn:
|
||||
self.conn.close()
|
||||
self.conn = None
|
||||
|
||||
|
||||
def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
|
||||
""" Send a request for the next batch of places.
|
||||
If details for the places are required, they will be fetched
|
||||
asynchronously.
|
||||
|
||||
Returns true if there is still data available.
|
||||
"""
|
||||
ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
|
||||
|
||||
if not ids:
|
||||
self.current_ids = None
|
||||
return False
|
||||
|
||||
assert self.conn is not None
|
||||
self.current_ids = runner.get_place_details(self.conn, ids)
|
||||
|
||||
return True
|
||||
|
||||
def get_batch(self) -> DictCursorResults:
|
||||
""" Get the next batch of data, previously requested with
|
||||
`fetch_next_batch`.
|
||||
"""
|
||||
assert self.conn is not None
|
||||
assert self.conn.cursor is not None
|
||||
|
||||
if self.current_ids is not None and not self.current_ids:
|
||||
tstart = time.time()
|
||||
self.conn.wait()
|
||||
self.wait_time += time.time() - tstart
|
||||
self.current_ids = cast(Optional[DictCursorResults],
|
||||
self.conn.cursor.fetchall())
|
||||
|
||||
return self.current_ids if self.current_ids is not None else []
|
||||
|
||||
def __enter__(self) -> 'PlaceFetcher':
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
assert self.conn is not None
|
||||
self.conn.wait()
|
||||
self.close()
|
||||
|
||||
|
||||
class Indexer:
|
||||
""" Main indexing routine.
|
||||
"""
|
||||
@ -114,7 +42,7 @@ class Indexer:
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def index_full(self, analyse: bool = True) -> None:
|
||||
async def index_full(self, analyse: bool = True) -> None:
|
||||
""" Index the complete database. This will first index boundaries
|
||||
followed by all other objects. When `analyse` is True, then the
|
||||
database will be analysed at the appropriate places to
|
||||
@ -128,23 +56,27 @@ class Indexer:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('ANALYZE')
|
||||
|
||||
if self.index_by_rank(0, 4) > 0:
|
||||
_analyze()
|
||||
while True:
|
||||
if await self.index_by_rank(0, 4) > 0:
|
||||
_analyze()
|
||||
|
||||
if self.index_boundaries(0, 30) > 100:
|
||||
_analyze()
|
||||
if await self.index_boundaries(0, 30) > 100:
|
||||
_analyze()
|
||||
|
||||
if self.index_by_rank(5, 25) > 100:
|
||||
_analyze()
|
||||
if await self.index_by_rank(5, 25) > 100:
|
||||
_analyze()
|
||||
|
||||
if self.index_by_rank(26, 30) > 1000:
|
||||
_analyze()
|
||||
if await self.index_by_rank(26, 30) > 1000:
|
||||
_analyze()
|
||||
|
||||
if self.index_postcodes() > 100:
|
||||
_analyze()
|
||||
if await self.index_postcodes() > 100:
|
||||
_analyze()
|
||||
|
||||
if not self.has_pending():
|
||||
break
|
||||
|
||||
|
||||
def index_boundaries(self, minrank: int, maxrank: int) -> int:
|
||||
async def index_boundaries(self, minrank: int, maxrank: int) -> int:
|
||||
""" Index only administrative boundaries within the given rank range.
|
||||
"""
|
||||
total = 0
|
||||
@ -153,11 +85,11 @@ class Indexer:
|
||||
|
||||
with self.tokenizer.name_analyzer() as analyzer:
|
||||
for rank in range(max(minrank, 4), min(maxrank, 26)):
|
||||
total += self._index(runners.BoundaryRunner(rank, analyzer))
|
||||
total += await self._index(runners.BoundaryRunner(rank, analyzer))
|
||||
|
||||
return total
|
||||
|
||||
def index_by_rank(self, minrank: int, maxrank: int) -> int:
|
||||
async def index_by_rank(self, minrank: int, maxrank: int) -> int:
|
||||
""" Index all entries of placex in the given rank range (inclusive)
|
||||
in order of their address rank.
|
||||
|
||||
@ -171,21 +103,27 @@ class Indexer:
|
||||
|
||||
with self.tokenizer.name_analyzer() as analyzer:
|
||||
for rank in range(max(1, minrank), maxrank + 1):
|
||||
total += self._index(runners.RankRunner(rank, analyzer), 20 if rank == 30 else 1)
|
||||
if rank >= 30:
|
||||
batch = 20
|
||||
elif rank >= 26:
|
||||
batch = 5
|
||||
else:
|
||||
batch = 1
|
||||
total += await self._index(runners.RankRunner(rank, analyzer), batch)
|
||||
|
||||
if maxrank == 30:
|
||||
total += self._index(runners.RankRunner(0, analyzer))
|
||||
total += self._index(runners.InterpolationRunner(analyzer), 20)
|
||||
total += await self._index(runners.RankRunner(0, analyzer))
|
||||
total += await self._index(runners.InterpolationRunner(analyzer), 20)
|
||||
|
||||
return total
|
||||
|
||||
|
||||
def index_postcodes(self) -> int:
|
||||
async def index_postcodes(self) -> int:
|
||||
"""Index the entries of the location_postcode table.
|
||||
"""
|
||||
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
|
||||
|
||||
return self._index(runners.PostcodeRunner(), 20)
|
||||
return await self._index(runners.PostcodeRunner(), 20)
|
||||
|
||||
|
||||
def update_status_table(self) -> None:
|
||||
@ -197,46 +135,58 @@ class Indexer:
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _index(self, runner: runners.Runner, batch: int = 1) -> int:
|
||||
async def _index(self, runner: runners.Runner, batch: int = 1) -> int:
|
||||
""" Index a single rank or table. `runner` describes the SQL to use
|
||||
for indexing. `batch` describes the number of objects that
|
||||
should be processed with a single SQL statement
|
||||
"""
|
||||
LOG.warning("Starting %s (using batch size %s)", runner.name(), batch)
|
||||
|
||||
with connect(self.dsn) as conn:
|
||||
psycopg2.extras.register_hstore(conn)
|
||||
with conn.cursor() as cur:
|
||||
total_tuples = cur.scalar(runner.sql_count_objects())
|
||||
LOG.debug("Total number of rows: %i", total_tuples)
|
||||
total_tuples = self._prepare_indexing(runner)
|
||||
|
||||
conn.commit()
|
||||
progress = ProgressLogger(runner.name(), total_tuples)
|
||||
|
||||
progress = ProgressLogger(runner.name(), total_tuples)
|
||||
if total_tuples > 0:
|
||||
async with await psycopg.AsyncConnection.connect(
|
||||
self.dsn, row_factory=psycopg.rows.dict_row) as aconn,\
|
||||
QueryPool(self.dsn, self.num_threads, autocommit=True) as pool:
|
||||
fetcher_time = 0.0
|
||||
tstart = time.time()
|
||||
async with aconn.cursor(name='places') as cur:
|
||||
query = runner.index_places_query(batch)
|
||||
params: List[Any] = []
|
||||
num_places = 0
|
||||
async for place in cur.stream(runner.sql_get_objects()):
|
||||
fetcher_time += time.time() - tstart
|
||||
|
||||
if total_tuples > 0:
|
||||
with conn.cursor(name='places') as cur:
|
||||
cur.execute(runner.sql_get_objects())
|
||||
params.extend(runner.index_places_params(place))
|
||||
num_places += 1
|
||||
|
||||
with PlaceFetcher(self.dsn, conn) as fetcher:
|
||||
with WorkerPool(self.dsn, self.num_threads) as pool:
|
||||
has_more = fetcher.fetch_next_batch(cur, runner)
|
||||
while has_more:
|
||||
places = fetcher.get_batch()
|
||||
if num_places >= batch:
|
||||
LOG.debug("Processing places: %s", str(params))
|
||||
await pool.put_query(query, params)
|
||||
progress.add(num_places)
|
||||
params = []
|
||||
num_places = 0
|
||||
|
||||
# asynchronously get the next batch
|
||||
has_more = fetcher.fetch_next_batch(cur, runner)
|
||||
tstart = time.time()
|
||||
|
||||
# And insert the current batch
|
||||
for idx in range(0, len(places), batch):
|
||||
part = places[idx:idx + batch]
|
||||
LOG.debug("Processing places: %s", str(part))
|
||||
runner.index_places(pool.next_free_worker(), part)
|
||||
progress.add(len(part))
|
||||
if num_places > 0:
|
||||
await pool.put_query(runner.index_places_query(num_places), params)
|
||||
|
||||
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
|
||||
fetcher.wait_time, pool.wait_time)
|
||||
|
||||
conn.commit()
|
||||
LOG.info("Wait time: fetcher: %.2fs, pool: %.2fs",
|
||||
fetcher_time, pool.wait_time)
|
||||
|
||||
return progress.done()
|
||||
|
||||
|
||||
def _prepare_indexing(self, runner: runners.Runner) -> int:
|
||||
with connect(self.dsn) as conn:
|
||||
hstore_info = psycopg.types.TypeInfo.fetch(conn, "hstore")
|
||||
if hstore_info is None:
|
||||
raise RuntimeError('Hstore extension is requested but not installed.')
|
||||
psycopg.types.hstore.register_hstore(hstore_info)
|
||||
|
||||
total_tuples = execute_scalar(conn, runner.sql_count_objects())
|
||||
LOG.debug("Total number of rows: %i", total_tuples)
|
||||
return cast(int, total_tuples)
|
||||
|
@ -8,14 +8,14 @@
|
||||
Mix-ins that provide the actual commands for the indexer for various indexing
|
||||
tasks.
|
||||
"""
|
||||
from typing import Any, List
|
||||
import functools
|
||||
from typing import Any, Sequence
|
||||
|
||||
from psycopg2 import sql as pysql
|
||||
import psycopg2.extras
|
||||
from psycopg import sql as pysql
|
||||
from psycopg.abc import Query
|
||||
from psycopg.rows import DictRow
|
||||
from psycopg.types.json import Json
|
||||
|
||||
from ..typing import Query, DictCursorResult, DictCursorResults, Protocol
|
||||
from ..db.async_connection import DBConnection
|
||||
from ..typing import Protocol
|
||||
from ..data.place_info import PlaceInfo
|
||||
from ..tokenizer.base import AbstractAnalyzer
|
||||
|
||||
@ -24,58 +24,48 @@ from ..tokenizer.base import AbstractAnalyzer
|
||||
def _mk_valuelist(template: str, num: int) -> pysql.Composed:
|
||||
return pysql.SQL(',').join([pysql.SQL(template)] * num)
|
||||
|
||||
def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json:
|
||||
return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
|
||||
def _analyze_place(place: DictRow, analyzer: AbstractAnalyzer) -> Json:
|
||||
return Json(analyzer.process_place(PlaceInfo(place)))
|
||||
|
||||
|
||||
class Runner(Protocol):
|
||||
def name(self) -> str: ...
|
||||
def sql_count_objects(self) -> Query: ...
|
||||
def sql_get_objects(self) -> Query: ...
|
||||
def get_place_details(self, worker: DBConnection,
|
||||
ids: DictCursorResults) -> DictCursorResults: ...
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
|
||||
def index_places_query(self, batch_size: int) -> Query: ...
|
||||
def index_places_params(self, place: DictRow) -> Sequence[Any]: ...
|
||||
|
||||
|
||||
SELECT_SQL = pysql.SQL("""SELECT place_id, extra.*
|
||||
FROM (SELECT * FROM placex {}) as px,
|
||||
LATERAL placex_indexing_prepare(px) as extra """)
|
||||
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
|
||||
|
||||
class AbstractPlacexRunner:
|
||||
""" Returns SQL commands for indexing of the placex table.
|
||||
"""
|
||||
SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
|
||||
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
|
||||
|
||||
def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
|
||||
self.rank = rank
|
||||
self.analyzer = analyzer
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _index_sql(self, num_places: int) -> pysql.Composed:
|
||||
def index_places_query(self, batch_size: int) -> Query:
|
||||
return pysql.SQL(
|
||||
""" UPDATE placex
|
||||
SET indexed_status = 0, address = v.addr, token_info = v.ti,
|
||||
name = v.name, linked_place_id = v.linked_place_id
|
||||
FROM (VALUES {}) as v(id, name, addr, linked_place_id, ti)
|
||||
WHERE place_id = v.id
|
||||
""").format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
|
||||
""").format(_mk_valuelist(UPDATE_LINE, batch_size))
|
||||
|
||||
|
||||
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
|
||||
worker.perform("""SELECT place_id, extra.*
|
||||
FROM placex, LATERAL placex_indexing_prepare(placex) as extra
|
||||
WHERE place_id IN %s""",
|
||||
(tuple((p[0] for p in ids)), ))
|
||||
|
||||
return []
|
||||
|
||||
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
|
||||
values: List[Any] = []
|
||||
for place in places:
|
||||
for field in ('place_id', 'name', 'address', 'linked_place_id'):
|
||||
values.append(place[field])
|
||||
values.append(_analyze_place(place, self.analyzer))
|
||||
|
||||
worker.perform(self._index_sql(len(places)), values)
|
||||
def index_places_params(self, place: DictRow) -> Sequence[Any]:
|
||||
return (place['place_id'],
|
||||
place['name'],
|
||||
place['address'],
|
||||
place['linked_place_id'],
|
||||
_analyze_place(place, self.analyzer))
|
||||
|
||||
|
||||
class RankRunner(AbstractPlacexRunner):
|
||||
@ -91,10 +81,10 @@ class RankRunner(AbstractPlacexRunner):
|
||||
""").format(pysql.Literal(self.rank))
|
||||
|
||||
def sql_get_objects(self) -> pysql.Composed:
|
||||
return self.SELECT_SQL + pysql.SQL(
|
||||
"""WHERE indexed_status > 0 and rank_address = {}
|
||||
ORDER BY geometry_sector
|
||||
""").format(pysql.Literal(self.rank))
|
||||
return SELECT_SQL.format(pysql.SQL(
|
||||
"""WHERE placex.indexed_status > 0 and placex.rank_address = {}
|
||||
ORDER BY placex.geometry_sector
|
||||
""").format(pysql.Literal(self.rank)))
|
||||
|
||||
|
||||
class BoundaryRunner(AbstractPlacexRunner):
|
||||
@ -105,19 +95,19 @@ class BoundaryRunner(AbstractPlacexRunner):
|
||||
def name(self) -> str:
|
||||
return f"boundaries rank {self.rank}"
|
||||
|
||||
def sql_count_objects(self) -> pysql.Composed:
|
||||
def sql_count_objects(self) -> Query:
|
||||
return pysql.SQL("""SELECT count(*) FROM placex
|
||||
WHERE indexed_status > 0
|
||||
AND rank_search = {}
|
||||
AND class = 'boundary' and type = 'administrative'
|
||||
""").format(pysql.Literal(self.rank))
|
||||
|
||||
def sql_get_objects(self) -> pysql.Composed:
|
||||
return self.SELECT_SQL + pysql.SQL(
|
||||
"""WHERE indexed_status > 0 and rank_search = {}
|
||||
and class = 'boundary' and type = 'administrative'
|
||||
ORDER BY partition, admin_level
|
||||
""").format(pysql.Literal(self.rank))
|
||||
def sql_get_objects(self) -> Query:
|
||||
return SELECT_SQL.format(pysql.SQL(
|
||||
"""WHERE placex.indexed_status > 0 and placex.rank_search = {}
|
||||
and placex.class = 'boundary' and placex.type = 'administrative'
|
||||
ORDER BY placex.partition, placex.admin_level
|
||||
""").format(pysql.Literal(self.rank)))
|
||||
|
||||
|
||||
class InterpolationRunner:
|
||||
@ -132,40 +122,29 @@ class InterpolationRunner:
|
||||
def name(self) -> str:
|
||||
return "interpolation lines (location_property_osmline)"
|
||||
|
||||
def sql_count_objects(self) -> str:
|
||||
def sql_count_objects(self) -> Query:
|
||||
return """SELECT count(*) FROM location_property_osmline
|
||||
WHERE indexed_status > 0"""
|
||||
|
||||
def sql_get_objects(self) -> str:
|
||||
return """SELECT place_id
|
||||
|
||||
def sql_get_objects(self) -> Query:
|
||||
return """SELECT place_id, get_interpolation_address(address, osm_id) as address
|
||||
FROM location_property_osmline
|
||||
WHERE indexed_status > 0
|
||||
ORDER BY geometry_sector"""
|
||||
|
||||
|
||||
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
|
||||
worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
|
||||
FROM location_property_osmline WHERE place_id IN %s""",
|
||||
(tuple((p[0] for p in ids)), ))
|
||||
return []
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _index_sql(self, num_places: int) -> pysql.Composed:
|
||||
def index_places_query(self, batch_size: int) -> Query:
|
||||
return pysql.SQL("""UPDATE location_property_osmline
|
||||
SET indexed_status = 0, address = v.addr, token_info = v.ti
|
||||
FROM (VALUES {}) as v(id, addr, ti)
|
||||
WHERE place_id = v.id
|
||||
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
|
||||
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", batch_size))
|
||||
|
||||
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
|
||||
values: List[Any] = []
|
||||
for place in places:
|
||||
values.extend((place[x] for x in ('place_id', 'address')))
|
||||
values.append(_analyze_place(place, self.analyzer))
|
||||
|
||||
worker.perform(self._index_sql(len(places)), values)
|
||||
def index_places_params(self, place: DictRow) -> Sequence[Any]:
|
||||
return (place['place_id'], place['address'],
|
||||
_analyze_place(place, self.analyzer))
|
||||
|
||||
|
||||
|
||||
@ -177,20 +156,21 @@ class PostcodeRunner(Runner):
|
||||
return "postcodes (location_postcode)"
|
||||
|
||||
|
||||
def sql_count_objects(self) -> str:
|
||||
def sql_count_objects(self) -> Query:
|
||||
return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
|
||||
|
||||
|
||||
def sql_get_objects(self) -> str:
|
||||
def sql_get_objects(self) -> Query:
|
||||
return """SELECT place_id FROM location_postcode
|
||||
WHERE indexed_status > 0
|
||||
ORDER BY country_code, postcode"""
|
||||
|
||||
|
||||
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
|
||||
return ids
|
||||
def index_places_query(self, batch_size: int) -> Query:
|
||||
return pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
|
||||
WHERE place_id IN ({})""")\
|
||||
.format(pysql.SQL(',').join((pysql.Placeholder() for _ in range(batch_size))))
|
||||
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
|
||||
worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
|
||||
WHERE place_id IN ({})""")
|
||||
.format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))
|
||||
|
||||
def index_places_params(self, place: DictRow) -> Sequence[Any]:
|
||||
return (place['place_id'], )
|
||||
|
@ -11,14 +11,16 @@ libICU instead of the PostgreSQL module.
|
||||
from typing import Optional, Sequence, List, Tuple, Mapping, Any, cast, \
|
||||
Dict, Set, Iterable
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
|
||||
from ..db.connection import connect, Connection, Cursor
|
||||
from psycopg.types.json import Jsonb
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..db.connection import connect, Connection, Cursor, server_version_tuple,\
|
||||
drop_tables, table_exists, execute_scalar
|
||||
from ..config import Configuration
|
||||
from ..db.utils import CopyBuffer
|
||||
from ..db.sql_preprocessor import SQLPreprocessor
|
||||
from ..data.place_info import PlaceInfo
|
||||
from ..data.place_name import PlaceName
|
||||
@ -108,19 +110,18 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
""" Recompute frequencies for all name words.
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
if not conn.table_exists('search_name'):
|
||||
if not table_exists(conn, 'search_name'):
|
||||
return
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('ANALYSE search_name')
|
||||
if threads > 1:
|
||||
cur.execute('SET max_parallel_workers_per_gather TO %s',
|
||||
(min(threads, 6),))
|
||||
cur.execute(pysql.SQL('SET max_parallel_workers_per_gather TO {}')
|
||||
.format(pysql.Literal(min(threads, 6),)))
|
||||
|
||||
if conn.server_version_tuple() < (12, 0):
|
||||
if server_version_tuple(conn) < (12, 0):
|
||||
LOG.info('Computing word frequencies')
|
||||
cur.drop_table('word_frequencies')
|
||||
cur.drop_table('addressword_frequencies')
|
||||
drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
|
||||
cur.execute("""CREATE TEMP TABLE word_frequencies AS
|
||||
SELECT unnest(name_vector) as id, count(*)
|
||||
FROM search_name GROUP BY id""")
|
||||
@ -152,17 +153,16 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
$$ LANGUAGE plpgsql IMMUTABLE;
|
||||
""")
|
||||
LOG.info('Update word table with recomputed frequencies')
|
||||
cur.drop_table('tmp_word')
|
||||
drop_tables(conn, 'tmp_word')
|
||||
cur.execute("""CREATE TABLE tmp_word AS
|
||||
SELECT word_id, word_token, type, word,
|
||||
word_freq_update(word_id, info) as info
|
||||
FROM word
|
||||
""")
|
||||
cur.drop_table('word_frequencies')
|
||||
cur.drop_table('addressword_frequencies')
|
||||
drop_tables(conn, 'word_frequencies', 'addressword_frequencies')
|
||||
else:
|
||||
LOG.info('Computing word frequencies')
|
||||
cur.drop_table('word_frequencies')
|
||||
drop_tables(conn, 'word_frequencies')
|
||||
cur.execute("""
|
||||
CREATE TEMP TABLE word_frequencies AS
|
||||
WITH word_freq AS MATERIALIZED (
|
||||
@ -182,7 +182,7 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
cur.execute('CREATE UNIQUE INDEX ON word_frequencies(id) INCLUDE(info)')
|
||||
cur.execute('ANALYSE word_frequencies')
|
||||
LOG.info('Update word table with recomputed frequencies')
|
||||
cur.drop_table('tmp_word')
|
||||
drop_tables(conn, 'tmp_word')
|
||||
cur.execute("""CREATE TABLE tmp_word AS
|
||||
SELECT word_id, word_token, type, word,
|
||||
(CASE WHEN wf.info is null THEN word.info
|
||||
@ -191,7 +191,7 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
FROM word LEFT JOIN word_frequencies wf
|
||||
ON word.word_id = wf.id
|
||||
""")
|
||||
cur.drop_table('word_frequencies')
|
||||
drop_tables(conn, 'word_frequencies')
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SET max_parallel_workers_per_gather TO 0')
|
||||
@ -210,7 +210,7 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
""" Remove unused house numbers.
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
if not conn.table_exists('search_name'):
|
||||
if not table_exists(conn, 'search_name'):
|
||||
return
|
||||
with conn.cursor(name="hnr_counter") as cur:
|
||||
cur.execute("""SELECT DISTINCT word_id, coalesce(info->>'lookup', word_token)
|
||||
@ -311,8 +311,7 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
frequencies.
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.drop_table('word')
|
||||
drop_tables(conn, 'word')
|
||||
sqlp = SQLPreprocessor(conn, config)
|
||||
sqlp.run_string(conn, """
|
||||
CREATE TABLE word (
|
||||
@ -370,8 +369,8 @@ class ICUTokenizer(AbstractTokenizer):
|
||||
""" Rename all tables and indexes used by the tokenizer.
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
drop_tables(conn, 'word')
|
||||
with conn.cursor() as cur:
|
||||
cur.drop_table('word')
|
||||
cur.execute(f"ALTER TABLE {old} RENAME TO word")
|
||||
for idx in ('word_token', 'word_id'):
|
||||
cur.execute(f"""ALTER INDEX idx_{old}_{idx}
|
||||
@ -393,7 +392,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
|
||||
def __init__(self, dsn: str, sanitizer: PlaceSanitizer,
|
||||
token_analysis: ICUTokenAnalysis) -> None:
|
||||
self.conn: Optional[Connection] = connect(dsn).connection
|
||||
self.conn: Optional[Connection] = connect(dsn)
|
||||
self.conn.autocommit = True
|
||||
self.sanitizer = sanitizer
|
||||
self.token_analysis = token_analysis
|
||||
@ -535,9 +534,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
|
||||
if terms:
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute_values("""SELECT create_postcode_word(pc, var)
|
||||
FROM (VALUES %s) AS v(pc, var)""",
|
||||
terms)
|
||||
cur.executemany("""SELECT create_postcode_word(%s, %s)""", terms)
|
||||
|
||||
|
||||
|
||||
@ -580,18 +577,15 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
to_add = new_phrases - existing_phrases
|
||||
|
||||
added = 0
|
||||
with CopyBuffer() as copystr:
|
||||
with cursor.copy('COPY word(word_token, type, word, info) FROM STDIN') as copy:
|
||||
for word, cls, typ, oper in to_add:
|
||||
term = self._search_normalized(word)
|
||||
if term:
|
||||
copystr.add(term, 'S', word,
|
||||
json.dumps({'class': cls, 'type': typ,
|
||||
'op': oper if oper in ('in', 'near') else None}))
|
||||
copy.write_row((term, 'S', word,
|
||||
Jsonb({'class': cls, 'type': typ,
|
||||
'op': oper if oper in ('in', 'near') else None})))
|
||||
added += 1
|
||||
|
||||
copystr.copy_out(cursor, 'word',
|
||||
columns=['word_token', 'type', 'word', 'info'])
|
||||
|
||||
return added
|
||||
|
||||
|
||||
@ -604,11 +598,11 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
to_delete = existing_phrases - new_phrases
|
||||
|
||||
if to_delete:
|
||||
cursor.execute_values(
|
||||
""" DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op)
|
||||
WHERE type = 'S' and word = name
|
||||
and info->>'class' = in_class and info->>'type' = in_type
|
||||
and ((op = '-' and info->>'op' is null) or op = info->>'op')
|
||||
cursor.executemany(
|
||||
""" DELETE FROM word
|
||||
WHERE type = 'S' and word = %s
|
||||
and info->>'class' = %s and info->>'type' = %s
|
||||
and %s = coalesce(info->>'op', '-')
|
||||
""", to_delete)
|
||||
|
||||
return len(to_delete)
|
||||
@ -655,7 +649,7 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
gone_tokens.update(existing_tokens[False] & word_tokens)
|
||||
if gone_tokens:
|
||||
cur.execute("""DELETE FROM word
|
||||
USING unnest(%s) as token
|
||||
USING unnest(%s::text[]) as token
|
||||
WHERE type = 'C' and word = %s
|
||||
and word_token = token""",
|
||||
(list(gone_tokens), country_code))
|
||||
@ -668,12 +662,12 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
if internal:
|
||||
sql = """INSERT INTO word (word_token, type, word, info)
|
||||
(SELECT token, 'C', %s, '{"internal": "yes"}'
|
||||
FROM unnest(%s) as token)
|
||||
FROM unnest(%s::text[]) as token)
|
||||
"""
|
||||
else:
|
||||
sql = """INSERT INTO word (word_token, type, word)
|
||||
(SELECT token, 'C', %s
|
||||
FROM unnest(%s) as token)
|
||||
FROM unnest(%s::text[]) as token)
|
||||
"""
|
||||
cur.execute(sql, (country_code, list(new_tokens)))
|
||||
|
||||
@ -733,11 +727,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
if norm_name:
|
||||
result = self._cache.housenumbers.get(norm_name, result)
|
||||
if result[0] is None:
|
||||
with self.conn.cursor() as cur:
|
||||
hid = cur.scalar("SELECT getorcreate_hnr_id(%s)", (norm_name, ))
|
||||
hid = execute_scalar(self.conn, "SELECT getorcreate_hnr_id(%s)", (norm_name, ))
|
||||
|
||||
result = hid, norm_name
|
||||
self._cache.housenumbers[norm_name] = result
|
||||
result = hid, norm_name
|
||||
self._cache.housenumbers[norm_name] = result
|
||||
else:
|
||||
# Otherwise use the analyzer to determine the canonical name.
|
||||
# Per convention we use the first variant as the 'lookup name', the
|
||||
@ -748,11 +741,10 @@ class ICUNameAnalyzer(AbstractAnalyzer):
|
||||
if result[0] is None:
|
||||
variants = analyzer.compute_variants(word_id)
|
||||
if variants:
|
||||
with self.conn.cursor() as cur:
|
||||
hid = cur.scalar("SELECT create_analyzed_hnr_id(%s, %s)",
|
||||
hid = execute_scalar(self.conn, "SELECT create_analyzed_hnr_id(%s, %s)",
|
||||
(word_id, list(variants)))
|
||||
result = hid, variants[0]
|
||||
self._cache.housenumbers[word_id] = result
|
||||
result = hid, variants[0]
|
||||
self._cache.housenumbers[word_id] = result
|
||||
|
||||
return result
|
||||
|
||||
|
@ -17,11 +17,12 @@ import shutil
|
||||
from textwrap import dedent
|
||||
|
||||
from icu import Transliterator
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..errors import UsageError
|
||||
from ..db.connection import connect, Connection
|
||||
from ..db.connection import connect, Connection, drop_tables, table_exists,\
|
||||
execute_scalar, register_hstore
|
||||
from ..config import Configuration
|
||||
from ..db import properties
|
||||
from ..db import utils as db_utils
|
||||
@ -78,12 +79,12 @@ def _check_module(module_dir: str, conn: Connection) -> None:
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
cur.execute("""CREATE FUNCTION nominatim_test_import_func(text)
|
||||
RETURNS text AS %s, 'transliteration'
|
||||
LANGUAGE c IMMUTABLE STRICT;
|
||||
DROP FUNCTION nominatim_test_import_func(text)
|
||||
""", (f'{module_dir}/nominatim.so', ))
|
||||
except psycopg2.DatabaseError as err:
|
||||
cur.execute(pysql.SQL("""CREATE FUNCTION nominatim_test_import_func(text)
|
||||
RETURNS text AS {}, 'transliteration'
|
||||
LANGUAGE c IMMUTABLE STRICT;
|
||||
DROP FUNCTION nominatim_test_import_func(text)
|
||||
""").format(pysql.Literal(f'{module_dir}/nominatim.so')))
|
||||
except psycopg.DatabaseError as err:
|
||||
LOG.fatal("Error accessing database module: %s", err)
|
||||
raise UsageError("Database module cannot be accessed.") from err
|
||||
|
||||
@ -179,11 +180,10 @@ class LegacyTokenizer(AbstractTokenizer):
|
||||
* Can nominatim.so be accessed by the database user?
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
with conn.cursor() as cur:
|
||||
try:
|
||||
out = cur.scalar("SELECT make_standard_name('a')")
|
||||
except psycopg2.Error as err:
|
||||
return hint.format(error=str(err))
|
||||
try:
|
||||
out = execute_scalar(conn, "SELECT make_standard_name('a')")
|
||||
except psycopg.Error as err:
|
||||
return hint.format(error=str(err))
|
||||
|
||||
if out != 'a':
|
||||
return hint.format(error='Unexpected result for make_standard_name()')
|
||||
@ -214,9 +214,9 @@ class LegacyTokenizer(AbstractTokenizer):
|
||||
""" Recompute the frequency of full words.
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
if conn.table_exists('search_name'):
|
||||
if table_exists(conn, 'search_name'):
|
||||
drop_tables(conn, "word_frequencies")
|
||||
with conn.cursor() as cur:
|
||||
cur.drop_table("word_frequencies")
|
||||
LOG.info("Computing word frequencies")
|
||||
cur.execute("""CREATE TEMP TABLE word_frequencies AS
|
||||
SELECT unnest(name_vector) as id, count(*)
|
||||
@ -226,7 +226,7 @@ class LegacyTokenizer(AbstractTokenizer):
|
||||
cur.execute("""UPDATE word SET search_name_count = count
|
||||
FROM word_frequencies
|
||||
WHERE word_token like ' %' and word_id = id""")
|
||||
cur.drop_table("word_frequencies")
|
||||
drop_tables(conn, "word_frequencies")
|
||||
conn.commit()
|
||||
|
||||
|
||||
@ -313,10 +313,10 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
|
||||
"""
|
||||
|
||||
def __init__(self, dsn: str, normalizer: Any):
|
||||
self.conn: Optional[Connection] = connect(dsn).connection
|
||||
self.conn: Optional[Connection] = connect(dsn)
|
||||
self.conn.autocommit = True
|
||||
self.normalizer = normalizer
|
||||
psycopg2.extras.register_hstore(self.conn)
|
||||
register_hstore(self.conn)
|
||||
|
||||
self._cache = _TokenCache(self.conn)
|
||||
|
||||
@ -406,7 +406,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
|
||||
""", (to_delete, ))
|
||||
if to_add:
|
||||
cur.execute("""SELECT count(create_postcode_id(pc))
|
||||
FROM unnest(%s) as pc
|
||||
FROM unnest(%s::text[]) as pc
|
||||
""", (to_add, ))
|
||||
|
||||
|
||||
@ -423,7 +423,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
|
||||
with self.conn.cursor() as cur:
|
||||
# Get the old phrases.
|
||||
existing_phrases = set()
|
||||
cur.execute("""SELECT word, class, type, operator FROM word
|
||||
cur.execute("""SELECT word, class as cls, type, operator FROM word
|
||||
WHERE class != 'place'
|
||||
OR (type != 'house' AND type != 'postcode')""")
|
||||
for label, cls, typ, oper in cur:
|
||||
@ -433,18 +433,19 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
|
||||
to_delete = existing_phrases - norm_phrases
|
||||
|
||||
if to_add:
|
||||
cur.execute_values(
|
||||
cur.executemany(
|
||||
""" INSERT INTO word (word_id, word_token, word, class, type,
|
||||
search_name_count, operator)
|
||||
(SELECT nextval('seq_word'), ' ' || make_standard_name(name), name,
|
||||
class, type, 0,
|
||||
CASE WHEN op in ('in', 'near') THEN op ELSE null END
|
||||
FROM (VALUES %s) as v(name, class, type, op))""",
|
||||
FROM (VALUES (%s, %s, %s, %s)) as v(name, class, type, op))""",
|
||||
to_add)
|
||||
|
||||
if to_delete and should_replace:
|
||||
cur.execute_values(
|
||||
""" DELETE FROM word USING (VALUES %s) as v(name, in_class, in_type, op)
|
||||
cur.executemany(
|
||||
""" DELETE FROM word
|
||||
USING (VALUES (%s, %s, %s, %s)) as v(name, in_class, in_type, op)
|
||||
WHERE word = name and class = in_class and type = in_type
|
||||
and ((op = '-' and operator is null) or op = operator)""",
|
||||
to_delete)
|
||||
@ -463,7 +464,7 @@ class LegacyNameAnalyzer(AbstractAnalyzer):
|
||||
"""INSERT INTO word (word_id, word_token, country_code)
|
||||
(SELECT nextval('seq_word'), lookup_token, %s
|
||||
FROM (SELECT DISTINCT ' ' || make_standard_name(n) as lookup_token
|
||||
FROM unnest(%s)n) y
|
||||
FROM unnest(%s::TEXT[])n) y
|
||||
WHERE NOT EXISTS(SELECT * FROM word
|
||||
WHERE word_token = lookup_token and country_code = %s))
|
||||
""", (country_code, list(names.values()), country_code))
|
||||
@ -536,9 +537,8 @@ class _TokenInfo:
|
||||
def add_names(self, conn: Connection, names: Mapping[str, str]) -> None:
|
||||
""" Add token information for the names of the place.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
# Create the token IDs for all names.
|
||||
self.data['names'] = cur.scalar("SELECT make_keywords(%s)::text",
|
||||
# Create the token IDs for all names.
|
||||
self.data['names'] = execute_scalar(conn, "SELECT make_keywords(%s)::text",
|
||||
(names, ))
|
||||
|
||||
|
||||
@ -576,9 +576,8 @@ class _TokenInfo:
|
||||
""" Add addr:street match terms.
|
||||
"""
|
||||
def _get_street(name: str) -> Optional[str]:
|
||||
with conn.cursor() as cur:
|
||||
return cast(Optional[str],
|
||||
cur.scalar("SELECT word_ids_from_name(%s)::text", (name, )))
|
||||
return cast(Optional[str],
|
||||
execute_scalar(conn, "SELECT word_ids_from_name(%s)::text", (name, )))
|
||||
|
||||
tokens = self.cache.streets.get(street, _get_street)
|
||||
self.data['street'] = tokens or '{}'
|
||||
|
@ -10,12 +10,12 @@ Functions for database analysis and maintenance.
|
||||
from typing import Optional, Tuple, Any, cast
|
||||
import logging
|
||||
|
||||
from psycopg2.extras import Json, register_hstore
|
||||
from psycopg2 import DataError
|
||||
import psycopg
|
||||
from psycopg.types.json import Json
|
||||
|
||||
from ..typing import DictCursorResult
|
||||
from ..config import Configuration
|
||||
from ..db.connection import connect, Cursor
|
||||
from ..db.connection import connect, Cursor, register_hstore
|
||||
from ..errors import UsageError
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
from ..data.place_info import PlaceInfo
|
||||
@ -59,7 +59,7 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
|
||||
"""
|
||||
with connect(config.get_libpq_dsn()) as conn:
|
||||
register_hstore(conn)
|
||||
with conn.cursor() as cur:
|
||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||
place = _get_place_info(cur, osm_id, place_id)
|
||||
|
||||
cur.execute("update placex set indexed_status = 2 where place_id = %s",
|
||||
@ -74,6 +74,9 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
|
||||
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
|
||||
|
||||
# Enable printing of messages.
|
||||
conn.add_notice_handler(lambda diag: print(diag.message_primary))
|
||||
|
||||
with tokenizer.name_analyzer() as analyzer:
|
||||
cur.execute("""UPDATE placex
|
||||
SET indexed_status = 0, address = %s, token_info = %s,
|
||||
@ -86,9 +89,6 @@ def analyse_indexing(config: Configuration, osm_id: Optional[str] = None,
|
||||
# we do not want to keep the results
|
||||
conn.rollback()
|
||||
|
||||
for msg in conn.notices:
|
||||
print(msg)
|
||||
|
||||
|
||||
def clean_deleted_relations(config: Configuration, age: str) -> None:
|
||||
""" Clean deleted relations older than a given age
|
||||
@ -101,6 +101,6 @@ def clean_deleted_relations(config: Configuration, age: str) -> None:
|
||||
WHERE p.osm_type = d.osm_type AND p.osm_id = d.osm_id
|
||||
AND age(p.indexed_date) > %s::interval""",
|
||||
(age, ))
|
||||
except DataError as exc:
|
||||
except psycopg.DataError as exc:
|
||||
raise UsageError('Invalid PostgreSQL time interval format') from exc
|
||||
conn.commit()
|
||||
|
@ -12,7 +12,8 @@ from enum import Enum
|
||||
from textwrap import dedent
|
||||
|
||||
from ..config import Configuration
|
||||
from ..db.connection import connect, Connection
|
||||
from ..db.connection import connect, Connection, server_version_tuple,\
|
||||
index_exists, table_exists, execute_scalar
|
||||
from ..db import properties
|
||||
from ..errors import UsageError
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
@ -80,7 +81,7 @@ def check_database(config: Configuration) -> int:
|
||||
""" Run a number of checks on the database and return the status.
|
||||
"""
|
||||
try:
|
||||
conn = connect(config.get_libpq_dsn()).connection
|
||||
conn = connect(config.get_libpq_dsn())
|
||||
except UsageError as err:
|
||||
conn = _BadConnection(str(err)) # type: ignore[assignment]
|
||||
|
||||
@ -109,14 +110,14 @@ def _get_indexes(conn: Connection) -> List[str]:
|
||||
'idx_postcode_id',
|
||||
'idx_postcode_postcode'
|
||||
]
|
||||
if conn.table_exists('search_name'):
|
||||
if table_exists(conn, 'search_name'):
|
||||
indexes.extend(('idx_search_name_nameaddress_vector',
|
||||
'idx_search_name_name_vector',
|
||||
'idx_search_name_centroid'))
|
||||
if conn.server_version_tuple() >= (11, 0, 0):
|
||||
if server_version_tuple(conn) >= (11, 0, 0):
|
||||
indexes.extend(('idx_placex_housenumber',
|
||||
'idx_osmline_parent_osm_id_with_hnr'))
|
||||
if conn.table_exists('place'):
|
||||
if table_exists(conn, 'place'):
|
||||
indexes.extend(('idx_location_area_country_place_id',
|
||||
'idx_place_osm_unique',
|
||||
'idx_placex_rank_address_sector',
|
||||
@ -153,7 +154,7 @@ def check_connection(conn: Any, config: Configuration) -> CheckResult:
|
||||
|
||||
Hints:
|
||||
* Are you connecting to the correct database?
|
||||
|
||||
|
||||
{instruction}
|
||||
|
||||
Check the Migration chapter of the Administration Guide.
|
||||
@ -165,7 +166,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
|
||||
""" Checking database_version matches Nominatim software version
|
||||
"""
|
||||
|
||||
if conn.table_exists('nominatim_properties'):
|
||||
if table_exists(conn, 'nominatim_properties'):
|
||||
db_version_str = properties.get_property(conn, 'database_version')
|
||||
else:
|
||||
db_version_str = None
|
||||
@ -202,7 +203,7 @@ def check_database_version(conn: Connection, config: Configuration) -> CheckResu
|
||||
def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
|
||||
""" Checking for placex table
|
||||
"""
|
||||
if conn.table_exists('placex'):
|
||||
if table_exists(conn, 'placex'):
|
||||
return CheckState.OK
|
||||
|
||||
return CheckState.FATAL, dict(config=config)
|
||||
@ -212,8 +213,7 @@ def check_placex_table(conn: Connection, config: Configuration) -> CheckResult:
|
||||
def check_placex_size(conn: Connection, _: Configuration) -> CheckResult:
|
||||
""" Checking for placex content
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cnt = cur.scalar('SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
|
||||
cnt = execute_scalar(conn, 'SELECT count(*) FROM (SELECT * FROM placex LIMIT 100) x')
|
||||
|
||||
return CheckState.OK if cnt > 0 else CheckState.FATAL
|
||||
|
||||
@ -244,16 +244,15 @@ def check_tokenizer(_: Connection, config: Configuration) -> CheckResult:
|
||||
def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult:
|
||||
""" Checking for wikipedia/wikidata data
|
||||
"""
|
||||
if not conn.table_exists('search_name') or not conn.table_exists('place'):
|
||||
if not table_exists(conn, 'search_name') or not table_exists(conn, 'place'):
|
||||
return CheckState.NOT_APPLICABLE
|
||||
|
||||
with conn.cursor() as cur:
|
||||
if conn.table_exists('wikimedia_importance'):
|
||||
cnt = cur.scalar('SELECT count(*) FROM wikimedia_importance')
|
||||
else:
|
||||
cnt = cur.scalar('SELECT count(*) FROM wikipedia_article')
|
||||
if table_exists(conn, 'wikimedia_importance'):
|
||||
cnt = execute_scalar(conn, 'SELECT count(*) FROM wikimedia_importance')
|
||||
else:
|
||||
cnt = execute_scalar(conn, 'SELECT count(*) FROM wikipedia_article')
|
||||
|
||||
return CheckState.WARN if cnt == 0 else CheckState.OK
|
||||
return CheckState.WARN if cnt == 0 else CheckState.OK
|
||||
|
||||
|
||||
@_check(hint="""\
|
||||
@ -264,8 +263,7 @@ def check_existance_wikipedia(conn: Connection, _: Configuration) -> CheckResult
|
||||
def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
|
||||
""" Checking indexing status
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
cnt = cur.scalar('SELECT count(*) FROM placex WHERE indexed_status > 0')
|
||||
cnt = execute_scalar(conn, 'SELECT count(*) FROM placex WHERE indexed_status > 0')
|
||||
|
||||
if cnt == 0:
|
||||
return CheckState.OK
|
||||
@ -276,7 +274,7 @@ def check_indexing(conn: Connection, _: Configuration) -> CheckResult:
|
||||
Low counts of unindexed places are fine."""
|
||||
return CheckState.WARN, dict(count=cnt, index_cmd=index_cmd)
|
||||
|
||||
if conn.index_exists('idx_placex_rank_search'):
|
||||
if index_exists(conn, 'idx_placex_rank_search'):
|
||||
# Likely just an interrupted update.
|
||||
index_cmd = 'nominatim index'
|
||||
else:
|
||||
@ -297,7 +295,7 @@ def check_database_indexes(conn: Connection, _: Configuration) -> CheckResult:
|
||||
"""
|
||||
missing = []
|
||||
for index in _get_indexes(conn):
|
||||
if not conn.index_exists(index):
|
||||
if not index_exists(conn, index):
|
||||
missing.append(index)
|
||||
|
||||
if missing:
|
||||
@ -340,11 +338,10 @@ def check_tiger_table(conn: Connection, config: Configuration) -> CheckResult:
|
||||
if not config.get_bool('USE_US_TIGER_DATA'):
|
||||
return CheckState.NOT_APPLICABLE
|
||||
|
||||
if not conn.table_exists('location_property_tiger'):
|
||||
if not table_exists(conn, 'location_property_tiger'):
|
||||
return CheckState.FAIL, dict(error='TIGER data table not found.')
|
||||
|
||||
with conn.cursor() as cur:
|
||||
if cur.scalar('SELECT count(*) FROM location_property_tiger') == 0:
|
||||
return CheckState.FAIL, dict(error='TIGER data table is empty.')
|
||||
if execute_scalar(conn, 'SELECT count(*) FROM location_property_tiger') == 0:
|
||||
return CheckState.FAIL, dict(error='TIGER data table is empty.')
|
||||
|
||||
return CheckState.OK
|
||||
|
@ -12,21 +12,15 @@ import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import psutil
|
||||
from psycopg2.extensions import make_dsn, parse_dsn
|
||||
|
||||
from ..config import Configuration
|
||||
from ..db.connection import connect
|
||||
from ..db.connection import connect, server_version_tuple, execute_scalar
|
||||
from ..version import NOMINATIM_VERSION
|
||||
|
||||
|
||||
def convert_version(ver_tup: Tuple[int, int]) -> str:
|
||||
"""converts tuple version (ver_tup) to a string representation"""
|
||||
return ".".join(map(str, ver_tup))
|
||||
|
||||
|
||||
def friendly_memory_string(mem: float) -> str:
|
||||
"""Create a user friendly string for the amount of memory specified as mem"""
|
||||
mem_magnitude = ("bytes", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
|
||||
@ -102,17 +96,17 @@ def report_system_information(config: Configuration) -> None:
|
||||
"""Generate a report about the host system including software versions, memory,
|
||||
storage, and database configuration."""
|
||||
|
||||
with connect(make_dsn(config.get_libpq_dsn(), dbname='postgres')) as conn:
|
||||
postgresql_ver: str = convert_version(conn.server_version_tuple())
|
||||
with connect(config.get_libpq_dsn(), dbname='postgres') as conn:
|
||||
postgresql_ver: str = '.'.join(map(str, server_version_tuple(conn)))
|
||||
|
||||
with conn.cursor() as cur:
|
||||
num = cur.scalar("SELECT count(*) FROM pg_catalog.pg_database WHERE datname=%s",
|
||||
(parse_dsn(config.get_libpq_dsn())['dbname'], ))
|
||||
nominatim_db_exists = num == 1 if isinstance(num, int) else False
|
||||
cur.execute("SELECT datname FROM pg_catalog.pg_database WHERE datname=%s",
|
||||
(config.get_database_params()['dbname'], ))
|
||||
nominatim_db_exists = cur.rowcount > 0
|
||||
|
||||
if nominatim_db_exists:
|
||||
with connect(config.get_libpq_dsn()) as conn:
|
||||
postgis_ver: str = convert_version(conn.postgis_version_tuple())
|
||||
postgis_ver: str = execute_scalar(conn, 'SELECT postgis_lib_version()')
|
||||
else:
|
||||
postgis_ver = "Unable to connect to database"
|
||||
|
||||
|
@ -10,18 +10,20 @@ Functions for setting up and importing a new Nominatim database.
|
||||
from typing import Tuple, Optional, Union, Sequence, MutableMapping, Any
|
||||
import logging
|
||||
import os
|
||||
import selectors
|
||||
import subprocess
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
from psycopg2 import sql as pysql
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..errors import UsageError
|
||||
from ..config import Configuration
|
||||
from ..db.connection import connect, get_pg_env, Connection
|
||||
from ..db.async_connection import DBConnection
|
||||
from ..db.connection import connect, get_pg_env, Connection, server_version_tuple,\
|
||||
postgis_version_tuple, drop_tables, table_exists, execute_scalar
|
||||
from ..db.sql_preprocessor import SQLPreprocessor
|
||||
from ..db.query_pool import QueryPool
|
||||
from .exec_utils import run_osm2pgsql
|
||||
from ..version import POSTGRESQL_REQUIRED_VERSION, POSTGIS_REQUIRED_VERSION
|
||||
|
||||
@ -40,19 +42,21 @@ def _require_version(module: str, actual: Tuple[int, int], expected: Tuple[int,
|
||||
|
||||
def _require_loaded(extension_name: str, conn: Connection) -> None:
|
||||
""" Check that the given extension is loaded. """
|
||||
if not conn.extension_loaded(extension_name):
|
||||
LOG.fatal('Required module %s is not loaded.', extension_name)
|
||||
raise UsageError(f'{extension_name} is not loaded.')
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT * FROM pg_extension WHERE extname = %s', (extension_name, ))
|
||||
if cur.rowcount <= 0:
|
||||
LOG.fatal('Required module %s is not loaded.', extension_name)
|
||||
raise UsageError(f'{extension_name} is not loaded.')
|
||||
|
||||
|
||||
def check_existing_database_plugins(dsn: str) -> None:
|
||||
""" Check that the database has the required plugins installed."""
|
||||
with connect(dsn) as conn:
|
||||
_require_version('PostgreSQL server',
|
||||
conn.server_version_tuple(),
|
||||
server_version_tuple(conn),
|
||||
POSTGRESQL_REQUIRED_VERSION)
|
||||
_require_version('PostGIS',
|
||||
conn.postgis_version_tuple(),
|
||||
postgis_version_tuple(conn),
|
||||
POSTGIS_REQUIRED_VERSION)
|
||||
_require_loaded('hstore', conn)
|
||||
|
||||
@ -78,31 +82,30 @@ def setup_database_skeleton(dsn: str, rouser: Optional[str] = None) -> None:
|
||||
|
||||
with connect(dsn) as conn:
|
||||
_require_version('PostgreSQL server',
|
||||
conn.server_version_tuple(),
|
||||
server_version_tuple(conn),
|
||||
POSTGRESQL_REQUIRED_VERSION)
|
||||
|
||||
if rouser is not None:
|
||||
with conn.cursor() as cur:
|
||||
cnt = cur.scalar('SELECT count(*) FROM pg_user where usename = %s',
|
||||
cnt = execute_scalar(conn, 'SELECT count(*) FROM pg_user where usename = %s',
|
||||
(rouser, ))
|
||||
if cnt == 0:
|
||||
LOG.fatal("Web user '%s' does not exist. Create it with:\n"
|
||||
"\n createuser %s", rouser, rouser)
|
||||
raise UsageError('Missing read-only user.')
|
||||
if cnt == 0:
|
||||
LOG.fatal("Web user '%s' does not exist. Create it with:\n"
|
||||
"\n createuser %s", rouser, rouser)
|
||||
raise UsageError('Missing read-only user.')
|
||||
|
||||
# Create extensions.
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('CREATE EXTENSION IF NOT EXISTS hstore')
|
||||
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis')
|
||||
|
||||
postgis_version = conn.postgis_version_tuple()
|
||||
postgis_version = postgis_version_tuple(conn)
|
||||
if postgis_version[0] >= 3:
|
||||
cur.execute('CREATE EXTENSION IF NOT EXISTS postgis_raster')
|
||||
|
||||
conn.commit()
|
||||
|
||||
_require_version('PostGIS',
|
||||
conn.postgis_version_tuple(),
|
||||
postgis_version_tuple(conn),
|
||||
POSTGIS_REQUIRED_VERSION)
|
||||
|
||||
|
||||
@ -134,12 +137,13 @@ def import_osm_data(osm_files: Union[Path, Sequence[Path]],
|
||||
with connect(options['dsn']) as conn:
|
||||
if not ignore_errors:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('SELECT * FROM place LIMIT 1')
|
||||
cur.execute('SELECT true FROM place LIMIT 1')
|
||||
if cur.rowcount == 0:
|
||||
raise UsageError('No data imported by osm2pgsql.')
|
||||
|
||||
if drop:
|
||||
conn.drop_table('planet_osm_nodes')
|
||||
drop_tables(conn, 'planet_osm_nodes')
|
||||
conn.commit()
|
||||
|
||||
if drop and options['flatnode_file']:
|
||||
Path(options['flatnode_file']).unlink()
|
||||
@ -182,7 +186,7 @@ def truncate_data_tables(conn: Connection) -> None:
|
||||
cur.execute('TRUNCATE location_property_tiger')
|
||||
cur.execute('TRUNCATE location_property_osmline')
|
||||
cur.execute('TRUNCATE location_postcode')
|
||||
if conn.table_exists('search_name'):
|
||||
if table_exists(conn, 'search_name'):
|
||||
cur.execute('TRUNCATE search_name')
|
||||
cur.execute('DROP SEQUENCE IF EXISTS seq_place')
|
||||
cur.execute('CREATE SEQUENCE seq_place start 100000')
|
||||
@ -202,54 +206,51 @@ _COPY_COLUMNS = pysql.SQL(',').join(map(pysql.Identifier,
|
||||
'extratags', 'geometry')))
|
||||
|
||||
|
||||
def load_data(dsn: str, threads: int) -> None:
|
||||
async def load_data(dsn: str, threads: int) -> None:
|
||||
""" Copy data into the word and placex table.
|
||||
"""
|
||||
sel = selectors.DefaultSelector()
|
||||
# Then copy data from place to placex in <threads - 1> chunks.
|
||||
place_threads = max(1, threads - 1)
|
||||
for imod in range(place_threads):
|
||||
conn = DBConnection(dsn)
|
||||
conn.connect()
|
||||
conn.perform(
|
||||
pysql.SQL("""INSERT INTO placex ({columns})
|
||||
SELECT {columns} FROM place
|
||||
WHERE osm_id % {total} = {mod}
|
||||
AND NOT (class='place' and (type='houses' or type='postcode'))
|
||||
AND ST_IsValid(geometry)
|
||||
""").format(columns=_COPY_COLUMNS,
|
||||
total=pysql.Literal(place_threads),
|
||||
mod=pysql.Literal(imod)))
|
||||
sel.register(conn, selectors.EVENT_READ, conn)
|
||||
placex_threads = max(1, threads - 1)
|
||||
|
||||
# Address interpolations go into another table.
|
||||
conn = DBConnection(dsn)
|
||||
conn.connect()
|
||||
conn.perform("""INSERT INTO location_property_osmline (osm_id, address, linegeo)
|
||||
SELECT osm_id, address, geometry FROM place
|
||||
WHERE class='place' and type='houses' and osm_type='W'
|
||||
and ST_GeometryType(geometry) = 'ST_LineString'
|
||||
""")
|
||||
sel.register(conn, selectors.EVENT_READ, conn)
|
||||
progress = asyncio.create_task(_progress_print())
|
||||
|
||||
# Now wait for all of them to finish.
|
||||
todo = place_threads + 1
|
||||
while todo > 0:
|
||||
for key, _ in sel.select(1):
|
||||
conn = key.data
|
||||
sel.unregister(conn)
|
||||
conn.wait()
|
||||
conn.close()
|
||||
todo -= 1
|
||||
async with QueryPool(dsn, placex_threads + 1) as pool:
|
||||
# Copy data from place to placex in <threads - 1> chunks.
|
||||
for imod in range(placex_threads):
|
||||
await pool.put_query(
|
||||
pysql.SQL("""INSERT INTO placex ({columns})
|
||||
SELECT {columns} FROM place
|
||||
WHERE osm_id % {total} = {mod}
|
||||
AND NOT (class='place'
|
||||
and (type='houses' or type='postcode'))
|
||||
AND ST_IsValid(geometry)
|
||||
""").format(columns=_COPY_COLUMNS,
|
||||
total=pysql.Literal(placex_threads),
|
||||
mod=pysql.Literal(imod)), None)
|
||||
|
||||
# Interpolations need to be copied seperately
|
||||
await pool.put_query("""
|
||||
INSERT INTO location_property_osmline (osm_id, address, linegeo)
|
||||
SELECT osm_id, address, geometry FROM place
|
||||
WHERE class='place' and type='houses' and osm_type='W'
|
||||
and ST_GeometryType(geometry) = 'ST_LineString' """, None)
|
||||
|
||||
progress.cancel()
|
||||
|
||||
async with await psycopg.AsyncConnection.connect(dsn) as aconn:
|
||||
await aconn.execute('ANALYSE')
|
||||
|
||||
|
||||
async def _progress_print() -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(1)
|
||||
except asyncio.CancelledError:
|
||||
print('', flush=True)
|
||||
break
|
||||
print('.', end='', flush=True)
|
||||
print('\n')
|
||||
|
||||
with connect(dsn) as syn_conn:
|
||||
with syn_conn.cursor() as cur:
|
||||
cur.execute('ANALYSE')
|
||||
|
||||
|
||||
def create_search_indices(conn: Connection, config: Configuration,
|
||||
async def create_search_indices(conn: Connection, config: Configuration,
|
||||
drop: bool = False, threads: int = 1) -> None:
|
||||
""" Create tables that have explicit partitioning.
|
||||
"""
|
||||
@ -268,5 +269,5 @@ def create_search_indices(conn: Connection, config: Configuration,
|
||||
|
||||
sql = SQLPreprocessor(conn, config)
|
||||
|
||||
sql.run_parallel_sql_file(config.get_libpq_dsn(),
|
||||
'indices.sql', min(8, threads), drop=drop)
|
||||
await sql.run_parallel_sql_file(config.get_libpq_dsn(),
|
||||
'indices.sql', min(8, threads), drop=drop)
|
||||
|
@ -10,9 +10,9 @@ Functions for removing unnecessary data from the database.
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
from psycopg2 import sql as pysql
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..db.connection import Connection
|
||||
from ..db.connection import Connection, drop_tables, table_exists
|
||||
|
||||
UPDATE_TABLES = [
|
||||
'address_levels',
|
||||
@ -39,9 +39,7 @@ def drop_update_tables(conn: Connection) -> None:
|
||||
+ pysql.SQL(' or ').join(parts))
|
||||
tables = [r[0] for r in cur]
|
||||
|
||||
for table in tables:
|
||||
cur.drop_table(table, cascade=True)
|
||||
|
||||
drop_tables(conn, *tables, cascade=True)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@ -55,4 +53,4 @@ def is_frozen(conn: Connection) -> bool:
|
||||
""" Returns true if database is in a frozen state
|
||||
"""
|
||||
|
||||
return conn.table_exists('place') is False
|
||||
return table_exists(conn, 'place') is False
|
||||
|
@ -10,12 +10,13 @@ Functions for database migration to newer software versions.
|
||||
from typing import List, Tuple, Callable, Any
|
||||
import logging
|
||||
|
||||
from psycopg2 import sql as pysql
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..errors import UsageError
|
||||
from ..config import Configuration
|
||||
from ..db import properties
|
||||
from ..db.connection import connect, Connection
|
||||
from ..db.connection import connect, Connection, server_version_tuple,\
|
||||
table_has_column, table_exists, execute_scalar, register_hstore
|
||||
from ..version import NominatimVersion, NOMINATIM_VERSION, parse_version
|
||||
from ..tokenizer import factory as tokenizer_factory
|
||||
from . import refresh
|
||||
@ -29,7 +30,8 @@ def migrate(config: Configuration, paths: Any) -> int:
|
||||
if necesssary.
|
||||
"""
|
||||
with connect(config.get_libpq_dsn()) as conn:
|
||||
if conn.table_exists('nominatim_properties'):
|
||||
register_hstore(conn)
|
||||
if table_exists(conn, 'nominatim_properties'):
|
||||
db_version_str = properties.get_property(conn, 'database_version')
|
||||
else:
|
||||
db_version_str = None
|
||||
@ -72,16 +74,15 @@ def _guess_version(conn: Connection) -> NominatimVersion:
|
||||
Only migrations for 3.6 and later are supported, so bail out
|
||||
when the version seems older.
|
||||
"""
|
||||
with conn.cursor() as cur:
|
||||
# In version 3.6, the country_name table was updated. Check for that.
|
||||
cnt = cur.scalar("""SELECT count(*) FROM
|
||||
(SELECT svals(name) FROM country_name
|
||||
WHERE country_code = 'gb')x;
|
||||
""")
|
||||
if cnt < 100:
|
||||
LOG.fatal('It looks like your database was imported with a version '
|
||||
'prior to 3.6.0. Automatic migration not possible.')
|
||||
raise UsageError('Migration not possible.')
|
||||
# In version 3.6, the country_name table was updated. Check for that.
|
||||
cnt = execute_scalar(conn, """SELECT count(*) FROM
|
||||
(SELECT svals(name) FROM country_name
|
||||
WHERE country_code = 'gb')x;
|
||||
""")
|
||||
if cnt < 100:
|
||||
LOG.fatal('It looks like your database was imported with a version '
|
||||
'prior to 3.6.0. Automatic migration not possible.')
|
||||
raise UsageError('Migration not possible.')
|
||||
|
||||
return NominatimVersion(3, 5, 0, 99)
|
||||
|
||||
@ -125,7 +126,7 @@ def import_status_timestamp_change(conn: Connection, **_: Any) -> None:
|
||||
def add_nominatim_property_table(conn: Connection, config: Configuration, **_: Any) -> None:
|
||||
""" Add nominatim_property table.
|
||||
"""
|
||||
if not conn.table_exists('nominatim_properties'):
|
||||
if not table_exists(conn, 'nominatim_properties'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(pysql.SQL("""CREATE TABLE nominatim_properties (
|
||||
property TEXT,
|
||||
@ -189,13 +190,9 @@ def install_legacy_tokenizer(conn: Connection, config: Configuration, **_: Any)
|
||||
configuration for the backwards-compatible legacy tokenizer
|
||||
"""
|
||||
if properties.get_property(conn, 'tokenizer') is None:
|
||||
with conn.cursor() as cur:
|
||||
for table in ('placex', 'location_property_osmline'):
|
||||
has_column = cur.scalar("""SELECT count(*) FROM information_schema.columns
|
||||
WHERE table_name = %s
|
||||
and column_name = 'token_info'""",
|
||||
(table, ))
|
||||
if has_column == 0:
|
||||
for table in ('placex', 'location_property_osmline'):
|
||||
if not table_has_column(conn, table, 'token_info'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(pysql.SQL('ALTER TABLE {} ADD COLUMN token_info JSONB')
|
||||
.format(pysql.Identifier(table)))
|
||||
tokenizer = tokenizer_factory.create_tokenizer(config, init_db=False,
|
||||
@ -212,7 +209,7 @@ def create_tiger_housenumber_index(conn: Connection, **_: Any) -> None:
|
||||
The inclusion is needed for efficient lookup of housenumbers in
|
||||
full address searches.
|
||||
"""
|
||||
if conn.server_version_tuple() >= (11, 0, 0):
|
||||
if server_version_tuple(conn) >= (11, 0, 0):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(""" CREATE INDEX IF NOT EXISTS
|
||||
idx_location_property_tiger_housenumber_migrated
|
||||
@ -239,7 +236,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
|
||||
Also converts the data into the stricter format which requires that
|
||||
startnumbers comply with the odd/even requirements.
|
||||
"""
|
||||
if conn.table_has_column('location_property_osmline', 'step'):
|
||||
if table_has_column(conn, 'location_property_osmline', 'step'):
|
||||
return
|
||||
|
||||
with conn.cursor() as cur:
|
||||
@ -271,7 +268,7 @@ def add_step_column_for_interpolation(conn: Connection, **_: Any) -> None:
|
||||
def add_step_column_for_tiger(conn: Connection, **_: Any) -> None:
|
||||
""" Add a new column 'step' to the tiger data table.
|
||||
"""
|
||||
if conn.table_has_column('location_property_tiger', 'step'):
|
||||
if table_has_column(conn, 'location_property_tiger', 'step'):
|
||||
return
|
||||
|
||||
with conn.cursor() as cur:
|
||||
@ -287,7 +284,7 @@ def add_derived_name_column_for_country_names(conn: Connection, **_: Any) -> Non
|
||||
""" Add a new column 'derived_name' which in the future takes the
|
||||
country names as imported from OSM data.
|
||||
"""
|
||||
if not conn.table_has_column('country_name', 'derived_name'):
|
||||
if not table_has_column(conn, 'country_name', 'derived_name'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("ALTER TABLE country_name ADD COLUMN derived_name public.HSTORE")
|
||||
|
||||
@ -297,12 +294,9 @@ def mark_internal_country_names(conn: Connection, config: Configuration, **_: An
|
||||
""" Names from the country table should be marked as internal to prevent
|
||||
them from being deleted. Only necessary for ICU tokenizer.
|
||||
"""
|
||||
import psycopg2.extras # pylint: disable=import-outside-toplevel
|
||||
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(config)
|
||||
with tokenizer.name_analyzer() as analyzer:
|
||||
with conn.cursor() as cur:
|
||||
psycopg2.extras.register_hstore(cur)
|
||||
cur.execute("SELECT country_code, name FROM country_name")
|
||||
|
||||
for country_code, names in cur:
|
||||
@ -319,7 +313,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
|
||||
The table is only necessary when updates are possible, i.e.
|
||||
the database is not in freeze mode.
|
||||
"""
|
||||
if conn.table_exists('place'):
|
||||
if table_exists(conn, 'place'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""CREATE TABLE IF NOT EXISTS place_to_be_deleted (
|
||||
osm_type CHAR(1),
|
||||
@ -333,7 +327,7 @@ def add_place_deletion_todo_table(conn: Connection, **_: Any) -> None:
|
||||
def split_pending_index(conn: Connection, **_: Any) -> None:
|
||||
""" Reorganise indexes for pending updates.
|
||||
"""
|
||||
if conn.table_exists('place'):
|
||||
if table_exists(conn, 'place'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""CREATE INDEX IF NOT EXISTS idx_placex_rank_address_sector
|
||||
ON placex USING BTREE (rank_address, geometry_sector)
|
||||
@ -349,7 +343,7 @@ def split_pending_index(conn: Connection, **_: Any) -> None:
|
||||
def enable_forward_dependencies(conn: Connection, **_: Any) -> None:
|
||||
""" Create indexes for updates with forward dependency tracking (long-running).
|
||||
"""
|
||||
if conn.table_exists('planet_osm_ways'):
|
||||
if table_exists(conn, 'planet_osm_ways'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""SELECT * FROM pg_indexes
|
||||
WHERE tablename = 'planet_osm_ways'
|
||||
@ -398,7 +392,7 @@ def create_postcode_area_lookup_index(conn: Connection, **_: Any) -> None:
|
||||
def create_postcode_parent_index(conn: Connection, **_: Any) -> None:
|
||||
""" Create index needed for updating postcodes when a parent changes.
|
||||
"""
|
||||
if conn.table_exists('planet_osm_ways'):
|
||||
if table_exists(conn, 'planet_osm_ways'):
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""CREATE INDEX IF NOT EXISTS
|
||||
idx_location_postcode_parent_place_id
|
||||
|
@ -16,9 +16,9 @@ import gzip
|
||||
import logging
|
||||
from math import isfinite
|
||||
|
||||
from psycopg2 import sql as pysql
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..db.connection import connect, Connection
|
||||
from ..db.connection import connect, Connection, table_exists
|
||||
from ..utils.centroid import PointsCentroid
|
||||
from ..data.postcode_format import PostcodeFormatter, CountryPostcodeMatcher
|
||||
from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
|
||||
@ -76,30 +76,30 @@ class _PostcodeCollector:
|
||||
|
||||
with conn.cursor() as cur:
|
||||
if to_add:
|
||||
cur.execute_values(
|
||||
cur.executemany(pysql.SQL(
|
||||
"""INSERT INTO location_postcode
|
||||
(place_id, indexed_status, country_code,
|
||||
postcode, geometry) VALUES %s""",
|
||||
to_add,
|
||||
template=pysql.SQL("""(nextval('seq_place'), 1, {},
|
||||
%s, 'SRID=4326;POINT(%s %s)')
|
||||
""").format(pysql.Literal(self.country)))
|
||||
postcode, geometry)
|
||||
VALUES (nextval('seq_place'), 1, {}, %s,
|
||||
ST_SetSRID(ST_MakePoint(%s, %s), 4326))
|
||||
""").format(pysql.Literal(self.country)),
|
||||
to_add)
|
||||
if to_delete:
|
||||
cur.execute("""DELETE FROM location_postcode
|
||||
WHERE country_code = %s and postcode = any(%s)
|
||||
""", (self.country, to_delete))
|
||||
if to_update:
|
||||
cur.execute_values(
|
||||
cur.executemany(
|
||||
pysql.SQL("""UPDATE location_postcode
|
||||
SET indexed_status = 2,
|
||||
geometry = ST_SetSRID(ST_Point(v.x, v.y), 4326)
|
||||
FROM (VALUES %s) AS v (pc, x, y)
|
||||
WHERE country_code = {} and postcode = pc
|
||||
""").format(pysql.Literal(self.country)), to_update)
|
||||
geometry = ST_SetSRID(ST_Point(%s, %s), 4326)
|
||||
WHERE country_code = {} and postcode = %s
|
||||
""").format(pysql.Literal(self.country)),
|
||||
to_update)
|
||||
|
||||
|
||||
def _compute_changes(self, conn: Connection) \
|
||||
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[str, float, float]]]:
|
||||
-> Tuple[List[Tuple[str, float, float]], List[str], List[Tuple[float, float, str]]]:
|
||||
""" Compute which postcodes from the collected postcodes have to be
|
||||
added or modified and which from the location_postcode table
|
||||
have to be deleted.
|
||||
@ -116,7 +116,7 @@ class _PostcodeCollector:
|
||||
if pcobj:
|
||||
newx, newy = pcobj.centroid()
|
||||
if (x - newx) > 0.0000001 or (y - newy) > 0.0000001:
|
||||
to_update.append((postcode, newx, newy))
|
||||
to_update.append((newx, newy, postcode))
|
||||
else:
|
||||
to_delete.append(postcode)
|
||||
|
||||
@ -231,4 +231,4 @@ def can_compute(dsn: str) -> bool:
|
||||
postcodes can be computed.
|
||||
"""
|
||||
with connect(dsn) as conn:
|
||||
return conn.table_exists('place')
|
||||
return table_exists(conn, 'place')
|
||||
|
@ -14,11 +14,12 @@ import logging
|
||||
from textwrap import dedent
|
||||
from pathlib import Path
|
||||
|
||||
from psycopg2 import sql as pysql
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from ..config import Configuration
|
||||
from ..db.connection import Connection, connect
|
||||
from ..db.utils import execute_file, CopyBuffer
|
||||
from ..db.connection import Connection, connect, postgis_version_tuple,\
|
||||
drop_tables, table_exists
|
||||
from ..db.utils import execute_file
|
||||
from ..db.sql_preprocessor import SQLPreprocessor
|
||||
from ..version import NOMINATIM_VERSION
|
||||
|
||||
@ -56,9 +57,9 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
|
||||
for entry in levels:
|
||||
_add_address_level_rows_from_entry(rows, entry)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.drop_table(table)
|
||||
drop_tables(conn, table)
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(pysql.SQL("""CREATE TABLE {} (
|
||||
country_code varchar(2),
|
||||
class TEXT,
|
||||
@ -67,8 +68,8 @@ def load_address_levels(conn: Connection, table: str, levels: Sequence[Mapping[s
|
||||
rank_address SMALLINT)
|
||||
""").format(pysql.Identifier(table)))
|
||||
|
||||
cur.execute_values(pysql.SQL("INSERT INTO {} VALUES %s")
|
||||
.format(pysql.Identifier(table)), rows)
|
||||
cur.executemany(pysql.SQL("INSERT INTO {} VALUES (%s, %s, %s, %s, %s)")
|
||||
.format(pysql.Identifier(table)), rows)
|
||||
|
||||
cur.execute(pysql.SQL('CREATE UNIQUE INDEX ON {} (country_code, class, type)')
|
||||
.format(pysql.Identifier(table)))
|
||||
@ -154,15 +155,13 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
|
||||
if not data_file.exists():
|
||||
return 1
|
||||
|
||||
# Only import the first occurence of a wikidata ID.
|
||||
# Only import the first occurrence of a wikidata ID.
|
||||
# This keeps indexes and table small.
|
||||
wd_done = set()
|
||||
|
||||
with connect(dsn) as conn:
|
||||
drop_tables(conn, 'wikipedia_article', 'wikipedia_redirect', 'wikimedia_importance')
|
||||
with conn.cursor() as cur:
|
||||
cur.drop_table('wikipedia_article')
|
||||
cur.drop_table('wikipedia_redirect')
|
||||
cur.drop_table('wikimedia_importance')
|
||||
cur.execute("""CREATE TABLE wikimedia_importance (
|
||||
language TEXT NOT NULL,
|
||||
title TEXT NOT NULL,
|
||||
@ -170,24 +169,17 @@ def import_importance_csv(dsn: str, data_file: Path) -> int:
|
||||
wikidata TEXT
|
||||
) """)
|
||||
|
||||
with gzip.open(str(data_file), 'rt') as fd, CopyBuffer() as buf:
|
||||
for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
|
||||
wd_id = int(row['wikidata_id'][1:])
|
||||
buf.add(row['language'], row['title'], row['importance'],
|
||||
None if wd_id in wd_done else row['wikidata_id'])
|
||||
wd_done.add(wd_id)
|
||||
copy_cmd = """COPY wikimedia_importance(language, title, importance, wikidata)
|
||||
FROM STDIN"""
|
||||
with gzip.open(str(data_file), 'rt') as fd, cur.copy(copy_cmd) as copy:
|
||||
for row in csv.DictReader(fd, delimiter='\t', quotechar='|'):
|
||||
wd_id = int(row['wikidata_id'][1:])
|
||||
copy.write_row((row['language'],
|
||||
row['title'],
|
||||
row['importance'],
|
||||
None if wd_id in wd_done else row['wikidata_id']))
|
||||
wd_done.add(wd_id)
|
||||
|
||||
if buf.size() > 10000000:
|
||||
with conn.cursor() as cur:
|
||||
buf.copy_out(cur, 'wikimedia_importance',
|
||||
columns=['language', 'title', 'importance',
|
||||
'wikidata'])
|
||||
|
||||
with conn.cursor() as cur:
|
||||
buf.copy_out(cur, 'wikimedia_importance',
|
||||
columns=['language', 'title', 'importance', 'wikidata'])
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_title
|
||||
ON wikimedia_importance (title)""")
|
||||
cur.execute("""CREATE INDEX IF NOT EXISTS idx_wikimedia_importance_wikidata
|
||||
@ -228,7 +220,7 @@ def import_secondary_importance(dsn: str, data_path: Path, ignore_errors: bool =
|
||||
return 1
|
||||
|
||||
with connect(dsn) as conn:
|
||||
postgis_version = conn.postgis_version_tuple()
|
||||
postgis_version = postgis_version_tuple(conn)
|
||||
if postgis_version[0] < 3:
|
||||
LOG.error('PostGIS version is too old for using OSM raster data.')
|
||||
return 2
|
||||
@ -309,7 +301,7 @@ def setup_website(basedir: Path, config: Configuration, conn: Connection) -> Non
|
||||
|
||||
template = "\nrequire_once(CONST_LibDir.'/website/{}');\n"
|
||||
|
||||
search_name_table_exists = bool(conn and conn.table_exists('search_name'))
|
||||
search_name_table_exists = bool(conn and table_exists(conn, 'search_name'))
|
||||
|
||||
for script in WEBSITE_SCRIPTS:
|
||||
if not search_name_table_exists and script == 'search.php':
|
||||
|
@ -20,7 +20,7 @@ import requests
|
||||
|
||||
from ..errors import UsageError
|
||||
from ..db import status
|
||||
from ..db.connection import Connection, connect
|
||||
from ..db.connection import Connection, connect, server_version_tuple
|
||||
from .exec_utils import run_osm2pgsql
|
||||
|
||||
try:
|
||||
@ -155,7 +155,7 @@ def run_osm2pgsql_updates(conn: Connection, options: MutableMapping[str, Any]) -
|
||||
|
||||
# Consume updates with osm2pgsql.
|
||||
options['append'] = True
|
||||
options['disable_jit'] = conn.server_version_tuple() >= (11, 0)
|
||||
options['disable_jit'] = server_version_tuple(conn) >= (11, 0)
|
||||
run_osm2pgsql(options)
|
||||
|
||||
# Handle deletions
|
||||
|
@ -17,11 +17,11 @@ from typing import Iterable, Tuple, Mapping, Sequence, Optional, Set
|
||||
import logging
|
||||
import re
|
||||
|
||||
from psycopg2.sql import Identifier, SQL
|
||||
from psycopg.sql import Identifier, SQL
|
||||
|
||||
from ...typing import Protocol
|
||||
from ...config import Configuration
|
||||
from ...db.connection import Connection
|
||||
from ...db.connection import Connection, drop_tables, index_exists
|
||||
from .importer_statistics import SpecialPhrasesImporterStatistics
|
||||
from .special_phrase import SpecialPhrase
|
||||
from ...tokenizer.base import AbstractTokenizer
|
||||
@ -233,7 +233,7 @@ class SPImporter():
|
||||
index_prefix = f'idx_place_classtype_{phrase_class}_{phrase_type}_'
|
||||
base_table = _classtype_table(phrase_class, phrase_type)
|
||||
# Index on centroid
|
||||
if not self.db_connection.index_exists(index_prefix + 'centroid'):
|
||||
if not index_exists(self.db_connection, index_prefix + 'centroid'):
|
||||
with self.db_connection.cursor() as db_cursor:
|
||||
db_cursor.execute(SQL("CREATE INDEX {} ON {} USING GIST (centroid) {}")
|
||||
.format(Identifier(index_prefix + 'centroid'),
|
||||
@ -241,7 +241,7 @@ class SPImporter():
|
||||
SQL(sql_tablespace)))
|
||||
|
||||
# Index on place_id
|
||||
if not self.db_connection.index_exists(index_prefix + 'place_id'):
|
||||
if not index_exists(self.db_connection, index_prefix + 'place_id'):
|
||||
with self.db_connection.cursor() as db_cursor:
|
||||
db_cursor.execute(SQL("CREATE INDEX {} ON {} USING btree(place_id) {}")
|
||||
.format(Identifier(index_prefix + 'place_id'),
|
||||
@ -259,6 +259,7 @@ class SPImporter():
|
||||
.format(Identifier(table_name),
|
||||
Identifier(self.config.DATABASE_WEBUSER)))
|
||||
|
||||
|
||||
def _remove_non_existent_tables_from_db(self) -> None:
|
||||
"""
|
||||
Remove special phrases which doesn't exist on the wiki anymore.
|
||||
@ -268,7 +269,6 @@ class SPImporter():
|
||||
|
||||
# Delete place_classtype tables corresponding to class/type which
|
||||
# are not on the wiki anymore.
|
||||
with self.db_connection.cursor() as db_cursor:
|
||||
for table in self.table_phrases_to_delete:
|
||||
self.statistics_handler.notify_one_table_deleted()
|
||||
db_cursor.drop_table(table)
|
||||
drop_tables(self.db_connection, *self.table_phrases_to_delete)
|
||||
for _ in self.table_phrases_to_delete:
|
||||
self.statistics_handler.notify_one_table_deleted()
|
||||
|
@ -7,22 +7,22 @@
|
||||
"""
|
||||
Functions for importing tiger data and handling tarbar and directory files
|
||||
"""
|
||||
from typing import Any, TextIO, List, Union, cast
|
||||
from typing import Any, TextIO, List, Union, cast, Iterator, Dict
|
||||
import csv
|
||||
import io
|
||||
import logging
|
||||
import os
|
||||
import tarfile
|
||||
|
||||
from psycopg2.extras import Json
|
||||
from psycopg.types.json import Json
|
||||
|
||||
from ..config import Configuration
|
||||
from ..db.connection import connect
|
||||
from ..db.async_connection import WorkerPool
|
||||
from ..db.sql_preprocessor import SQLPreprocessor
|
||||
from ..errors import UsageError
|
||||
from ..db.query_pool import QueryPool
|
||||
from ..data.place_info import PlaceInfo
|
||||
from ..tokenizer.base import AbstractAnalyzer, AbstractTokenizer
|
||||
from ..tokenizer.base import AbstractTokenizer
|
||||
from . import freeze
|
||||
|
||||
LOG = logging.getLogger()
|
||||
@ -63,13 +63,13 @@ class TigerInput:
|
||||
self.tar_handle.close()
|
||||
self.tar_handle = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.files)
|
||||
|
||||
def next_file(self) -> TextIO:
|
||||
def get_file(self, fname: Union[str, tarfile.TarInfo]) -> TextIO:
|
||||
""" Return a file handle to the next file to be processed.
|
||||
Raises an IndexError if there is no file left.
|
||||
"""
|
||||
fname = self.files.pop(0)
|
||||
|
||||
if self.tar_handle is not None:
|
||||
extracted = self.tar_handle.extractfile(fname)
|
||||
assert extracted is not None
|
||||
@ -78,47 +78,22 @@ class TigerInput:
|
||||
return open(cast(str, fname), encoding='utf-8')
|
||||
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.files)
|
||||
def __iter__(self) -> Iterator[Dict[str, Any]]:
|
||||
""" Iterate over the lines in each file.
|
||||
"""
|
||||
for fname in self.files:
|
||||
fd = self.get_file(fname)
|
||||
yield from csv.DictReader(fd, delimiter=';')
|
||||
|
||||
|
||||
def handle_threaded_sql_statements(pool: WorkerPool, fd: TextIO,
|
||||
analyzer: AbstractAnalyzer) -> None:
|
||||
""" Handles sql statement with multiplexing
|
||||
"""
|
||||
lines = 0
|
||||
# Using pool of database connections to execute sql statements
|
||||
|
||||
sql = "SELECT tiger_line_import(%s, %s, %s, %s, %s, %s)"
|
||||
|
||||
for row in csv.DictReader(fd, delimiter=';'):
|
||||
try:
|
||||
address = dict(street=row['street'], postcode=row['postcode'])
|
||||
args = ('SRID=4326;' + row['geometry'],
|
||||
int(row['from']), int(row['to']), row['interpolation'],
|
||||
Json(analyzer.process_place(PlaceInfo({'address': address}))),
|
||||
analyzer.normalize_postcode(row['postcode']))
|
||||
except ValueError:
|
||||
continue
|
||||
pool.next_free_worker().perform(sql, args=args)
|
||||
|
||||
lines += 1
|
||||
if lines == 1000:
|
||||
print('.', end='', flush=True)
|
||||
lines = 0
|
||||
|
||||
|
||||
def add_tiger_data(data_dir: str, config: Configuration, threads: int,
|
||||
async def add_tiger_data(data_dir: str, config: Configuration, threads: int,
|
||||
tokenizer: AbstractTokenizer) -> int:
|
||||
""" Import tiger data from directory or tar file `data dir`.
|
||||
"""
|
||||
dsn = config.get_libpq_dsn()
|
||||
|
||||
with connect(dsn) as conn:
|
||||
is_frozen = freeze.is_frozen(conn)
|
||||
conn.close()
|
||||
|
||||
if is_frozen:
|
||||
if freeze.is_frozen(conn):
|
||||
raise UsageError("Tiger cannot be imported when database frozen (Github issue #3048)")
|
||||
|
||||
with TigerInput(data_dir) as tar:
|
||||
@ -133,13 +108,30 @@ def add_tiger_data(data_dir: str, config: Configuration, threads: int,
|
||||
# sql_query in <threads - 1> chunks.
|
||||
place_threads = max(1, threads - 1)
|
||||
|
||||
with WorkerPool(dsn, place_threads, ignore_sql_errors=True) as pool:
|
||||
async with QueryPool(dsn, place_threads, autocommit=True) as pool:
|
||||
with tokenizer.name_analyzer() as analyzer:
|
||||
while tar:
|
||||
with tar.next_file() as fd:
|
||||
handle_threaded_sql_statements(pool, fd, analyzer)
|
||||
lines = 0
|
||||
for row in tar:
|
||||
try:
|
||||
address = dict(street=row['street'], postcode=row['postcode'])
|
||||
args = ('SRID=4326;' + row['geometry'],
|
||||
int(row['from']), int(row['to']), row['interpolation'],
|
||||
Json(analyzer.process_place(PlaceInfo({'address': address}))),
|
||||
analyzer.normalize_postcode(row['postcode']))
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
print('\n')
|
||||
await pool.put_query(
|
||||
"""SELECT tiger_line_import(%s::GEOMETRY, %s::INT,
|
||||
%s::INT, %s::TEXT, %s::JSONB, %s::TEXT)""",
|
||||
args)
|
||||
|
||||
lines += 1
|
||||
if lines == 1000:
|
||||
print('.', end='', flush=True)
|
||||
lines = 0
|
||||
|
||||
print('', flush=True)
|
||||
|
||||
LOG.warning("Creating indexes on Tiger data")
|
||||
with connect(dsn) as conn:
|
||||
|
@ -16,18 +16,13 @@ from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
|
||||
# pylint: disable=missing-class-docstring,useless-import-alias
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import psycopg2.sql
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
import os
|
||||
|
||||
StrPath = Union[str, 'os.PathLike[str]']
|
||||
|
||||
SysEnv = Mapping[str, str]
|
||||
|
||||
# psycopg2-related types
|
||||
|
||||
Query = Union[str, bytes, 'psycopg2.sql.Composable']
|
||||
# psycopg-related types
|
||||
|
||||
T_ResultKey = TypeVar('T_ResultKey', int, str)
|
||||
|
||||
@ -36,8 +31,6 @@ class DictCursorResult(Mapping[str, Any]):
|
||||
|
||||
DictCursorResults = Sequence[DictCursorResult]
|
||||
|
||||
T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
|
||||
|
||||
# The following typing features require typing_extensions to work
|
||||
# on all supported Python versions.
|
||||
# Only require this for type checking but not for normal operations.
|
||||
|
@ -31,7 +31,7 @@ class NominatimVersion(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
patch_level: int
|
||||
db_patch_level: Optional[int]
|
||||
db_patch_level: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.db_patch_level is None:
|
||||
@ -47,6 +47,7 @@ class NominatimVersion(NamedTuple):
|
||||
return f"{self.major}.{self.minor}.{self.patch_level}"
|
||||
|
||||
|
||||
|
||||
def parse_version(version: str) -> NominatimVersion:
|
||||
""" Parse a version string into a version consisting of a tuple of
|
||||
four ints: major, minor, patch level, database patch level
|
||||
|
@ -9,14 +9,14 @@ import importlib
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
|
||||
sys.path.insert(1, str((Path(__file__) / '..' / '..' / '..' / '..'/ 'src').resolve()))
|
||||
|
||||
from nominatim_db import cli
|
||||
from nominatim_db.config import Configuration
|
||||
from nominatim_db.db.connection import Connection
|
||||
from nominatim_db.db.connection import Connection, register_hstore, execute_scalar
|
||||
from nominatim_db.tools import refresh
|
||||
from nominatim_db.tokenizer import factory as tokenizer_factory
|
||||
from steps.utils import run_script
|
||||
@ -60,7 +60,7 @@ class NominatimEnvironment:
|
||||
""" Return a connection to the database with the given name.
|
||||
Uses configured host, user and port.
|
||||
"""
|
||||
dbargs = {'database': dbname}
|
||||
dbargs = {'dbname': dbname, 'row_factory': psycopg.rows.dict_row}
|
||||
if self.db_host:
|
||||
dbargs['host'] = self.db_host
|
||||
if self.db_port:
|
||||
@ -69,8 +69,7 @@ class NominatimEnvironment:
|
||||
dbargs['user'] = self.db_user
|
||||
if self.db_pass:
|
||||
dbargs['password'] = self.db_pass
|
||||
conn = psycopg2.connect(connection_factory=Connection, **dbargs)
|
||||
return conn
|
||||
return psycopg.connect(**dbargs)
|
||||
|
||||
def next_code_coverage_file(self):
|
||||
""" Generate the next name for a coverage file.
|
||||
@ -132,6 +131,8 @@ class NominatimEnvironment:
|
||||
conn = False
|
||||
refresh.setup_website(Path(self.website_dir.name) / 'website',
|
||||
self.get_test_config(), conn)
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_test_config(self):
|
||||
@ -160,11 +161,10 @@ class NominatimEnvironment:
|
||||
def db_drop_database(self, name):
|
||||
""" Drop the database with the given name.
|
||||
"""
|
||||
conn = self.connect_database('postgres')
|
||||
conn.set_isolation_level(0)
|
||||
cur = conn.cursor()
|
||||
cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
|
||||
conn.close()
|
||||
with self.connect_database('postgres') as conn:
|
||||
conn.autocommit = True
|
||||
conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
|
||||
+ pysql.Identifier(name))
|
||||
|
||||
def setup_template_db(self):
|
||||
""" Setup a template database that already contains common test data.
|
||||
@ -249,16 +249,18 @@ class NominatimEnvironment:
|
||||
""" Setup a test against a fresh, empty test database.
|
||||
"""
|
||||
self.setup_template_db()
|
||||
conn = self.connect_database(self.template_db)
|
||||
conn.set_isolation_level(0)
|
||||
cur = conn.cursor()
|
||||
cur.execute('DROP DATABASE IF EXISTS {}'.format(self.test_db))
|
||||
cur.execute('CREATE DATABASE {} TEMPLATE = {}'.format(self.test_db, self.template_db))
|
||||
conn.close()
|
||||
with self.connect_database(self.template_db) as conn:
|
||||
conn.autocommit = True
|
||||
conn.execute(pysql.SQL('DROP DATABASE IF EXISTS')
|
||||
+ pysql.Identifier(self.test_db))
|
||||
conn.execute(pysql.SQL('CREATE DATABASE {} TEMPLATE = {}').format(
|
||||
pysql.Identifier(self.test_db),
|
||||
pysql.Identifier(self.template_db)))
|
||||
|
||||
self.write_nominatim_config(self.test_db)
|
||||
context.db = self.connect_database(self.test_db)
|
||||
context.db.autocommit = True
|
||||
psycopg2.extras.register_hstore(context.db, globally=False)
|
||||
register_hstore(context.db)
|
||||
|
||||
def teardown_db(self, context, force_drop=False):
|
||||
""" Remove the test database, if it exists.
|
||||
@ -276,31 +278,26 @@ class NominatimEnvironment:
|
||||
dropped and always false returned.
|
||||
"""
|
||||
if self.reuse_template:
|
||||
conn = self.connect_database('postgres')
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('select count(*) from pg_database where datname = %s',
|
||||
(name,))
|
||||
if cur.fetchone()[0] == 1:
|
||||
with self.connect_database('postgres') as conn:
|
||||
num = execute_scalar(conn,
|
||||
'select count(*) from pg_database where datname = %s',
|
||||
(name,))
|
||||
if num == 1:
|
||||
return True
|
||||
conn.close()
|
||||
else:
|
||||
self.db_drop_database(name)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def reindex_placex(self, db):
|
||||
""" Run the indexing step until all data in the placex has
|
||||
been processed. Indexing during updates can produce more data
|
||||
to index under some circumstances. That is why indexing may have
|
||||
to be run multiple times.
|
||||
"""
|
||||
with db.cursor() as cur:
|
||||
while True:
|
||||
self.run_nominatim('index')
|
||||
self.run_nominatim('index')
|
||||
|
||||
cur.execute("SELECT 'a' FROM placex WHERE indexed_status != 0 LIMIT 1")
|
||||
if cur.rowcount == 0:
|
||||
return
|
||||
|
||||
def run_nominatim(self, *cmdline):
|
||||
""" Run the nominatim command-line tool via the library.
|
||||
|
@ -7,7 +7,8 @@
|
||||
import logging
|
||||
from itertools import chain
|
||||
|
||||
import psycopg2.extras
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from place_inserter import PlaceColumn
|
||||
from table_compare import NominatimID, DBRow
|
||||
@ -18,7 +19,7 @@ from nominatim_db.tokenizer import factory as tokenizer_factory
|
||||
def check_database_integrity(context):
|
||||
""" Check some generic constraints on the tables.
|
||||
"""
|
||||
with context.db.cursor() as cur:
|
||||
with context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
# place_addressline should not have duplicate (place_id, address_place_id)
|
||||
cur.execute("""SELECT count(*) FROM
|
||||
(SELECT place_id, address_place_id, count(*) as c
|
||||
@ -54,7 +55,7 @@ def add_data_to_planet_relations(context):
|
||||
with context.db.cursor() as cur:
|
||||
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
|
||||
row = cur.fetchone()
|
||||
if row is None or row[0] == '1':
|
||||
if row is None or row['value'] == '1':
|
||||
for r in context.table:
|
||||
last_node = 0
|
||||
last_way = 0
|
||||
@ -96,8 +97,8 @@ def add_data_to_planet_relations(context):
|
||||
|
||||
cur.execute("""INSERT INTO planet_osm_rels (id, tags, members)
|
||||
VALUES (%s, %s, %s)""",
|
||||
(r['id'], psycopg2.extras.Json(tags),
|
||||
psycopg2.extras.Json(members)))
|
||||
(r['id'], psycopg.types.json.Json(tags),
|
||||
psycopg.types.json.Json(members)))
|
||||
|
||||
@given("the ways")
|
||||
def add_data_to_planet_ways(context):
|
||||
@ -107,10 +108,10 @@ def add_data_to_planet_ways(context):
|
||||
with context.db.cursor() as cur:
|
||||
cur.execute("SELECT value FROM osm2pgsql_properties WHERE property = 'db_format'")
|
||||
row = cur.fetchone()
|
||||
json_tags = row is not None and row[0] != '1'
|
||||
json_tags = row is not None and row['value'] != '1'
|
||||
for r in context.table:
|
||||
if json_tags:
|
||||
tags = psycopg2.extras.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
|
||||
tags = psycopg.types.json.Json({h[5:]: r[h] for h in r.headings if h.startswith("tags+")})
|
||||
else:
|
||||
tags = list(chain.from_iterable([(h[5:], r[h])
|
||||
for h in r.headings if h.startswith("tags+")]))
|
||||
@ -197,7 +198,7 @@ def check_place_contents(context, table, exact):
|
||||
expected rows are expected to be present with at least one database row.
|
||||
When 'exactly' is given, there must not be additional rows in the database.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
expected_content = set()
|
||||
for row in context.table:
|
||||
nid = NominatimID(row['object'])
|
||||
@ -215,8 +216,9 @@ def check_place_contents(context, table, exact):
|
||||
DBRow(nid, res, context).assert_row(row, ['object'])
|
||||
|
||||
if exact:
|
||||
cur.execute('SELECT osm_type, osm_id, class from {}'.format(table))
|
||||
actual = set([(r[0], r[1], r[2]) for r in cur])
|
||||
cur.execute(pysql.SQL('SELECT osm_type, osm_id, class from')
|
||||
+ pysql.Identifier(table))
|
||||
actual = set([(r['osm_type'], r['osm_id'], r['class']) for r in cur])
|
||||
assert expected_content == actual, \
|
||||
f"Missing entries: {expected_content - actual}\n" \
|
||||
f"Not expected in table: {actual - expected_content}"
|
||||
@ -227,7 +229,7 @@ def check_place_has_entry(context, table, oid):
|
||||
""" Ensure that no database row for the given object exists. The ID
|
||||
must be of the form '<NRW><osm id>[:<class>]'.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
NominatimID(oid).query_osm_id(cur, "SELECT * FROM %s where {}" % table)
|
||||
assert cur.rowcount == 0, \
|
||||
"Found {} entries for ID {}".format(cur.rowcount, oid)
|
||||
@ -244,7 +246,7 @@ def check_search_name_contents(context, exclude):
|
||||
tokenizer = tokenizer_factory.get_tokenizer_for_db(context.nominatim.get_test_config())
|
||||
|
||||
with tokenizer.name_analyzer() as analyzer:
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
for row in context.table:
|
||||
nid = NominatimID(row['object'])
|
||||
nid.row_by_place_id(cur, 'search_name',
|
||||
@ -276,7 +278,7 @@ def check_search_name_has_entry(context, oid):
|
||||
""" Check that there is noentry in the search_name table for the given
|
||||
objects. IDs are in format '<NRW><osm id>[:<class>]'.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
NominatimID(oid).row_by_place_id(cur, 'search_name')
|
||||
|
||||
assert cur.rowcount == 0, \
|
||||
@ -290,7 +292,7 @@ def check_location_postcode(context):
|
||||
All rows must be present as excepted and there must not be additional
|
||||
rows.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
cur.execute("SELECT *, ST_AsText(geometry) as geomtxt FROM location_postcode")
|
||||
assert cur.rowcount == len(list(context.table)), \
|
||||
"Postcode table has {} rows, expected {}.".format(cur.rowcount, len(list(context.table)))
|
||||
@ -321,7 +323,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
|
||||
|
||||
plist.sort()
|
||||
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
if nctx.tokenizer != 'legacy':
|
||||
cur.execute("SELECT word FROM word WHERE type = 'P' and word = any(%s)",
|
||||
(plist,))
|
||||
@ -330,7 +332,7 @@ def check_word_table_for_postcodes(context, exclude, postcodes):
|
||||
and class = 'place' and type = 'postcode'""",
|
||||
(plist,))
|
||||
|
||||
found = [row[0] for row in cur]
|
||||
found = [row['word'] for row in cur]
|
||||
assert len(found) == len(set(found)), f"Duplicate rows for postcodes: {found}"
|
||||
|
||||
if exclude:
|
||||
@ -347,7 +349,7 @@ def check_place_addressline(context):
|
||||
representing the addressee and the 'address' column, representing the
|
||||
address item.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
for row in context.table:
|
||||
nid = NominatimID(row['object'])
|
||||
pid = nid.get_place_id(cur)
|
||||
@ -366,7 +368,7 @@ def check_place_addressline_exclude(context):
|
||||
""" Check that the place_addressline doesn't contain any entries for the
|
||||
given addressee/address item pairs.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
for row in context.table:
|
||||
pid = NominatimID(row['object']).get_place_id(cur)
|
||||
apid = NominatimID(row['address']).get_place_id(cur, allow_empty=True)
|
||||
@ -381,7 +383,7 @@ def check_place_addressline_exclude(context):
|
||||
def check_location_property_osmline(context, oid, neg):
|
||||
""" Check that the given way is present in the interpolation table.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
cur.execute("""SELECT *, ST_AsText(linegeo) as geomtxt
|
||||
FROM location_property_osmline
|
||||
WHERE osm_id = %s AND startnumber IS NOT NULL""",
|
||||
@ -417,7 +419,7 @@ def check_place_contents(context, exact):
|
||||
expected rows are expected to be present with at least one database row.
|
||||
When 'exactly' is given, there must not be additional rows in the database.
|
||||
"""
|
||||
with context.db.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||
with context.db.cursor() as cur:
|
||||
expected_content = set()
|
||||
for row in context.table:
|
||||
if ':' in row['object']:
|
||||
@ -447,7 +449,7 @@ def check_place_contents(context, exact):
|
||||
|
||||
if exact:
|
||||
cur.execute('SELECT osm_id, startnumber from location_property_osmline')
|
||||
actual = set([(r[0], r[1]) for r in cur])
|
||||
actual = set([(r['osm_id'], r['startnumber']) for r in cur])
|
||||
assert expected_content == actual, \
|
||||
f"Missing entries: {expected_content - actual}\n" \
|
||||
f"Not expected in table: {actual - expected_content}"
|
||||
|
@ -10,6 +10,9 @@ Functions to facilitate accessing and comparing the content of DB tables.
|
||||
import re
|
||||
import json
|
||||
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from steps.check_functions import Almost
|
||||
|
||||
ID_REGEX = re.compile(r"(?P<typ>[NRW])(?P<oid>\d+)(:(?P<cls>\w+))?")
|
||||
@ -73,7 +76,7 @@ class NominatimID:
|
||||
assert cur.rowcount == 1, \
|
||||
"Place ID {!s} not unique. Found {} entries.".format(self, cur.rowcount)
|
||||
|
||||
return cur.fetchone()[0]
|
||||
return cur.fetchone()['place_id']
|
||||
|
||||
|
||||
class DBRow:
|
||||
@ -152,9 +155,10 @@ class DBRow:
|
||||
|
||||
def _has_centroid(self, expected):
|
||||
if expected == 'in geometry':
|
||||
with self.context.db.cursor() as cur:
|
||||
cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point({cx}, {cy}), 4326),
|
||||
ST_SetSRID('{geomtxt}'::geometry, 4326))""".format(**self.db_row))
|
||||
with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
cur.execute("""SELECT ST_Within(ST_SetSRID(ST_Point(%(cx)s, %(cy)s), 4326),
|
||||
ST_SetSRID(%(geomtxt)s::geometry, 4326))""",
|
||||
(self.db_row))
|
||||
return cur.fetchone()[0]
|
||||
|
||||
if ' ' in expected:
|
||||
@ -166,10 +170,11 @@ class DBRow:
|
||||
|
||||
def _has_geometry(self, expected):
|
||||
geom = self.context.osm.parse_geometry(expected)
|
||||
with self.context.db.cursor() as cur:
|
||||
cur.execute("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
|
||||
ST_SnapToGrid(ST_SetSRID('{}'::geometry, 4326), 0.00001, 0.00001))""".format(
|
||||
geom, self.db_row['geomtxt']))
|
||||
with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
cur.execute(pysql.SQL("""SELECT ST_Equals(ST_SnapToGrid({}, 0.00001, 0.00001),
|
||||
ST_SnapToGrid(ST_SetSRID({}::geometry, 4326), 0.00001, 0.00001))""")
|
||||
.format(pysql.SQL(geom),
|
||||
pysql.Literal(self.db_row['geomtxt'])))
|
||||
return cur.fetchone()[0]
|
||||
|
||||
def assert_msg(self, name, value):
|
||||
@ -209,7 +214,7 @@ class DBRow:
|
||||
if actual == 0:
|
||||
return "place ID 0"
|
||||
|
||||
with self.context.db.cursor() as cur:
|
||||
with self.context.db.cursor(row_factory=psycopg.rows.tuple_row) as cur:
|
||||
cur.execute("""SELECT osm_type, osm_id, class
|
||||
FROM placex WHERE place_id = %s""",
|
||||
(actual, ))
|
||||
|
@ -13,8 +13,6 @@ from pathlib import Path
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
import psycopg2.extras
|
||||
|
||||
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
|
||||
|
||||
import nominatim_api.v1.server_glue as glue
|
||||
@ -31,7 +29,6 @@ class TestDeletableEndPoint:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions):
|
||||
psycopg2.extras.register_hstore(temp_db_cursor)
|
||||
table_factory('import_polygon_delete',
|
||||
definition='osm_id bigint, osm_type char(1), class text, type text',
|
||||
content=[(345, 'N', 'boundary', 'administrative'),
|
||||
|
@ -14,8 +14,6 @@ from pathlib import Path
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
import psycopg2.extras
|
||||
|
||||
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
|
||||
|
||||
import nominatim_api.v1.server_glue as glue
|
||||
@ -32,8 +30,6 @@ class TestPolygonsEndPoint:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_deletable_table(self, temp_db_cursor, table_factory, temp_db_with_extensions):
|
||||
psycopg2.extras.register_hstore(temp_db_cursor)
|
||||
|
||||
self.now = dt.datetime.now()
|
||||
self.recent = dt.datetime.now() - dt.timedelta(days=3)
|
||||
|
||||
|
@ -25,6 +25,23 @@ class MockParamCapture:
|
||||
return self.return_value
|
||||
|
||||
|
||||
class AsyncMockParamCapture:
|
||||
""" Mock that records the parameters with which a function was called
|
||||
as well as the number of calls.
|
||||
"""
|
||||
def __init__(self, retval=0):
|
||||
self.called = 0
|
||||
self.return_value = retval
|
||||
self.last_args = None
|
||||
self.last_kwargs = None
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
self.called += 1
|
||||
self.last_args = args
|
||||
self.last_kwargs = kwargs
|
||||
return self.return_value
|
||||
|
||||
|
||||
class DummyTokenizer:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.update_sql_functions_called = False
|
||||
@ -69,6 +86,17 @@ def mock_func_factory(monkeypatch):
|
||||
return get_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def async_mock_func_factory(monkeypatch):
|
||||
def get_mock(module, func):
|
||||
mock = AsyncMockParamCapture()
|
||||
mock.func_name = func
|
||||
monkeypatch.setattr(module, func, mock)
|
||||
return mock
|
||||
|
||||
return get_mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cli_tokenizer_mock(monkeypatch):
|
||||
tok = DummyTokenizer()
|
||||
|
@ -17,6 +17,7 @@ import pytest
|
||||
import nominatim_db.indexer.indexer
|
||||
import nominatim_db.tools.add_osm_data
|
||||
import nominatim_db.tools.freeze
|
||||
import nominatim_db.tools.tiger_data
|
||||
|
||||
|
||||
def test_cli_help(cli_call, capsys):
|
||||
@ -52,8 +53,8 @@ def test_cli_add_data_object_command(cli_call, mock_func_factory, name, oid):
|
||||
|
||||
|
||||
|
||||
def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, mock_func_factory):
|
||||
mock = mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data')
|
||||
def test_cli_add_data_tiger_data(cli_call, cli_tokenizer_mock, async_mock_func_factory):
|
||||
mock = async_mock_func_factory(nominatim_db.tools.tiger_data, 'add_tiger_data')
|
||||
|
||||
assert cli_call('add-data', '--tiger-data', 'somewhere') == 0
|
||||
|
||||
@ -68,38 +69,6 @@ def test_cli_serve_php(cli_call, mock_func_factory):
|
||||
assert func.called == 1
|
||||
|
||||
|
||||
def test_cli_serve_starlette_custom_server(cli_call, mock_func_factory):
|
||||
pytest.importorskip("starlette")
|
||||
mod = pytest.importorskip("uvicorn")
|
||||
func = mock_func_factory(mod, "run")
|
||||
|
||||
cli_call('serve', '--engine', 'starlette', '--server', 'foobar:4545') == 0
|
||||
|
||||
assert func.called == 1
|
||||
assert func.last_kwargs['host'] == 'foobar'
|
||||
assert func.last_kwargs['port'] == 4545
|
||||
|
||||
|
||||
def test_cli_serve_starlette_custom_server_bad_port(cli_call, mock_func_factory):
|
||||
pytest.importorskip("starlette")
|
||||
mod = pytest.importorskip("uvicorn")
|
||||
func = mock_func_factory(mod, "run")
|
||||
|
||||
cli_call('serve', '--engine', 'starlette', '--server', 'foobar:45:45') == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("engine", ['falcon', 'starlette'])
|
||||
def test_cli_serve_uvicorn_based(cli_call, engine, mock_func_factory):
|
||||
pytest.importorskip(engine)
|
||||
mod = pytest.importorskip("uvicorn")
|
||||
func = mock_func_factory(mod, "run")
|
||||
|
||||
cli_call('serve', '--engine', engine) == 0
|
||||
|
||||
assert func.called == 1
|
||||
assert func.last_kwargs['host'] == '127.0.0.1'
|
||||
assert func.last_kwargs['port'] == 8088
|
||||
|
||||
|
||||
class TestCliWithDb:
|
||||
|
||||
@ -120,16 +89,19 @@ class TestCliWithDb:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("params,do_bnds,do_ranks", [
|
||||
([], 1, 1),
|
||||
(['--boundaries-only'], 1, 0),
|
||||
(['--no-boundaries'], 0, 1),
|
||||
([], 2, 2),
|
||||
(['--boundaries-only'], 2, 0),
|
||||
(['--no-boundaries'], 0, 2),
|
||||
(['--boundaries-only', '--no-boundaries'], 0, 0)])
|
||||
def test_index_command(self, mock_func_factory, table_factory,
|
||||
def test_index_command(self, monkeypatch, async_mock_func_factory, table_factory,
|
||||
params, do_bnds, do_ranks):
|
||||
table_factory('import_status', 'indexed bool')
|
||||
bnd_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries')
|
||||
rank_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank')
|
||||
postcode_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
|
||||
bnd_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_boundaries')
|
||||
rank_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_by_rank')
|
||||
postcode_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
|
||||
|
||||
monkeypatch.setattr(nominatim_db.indexer.indexer.Indexer, 'has_pending',
|
||||
[False, True].pop)
|
||||
|
||||
assert self.call_nominatim('index', *params) == 0
|
||||
|
||||
|
@ -34,7 +34,8 @@ class TestCliImportWithDb:
|
||||
|
||||
|
||||
@pytest.mark.parametrize('with_updates', [True, False])
|
||||
def test_import_full(self, mock_func_factory, with_updates, place_table, property_table):
|
||||
def test_import_full(self, mock_func_factory, async_mock_func_factory,
|
||||
with_updates, place_table, property_table):
|
||||
mocks = [
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'setup_database_skeleton'),
|
||||
mock_func_factory(nominatim_db.data.country_info, 'setup_country_tables'),
|
||||
@ -42,15 +43,15 @@ class TestCliImportWithDb:
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'import_wikipedia_articles'),
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'import_secondary_importance'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
|
||||
async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_tables'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_table_triggers'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_partition_tables'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'load_address_levels_from_config'),
|
||||
mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'),
|
||||
mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
|
||||
async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
|
||||
]
|
||||
|
||||
@ -72,14 +73,14 @@ class TestCliImportWithDb:
|
||||
assert mock.called == 1, "Mock '{}' not called".format(mock.func_name)
|
||||
|
||||
|
||||
def test_import_continue_load_data(self, mock_func_factory):
|
||||
def test_import_continue_load_data(self, mock_func_factory, async_mock_func_factory):
|
||||
mocks = [
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'truncate_data_tables'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
async_mock_func_factory(nominatim_db.tools.database_import, 'load_data'),
|
||||
async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
|
||||
mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes'),
|
||||
mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
|
||||
async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
|
||||
mock_func_factory(nominatim_db.db.properties, 'set_property')
|
||||
]
|
||||
@ -91,12 +92,12 @@ class TestCliImportWithDb:
|
||||
assert mock.called == 1, "Mock '{}' not called".format(mock.func_name)
|
||||
|
||||
|
||||
def test_import_continue_indexing(self, mock_func_factory, placex_table,
|
||||
temp_db_conn):
|
||||
def test_import_continue_indexing(self, mock_func_factory, async_mock_func_factory,
|
||||
placex_table, temp_db_conn):
|
||||
mocks = [
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
|
||||
mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
|
||||
async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full'),
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
|
||||
mock_func_factory(nominatim_db.db.properties, 'set_property')
|
||||
]
|
||||
@ -110,9 +111,9 @@ class TestCliImportWithDb:
|
||||
assert self.call_nominatim('import', '--continue', 'indexing') == 0
|
||||
|
||||
|
||||
def test_import_continue_postprocess(self, mock_func_factory):
|
||||
def test_import_continue_postprocess(self, mock_func_factory, async_mock_func_factory):
|
||||
mocks = [
|
||||
mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
async_mock_func_factory(nominatim_db.tools.database_import, 'create_search_indices'),
|
||||
mock_func_factory(nominatim_db.data.country_info, 'create_country_names'),
|
||||
mock_func_factory(nominatim_db.tools.refresh, 'setup_website'),
|
||||
mock_func_factory(nominatim_db.db.properties, 'set_property')
|
||||
|
@ -45,9 +45,9 @@ class TestRefresh:
|
||||
assert self.tokenizer_mock.update_word_tokens_called
|
||||
|
||||
|
||||
def test_refresh_postcodes(self, mock_func_factory, place_table):
|
||||
def test_refresh_postcodes(self, async_mock_func_factory, mock_func_factory, place_table):
|
||||
func_mock = mock_func_factory(nominatim_db.tools.postcodes, 'update_postcodes')
|
||||
idx_mock = mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
|
||||
idx_mock = async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_postcodes')
|
||||
|
||||
assert self.call_nominatim('refresh', '--postcodes') == 0
|
||||
assert func_mock.called == 1
|
||||
|
@ -47,8 +47,8 @@ def init_status(temp_db_conn, status_table):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def index_mock(mock_func_factory, tokenizer_mock, init_status):
|
||||
return mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full')
|
||||
def index_mock(async_mock_func_factory, tokenizer_mock, init_status):
|
||||
return async_mock_func_factory(nominatim_db.indexer.indexer.Indexer, 'index_full')
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -8,7 +8,8 @@ import itertools
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import psycopg2
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
import pytest
|
||||
|
||||
# always test against the source
|
||||
@ -36,26 +37,23 @@ def temp_db(monkeypatch):
|
||||
exported into NOMINATIM_DATABASE_DSN.
|
||||
"""
|
||||
name = 'test_nominatim_python_unittest'
|
||||
conn = psycopg2.connect(database='postgres')
|
||||
|
||||
conn.set_isolation_level(0)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
|
||||
cur.execute('CREATE DATABASE {}'.format(name))
|
||||
|
||||
conn.close()
|
||||
with psycopg.connect(dbname='postgres', autocommit=True) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(pysql.SQL('DROP DATABASE IF EXISTS') + pysql.Identifier(name))
|
||||
cur.execute(pysql.SQL('CREATE DATABASE') + pysql.Identifier(name))
|
||||
|
||||
monkeypatch.setenv('NOMINATIM_DATABASE_DSN', 'dbname=' + name)
|
||||
|
||||
with psycopg.connect(dbname=name) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('CREATE EXTENSION hstore')
|
||||
|
||||
yield name
|
||||
|
||||
conn = psycopg2.connect(database='postgres')
|
||||
|
||||
conn.set_isolation_level(0)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
|
||||
|
||||
conn.close()
|
||||
with psycopg.connect(dbname='postgres', autocommit=True) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('DROP DATABASE IF EXISTS {}'.format(name))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -65,11 +63,9 @@ def dsn(temp_db):
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_with_extensions(temp_db):
|
||||
conn = psycopg2.connect(database=temp_db)
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('CREATE EXTENSION hstore; CREATE EXTENSION postgis;')
|
||||
conn.commit()
|
||||
conn.close()
|
||||
with psycopg.connect(dbname=temp_db) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('CREATE EXTENSION postgis')
|
||||
|
||||
return temp_db
|
||||
|
||||
@ -77,7 +73,8 @@ def temp_db_with_extensions(temp_db):
|
||||
def temp_db_conn(temp_db):
|
||||
""" Connection to the test database.
|
||||
"""
|
||||
with connection.connect('dbname=' + temp_db) as conn:
|
||||
with connection.connect('', autocommit=True, dbname=temp_db) as conn:
|
||||
connection.register_hstore(conn)
|
||||
yield conn
|
||||
|
||||
|
||||
@ -86,22 +83,25 @@ def temp_db_cursor(temp_db):
|
||||
""" Connection and cursor towards the test database. The connection will
|
||||
be in auto-commit mode.
|
||||
"""
|
||||
conn = psycopg2.connect('dbname=' + temp_db)
|
||||
conn.set_isolation_level(0)
|
||||
with conn.cursor(cursor_factory=CursorForTesting) as cur:
|
||||
yield cur
|
||||
conn.close()
|
||||
with psycopg.connect(dbname=temp_db, autocommit=True, cursor_factory=CursorForTesting) as conn:
|
||||
connection.register_hstore(conn)
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def table_factory(temp_db_cursor):
|
||||
def table_factory(temp_db_conn):
|
||||
""" A fixture that creates new SQL tables, potentially filled with
|
||||
content.
|
||||
"""
|
||||
def mk_table(name, definition='id INT', content=None):
|
||||
temp_db_cursor.execute('CREATE TABLE {} ({})'.format(name, definition))
|
||||
if content is not None:
|
||||
temp_db_cursor.execute_values("INSERT INTO {} VALUES %s".format(name), content)
|
||||
with psycopg.ClientCursor(temp_db_conn) as cur:
|
||||
cur.execute('CREATE TABLE {} ({})'.format(name, definition))
|
||||
if content:
|
||||
sql = pysql.SQL("INSERT INTO {} VALUES ({})")\
|
||||
.format(pysql.Identifier(name),
|
||||
pysql.SQL(',').join([pysql.Placeholder() for _ in range(len(content[0]))]))
|
||||
cur.executemany(sql , content)
|
||||
|
||||
return mk_table
|
||||
|
||||
@ -168,7 +168,6 @@ def place_row(place_table, temp_db_cursor):
|
||||
""" A factory for rows in the place table. The table is created as a
|
||||
prerequisite to the fixture.
|
||||
"""
|
||||
psycopg2.extras.register_hstore(temp_db_cursor)
|
||||
idseq = itertools.count(1001)
|
||||
def _insert(osm_type='N', osm_id=None, cls='amenity', typ='cafe', names=None,
|
||||
admin_level=None, address=None, extratags=None, geom=None):
|
||||
|
@ -5,11 +5,11 @@
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Specialised psycopg2 cursor with shortcut functions useful for testing.
|
||||
Specialised psycopg cursor with shortcut functions useful for testing.
|
||||
"""
|
||||
import psycopg2.extras
|
||||
import psycopg
|
||||
|
||||
class CursorForTesting(psycopg2.extras.DictCursor):
|
||||
class CursorForTesting(psycopg.Cursor):
|
||||
""" Extension to the DictCursor class that provides execution
|
||||
short-cuts that simplify writing assertions.
|
||||
"""
|
||||
@ -59,9 +59,3 @@ class CursorForTesting(psycopg2.extras.DictCursor):
|
||||
return self.scalar('SELECT count(*) FROM ' + table)
|
||||
|
||||
return self.scalar('SELECT count(*) FROM {} WHERE {}'.format(table, where))
|
||||
|
||||
|
||||
def execute_values(self, *args, **kwargs):
|
||||
""" Execute the execute_values() function on the cursor.
|
||||
"""
|
||||
psycopg2.extras.execute_values(self, *args, **kwargs)
|
||||
|
@ -1,113 +0,0 @@
|
||||
# SPDX-License-Identifier: GPL-3.0-or-later
|
||||
#
|
||||
# This file is part of Nominatim. (https://nominatim.org)
|
||||
#
|
||||
# Copyright (C) 2024 by the Nominatim developer community.
|
||||
# For a full list of authors see the git log.
|
||||
"""
|
||||
Tests for function providing a non-blocking query interface towards PostgreSQL.
|
||||
"""
|
||||
from contextlib import closing
|
||||
import concurrent.futures
|
||||
|
||||
import pytest
|
||||
import psycopg2
|
||||
|
||||
from nominatim_db.db.async_connection import DBConnection, DeadlockHandler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def conn(temp_db):
|
||||
with closing(DBConnection('dbname=' + temp_db)) as connection:
|
||||
yield connection
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_conns(temp_db):
|
||||
conn1 = psycopg2.connect('dbname=' + temp_db)
|
||||
conn2 = psycopg2.connect('dbname=' + temp_db)
|
||||
|
||||
yield conn1.cursor(), conn2.cursor()
|
||||
|
||||
conn1.close()
|
||||
conn2.close()
|
||||
|
||||
|
||||
def test_simple_query(conn, temp_db_conn):
|
||||
conn.connect()
|
||||
|
||||
conn.perform('CREATE TABLE foo (id INT)')
|
||||
conn.wait()
|
||||
|
||||
temp_db_conn.table_exists('foo')
|
||||
|
||||
|
||||
def test_wait_for_query(conn):
|
||||
conn.connect()
|
||||
|
||||
conn.perform('SELECT pg_sleep(1)')
|
||||
|
||||
assert not conn.is_done()
|
||||
|
||||
conn.wait()
|
||||
|
||||
|
||||
def test_bad_query(conn):
|
||||
conn.connect()
|
||||
|
||||
conn.perform('SELECT efasfjsea')
|
||||
|
||||
with pytest.raises(psycopg2.ProgrammingError):
|
||||
conn.wait()
|
||||
|
||||
|
||||
def test_bad_query_ignore(temp_db):
|
||||
with closing(DBConnection('dbname=' + temp_db, ignore_sql_errors=True)) as conn:
|
||||
conn.connect()
|
||||
|
||||
conn.perform('SELECT efasfjsea')
|
||||
|
||||
conn.wait()
|
||||
|
||||
|
||||
def exec_with_deadlock(cur, sql, detector):
|
||||
with DeadlockHandler(lambda *args: detector.append(1)):
|
||||
cur.execute(sql)
|
||||
|
||||
|
||||
def test_deadlock(simple_conns):
|
||||
cur1, cur2 = simple_conns
|
||||
|
||||
cur1.execute("""CREATE TABLE t1 (id INT PRIMARY KEY, t TEXT);
|
||||
INSERT into t1 VALUES (1, 'a'), (2, 'b')""")
|
||||
cur1.connection.commit()
|
||||
|
||||
cur1.execute("UPDATE t1 SET t = 'x' WHERE id = 1")
|
||||
cur2.execute("UPDATE t1 SET t = 'x' WHERE id = 2")
|
||||
|
||||
# This is the tricky part of the test. The first SQL command runs into
|
||||
# a lock and blocks, so we have to run it in a separate thread. When the
|
||||
# second deadlocking SQL statement is issued, Postgresql will abort one of
|
||||
# the two transactions that cause the deadlock. There is no way to tell
|
||||
# which one of the two. Therefore wrap both in a DeadlockHandler and
|
||||
# expect that exactly one of the two triggers.
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
deadlock_check = []
|
||||
try:
|
||||
future = executor.submit(exec_with_deadlock, cur2,
|
||||
"UPDATE t1 SET t = 'y' WHERE id = 1",
|
||||
deadlock_check)
|
||||
|
||||
while not future.running():
|
||||
pass
|
||||
|
||||
|
||||
exec_with_deadlock(cur1, "UPDATE t1 SET t = 'y' WHERE id = 2",
|
||||
deadlock_check)
|
||||
finally:
|
||||
# Whatever happens, make sure the deadlock gets resolved.
|
||||
cur1.connection.rollback()
|
||||
|
||||
future.result()
|
||||
|
||||
assert len(deadlock_check) == 1
|
@ -8,63 +8,76 @@
|
||||
Tests for specialised connection and cursor classes.
|
||||
"""
|
||||
import pytest
|
||||
import psycopg2
|
||||
import psycopg
|
||||
|
||||
from nominatim_db.db.connection import connect, get_pg_env
|
||||
import nominatim_db.db.connection as nc
|
||||
|
||||
@pytest.fixture
|
||||
def db(dsn):
|
||||
with connect(dsn) as conn:
|
||||
with nc.connect(dsn) as conn:
|
||||
yield conn
|
||||
|
||||
|
||||
def test_connection_table_exists(db, table_factory):
|
||||
assert not db.table_exists('foobar')
|
||||
assert not nc.table_exists(db, 'foobar')
|
||||
|
||||
table_factory('foobar')
|
||||
|
||||
assert db.table_exists('foobar')
|
||||
assert nc.table_exists(db, 'foobar')
|
||||
|
||||
|
||||
def test_has_column_no_table(db):
|
||||
assert not db.table_has_column('sometable', 'somecolumn')
|
||||
assert not nc.table_has_column(db, 'sometable', 'somecolumn')
|
||||
|
||||
|
||||
@pytest.mark.parametrize('name,result', [('tram', True), ('car', False)])
|
||||
def test_has_column(db, table_factory, name, result):
|
||||
table_factory('stuff', 'tram TEXT')
|
||||
|
||||
assert db.table_has_column('stuff', name) == result
|
||||
assert nc.table_has_column(db, 'stuff', name) == result
|
||||
|
||||
def test_connection_index_exists(db, table_factory, temp_db_cursor):
|
||||
assert not db.index_exists('some_index')
|
||||
assert not nc.index_exists(db, 'some_index')
|
||||
|
||||
table_factory('foobar')
|
||||
temp_db_cursor.execute('CREATE INDEX some_index ON foobar(id)')
|
||||
|
||||
assert db.index_exists('some_index')
|
||||
assert db.index_exists('some_index', table='foobar')
|
||||
assert not db.index_exists('some_index', table='bar')
|
||||
assert nc.index_exists(db, 'some_index')
|
||||
assert nc.index_exists(db, 'some_index', table='foobar')
|
||||
assert not nc.index_exists(db, 'some_index', table='bar')
|
||||
|
||||
|
||||
def test_drop_table_existing(db, table_factory):
|
||||
table_factory('dummy')
|
||||
assert db.table_exists('dummy')
|
||||
assert nc.table_exists(db, 'dummy')
|
||||
|
||||
db.drop_table('dummy')
|
||||
assert not db.table_exists('dummy')
|
||||
nc.drop_tables(db, 'dummy')
|
||||
assert not nc.table_exists(db, 'dummy')
|
||||
|
||||
|
||||
def test_drop_table_non_existsing(db):
|
||||
db.drop_table('dfkjgjriogjigjgjrdghehtre')
|
||||
def test_drop_table_non_existing(db):
|
||||
nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre')
|
||||
|
||||
|
||||
def test_drop_many_tables(db, table_factory):
|
||||
tables = [f'table{n}' for n in range(5)]
|
||||
|
||||
for t in tables:
|
||||
table_factory(t)
|
||||
assert nc.table_exists(db, t)
|
||||
|
||||
nc.drop_tables(db, *tables)
|
||||
|
||||
for t in tables:
|
||||
assert not nc.table_exists(db, t)
|
||||
|
||||
|
||||
def test_drop_table_non_existing_force(db):
|
||||
with pytest.raises(psycopg2.ProgrammingError, match='.*does not exist.*'):
|
||||
db.drop_table('dfkjgjriogjigjgjrdghehtre', if_exists=False)
|
||||
with pytest.raises(psycopg.ProgrammingError, match='.*does not exist.*'):
|
||||
nc.drop_tables(db, 'dfkjgjriogjigjgjrdghehtre', if_exists=False)
|
||||
|
||||
def test_connection_server_version_tuple(db):
|
||||
ver = db.server_version_tuple()
|
||||
ver = nc.server_version_tuple(db)
|
||||
|
||||
assert isinstance(ver, tuple)
|
||||
assert len(ver) == 2
|
||||
@ -72,7 +85,7 @@ def test_connection_server_version_tuple(db):
|
||||
|
||||
|
||||
def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
|
||||
ver = db.postgis_version_tuple()
|
||||
ver = nc.postgis_version_tuple(db)
|
||||
|
||||
assert isinstance(ver, tuple)
|
||||
assert len(ver) == 2
|
||||
@ -82,27 +95,24 @@ def test_connection_postgis_version_tuple(db, temp_db_with_extensions):
|
||||
def test_cursor_scalar(db, table_factory):
|
||||
table_factory('dummy')
|
||||
|
||||
with db.cursor() as cur:
|
||||
assert cur.scalar('SELECT count(*) FROM dummy') == 0
|
||||
assert nc.execute_scalar(db, 'SELECT count(*) FROM dummy') == 0
|
||||
|
||||
|
||||
def test_cursor_scalar_many_rows(db):
|
||||
with db.cursor() as cur:
|
||||
with pytest.raises(RuntimeError):
|
||||
cur.scalar('SELECT * FROM pg_tables')
|
||||
with pytest.raises(RuntimeError, match='Query did not return a single row.'):
|
||||
nc.execute_scalar(db, 'SELECT * FROM pg_tables')
|
||||
|
||||
|
||||
def test_cursor_scalar_no_rows(db, table_factory):
|
||||
table_factory('dummy')
|
||||
|
||||
with db.cursor() as cur:
|
||||
with pytest.raises(RuntimeError):
|
||||
cur.scalar('SELECT id FROM dummy')
|
||||
with pytest.raises(RuntimeError, match='Query did not return a single row.'):
|
||||
nc.execute_scalar(db, 'SELECT id FROM dummy')
|
||||
|
||||
|
||||
def test_get_pg_env_add_variable(monkeypatch):
|
||||
monkeypatch.delenv('PGPASSWORD', raising=False)
|
||||
env = get_pg_env('user=fooF')
|
||||
env = nc.get_pg_env('user=fooF')
|
||||
|
||||
assert env['PGUSER'] == 'fooF'
|
||||
assert 'PGPASSWORD' not in env
|
||||
@ -110,12 +120,12 @@ def test_get_pg_env_add_variable(monkeypatch):
|
||||
|
||||
def test_get_pg_env_overwrite_variable(monkeypatch):
|
||||
monkeypatch.setenv('PGUSER', 'some default')
|
||||
env = get_pg_env('user=overwriter')
|
||||
env = nc.get_pg_env('user=overwriter')
|
||||
|
||||
assert env['PGUSER'] == 'overwriter'
|
||||
|
||||
|
||||
def test_get_pg_env_ignore_unknown():
|
||||
env = get_pg_env('client_encoding=stuff', base_env={})
|
||||
env = nc.get_pg_env('client_encoding=stuff', base_env={})
|
||||
|
||||
assert env == {}
|
||||
|
@ -8,6 +8,7 @@
|
||||
Tests for SQL preprocessing.
|
||||
"""
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nominatim_db.db.sql_preprocessor import SQLPreprocessor
|
||||
|
||||
@ -54,3 +55,17 @@ def test_load_file_with_params(sql_preprocessor, sql_factory, temp_db_conn, temp
|
||||
sql_preprocessor.run_sql_file(temp_db_conn, sqlfile, bar='XX', foo='ZZ')
|
||||
|
||||
assert temp_db_cursor.scalar('SELECT test()') == 'ZZ XX'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_parallel_file(dsn, sql_preprocessor, tmp_path, temp_db_cursor):
|
||||
(tmp_path / 'test.sql').write_text("""
|
||||
CREATE TABLE foo (a TEXT);
|
||||
CREATE TABLE foo2(a TEXT);""" +
|
||||
"\n---\nCREATE TABLE bar (b INT);")
|
||||
|
||||
await sql_preprocessor.run_parallel_sql_file(dsn, 'test.sql', num_threads=4)
|
||||
|
||||
assert temp_db_cursor.table_exists('foo')
|
||||
assert temp_db_cursor.table_exists('foo2')
|
||||
assert temp_db_cursor.table_exists('bar')
|
||||
|
@ -58,103 +58,3 @@ def test_execute_file_with_post_code(dsn, tmp_path, temp_db_cursor):
|
||||
db_utils.execute_file(dsn, tmpfile, post_code='INSERT INTO test VALUES(23)')
|
||||
|
||||
assert temp_db_cursor.row_set('SELECT * FROM test') == {(23, )}
|
||||
|
||||
|
||||
class TestCopyBuffer:
|
||||
TABLE_NAME = 'copytable'
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_table(self, table_factory):
|
||||
table_factory(self.TABLE_NAME, 'col_a INT, col_b TEXT')
|
||||
|
||||
|
||||
def table_rows(self, cursor):
|
||||
return cursor.row_set('SELECT * FROM ' + self.TABLE_NAME)
|
||||
|
||||
|
||||
def test_copybuffer_empty(self):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.copy_out(None, "dummy")
|
||||
|
||||
|
||||
def test_all_columns(self, temp_db_cursor):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.add(3, 'hum')
|
||||
buf.add(None, 'f\\t')
|
||||
|
||||
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
|
||||
|
||||
assert self.table_rows(temp_db_cursor) == {(3, 'hum'), (None, 'f\\t')}
|
||||
|
||||
|
||||
def test_selected_columns(self, temp_db_cursor):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.add('foo')
|
||||
|
||||
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
|
||||
columns=['col_b'])
|
||||
|
||||
assert self.table_rows(temp_db_cursor) == {(None, 'foo')}
|
||||
|
||||
|
||||
def test_reordered_columns(self, temp_db_cursor):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.add('one', 1)
|
||||
buf.add(' two ', 2)
|
||||
|
||||
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
|
||||
columns=['col_b', 'col_a'])
|
||||
|
||||
assert self.table_rows(temp_db_cursor) == {(1, 'one'), (2, ' two ')}
|
||||
|
||||
|
||||
def test_special_characters(self, temp_db_cursor):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.add('foo\tbar')
|
||||
buf.add('sun\nson')
|
||||
buf.add('\\N')
|
||||
|
||||
buf.copy_out(temp_db_cursor, self.TABLE_NAME,
|
||||
columns=['col_b'])
|
||||
|
||||
assert self.table_rows(temp_db_cursor) == {(None, 'foo\tbar'),
|
||||
(None, 'sun\nson'),
|
||||
(None, '\\N')}
|
||||
|
||||
|
||||
|
||||
class TestCopyBufferJson:
|
||||
TABLE_NAME = 'copytable'
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_test_table(self, table_factory):
|
||||
table_factory(self.TABLE_NAME, 'col_a INT, col_b JSONB')
|
||||
|
||||
|
||||
def table_rows(self, cursor):
|
||||
cursor.execute('SELECT * FROM ' + self.TABLE_NAME)
|
||||
results = {k: v for k,v in cursor}
|
||||
|
||||
assert len(results) == cursor.rowcount
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def test_json_object(self, temp_db_cursor):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.add(1, json.dumps({'test': 'value', 'number': 1}))
|
||||
|
||||
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
|
||||
|
||||
assert self.table_rows(temp_db_cursor) == \
|
||||
{1: {'test': 'value', 'number': 1}}
|
||||
|
||||
|
||||
def test_json_object_special_chras(self, temp_db_cursor):
|
||||
with db_utils.CopyBuffer() as buf:
|
||||
buf.add(1, json.dumps({'te\tst': 'va\nlue', 'nu"mber': None}))
|
||||
|
||||
buf.copy_out(temp_db_cursor, self.TABLE_NAME)
|
||||
|
||||
assert self.table_rows(temp_db_cursor) == \
|
||||
{1: {'te\tst': 'va\nlue', 'nu"mber': None}}
|
||||
|
@ -9,6 +9,7 @@ Tests for running the indexing.
|
||||
"""
|
||||
import itertools
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nominatim_db.indexer import indexer
|
||||
from nominatim_db.tokenizer import factory
|
||||
@ -21,9 +22,8 @@ class IndexerTestDB:
|
||||
self.postcode_id = itertools.count(700000)
|
||||
|
||||
self.conn = conn
|
||||
self.conn.set_isolation_level(0)
|
||||
self.conn.autocimmit = True
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute('CREATE EXTENSION hstore')
|
||||
cur.execute("""CREATE TABLE placex (place_id BIGINT,
|
||||
name HSTORE,
|
||||
class TEXT,
|
||||
@ -156,7 +156,8 @@ def test_tokenizer(tokenizer_mock, project_env):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", [1, 15])
|
||||
def test_index_all_by_rank(test_db, threads, test_tokenizer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_all_by_rank(test_db, threads, test_tokenizer):
|
||||
for rank in range(31):
|
||||
test_db.add_place(rank_address=rank, rank_search=rank)
|
||||
test_db.add_osmline()
|
||||
@ -165,7 +166,7 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer):
|
||||
assert test_db.osmline_unindexed() == 1
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
|
||||
idx.index_by_rank(0, 30)
|
||||
await idx.index_by_rank(0, 30)
|
||||
|
||||
assert test_db.placex_unindexed() == 0
|
||||
assert test_db.osmline_unindexed() == 0
|
||||
@ -190,7 +191,8 @@ def test_index_all_by_rank(test_db, threads, test_tokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", [1, 15])
|
||||
def test_index_partial_without_30(test_db, threads, test_tokenizer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_partial_without_30(test_db, threads, test_tokenizer):
|
||||
for rank in range(31):
|
||||
test_db.add_place(rank_address=rank, rank_search=rank)
|
||||
test_db.add_osmline()
|
||||
@ -200,7 +202,7 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer):
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest',
|
||||
test_tokenizer, threads)
|
||||
idx.index_by_rank(4, 15)
|
||||
await idx.index_by_rank(4, 15)
|
||||
|
||||
assert test_db.placex_unindexed() == 19
|
||||
assert test_db.osmline_unindexed() == 1
|
||||
@ -211,7 +213,8 @@ def test_index_partial_without_30(test_db, threads, test_tokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", [1, 15])
|
||||
def test_index_partial_with_30(test_db, threads, test_tokenizer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_partial_with_30(test_db, threads, test_tokenizer):
|
||||
for rank in range(31):
|
||||
test_db.add_place(rank_address=rank, rank_search=rank)
|
||||
test_db.add_osmline()
|
||||
@ -220,7 +223,7 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer):
|
||||
assert test_db.osmline_unindexed() == 1
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
|
||||
idx.index_by_rank(28, 30)
|
||||
await idx.index_by_rank(28, 30)
|
||||
|
||||
assert test_db.placex_unindexed() == 27
|
||||
assert test_db.osmline_unindexed() == 0
|
||||
@ -230,7 +233,8 @@ def test_index_partial_with_30(test_db, threads, test_tokenizer):
|
||||
WHERE indexed_status = 0 AND rank_address between 1 and 27""") == 0
|
||||
|
||||
@pytest.mark.parametrize("threads", [1, 15])
|
||||
def test_index_boundaries(test_db, threads, test_tokenizer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_boundaries(test_db, threads, test_tokenizer):
|
||||
for rank in range(4, 10):
|
||||
test_db.add_admin(rank_address=rank, rank_search=rank)
|
||||
for rank in range(31):
|
||||
@ -241,7 +245,7 @@ def test_index_boundaries(test_db, threads, test_tokenizer):
|
||||
assert test_db.osmline_unindexed() == 1
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
|
||||
idx.index_boundaries(0, 30)
|
||||
await idx.index_boundaries(0, 30)
|
||||
|
||||
assert test_db.placex_unindexed() == 31
|
||||
assert test_db.osmline_unindexed() == 1
|
||||
@ -252,21 +256,23 @@ def test_index_boundaries(test_db, threads, test_tokenizer):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", [1, 15])
|
||||
def test_index_postcodes(test_db, threads, test_tokenizer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_postcodes(test_db, threads, test_tokenizer):
|
||||
for postcode in range(1000):
|
||||
test_db.add_postcode('de', postcode)
|
||||
for postcode in range(32000, 33000):
|
||||
test_db.add_postcode('us', postcode)
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
|
||||
idx.index_postcodes()
|
||||
await idx.index_postcodes()
|
||||
|
||||
assert test_db.scalar("""SELECT count(*) FROM location_postcode
|
||||
WHERE indexed_status != 0""") == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("analyse", [True, False])
|
||||
def test_index_full(test_db, analyse, test_tokenizer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_full(test_db, analyse, test_tokenizer):
|
||||
for rank in range(4, 10):
|
||||
test_db.add_admin(rank_address=rank, rank_search=rank)
|
||||
for rank in range(31):
|
||||
@ -276,22 +282,9 @@ def test_index_full(test_db, analyse, test_tokenizer):
|
||||
test_db.add_postcode('de', postcode)
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, 4)
|
||||
idx.index_full(analyse=analyse)
|
||||
await idx.index_full(analyse=analyse)
|
||||
|
||||
assert test_db.placex_unindexed() == 0
|
||||
assert test_db.osmline_unindexed() == 0
|
||||
assert test_db.scalar("""SELECT count(*) FROM location_postcode
|
||||
WHERE indexed_status != 0""") == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", [1, 15])
|
||||
def test_index_reopen_connection(test_db, threads, monkeypatch, test_tokenizer):
|
||||
monkeypatch.setattr(indexer.WorkerPool, "REOPEN_CONNECTIONS_AFTER", 15)
|
||||
|
||||
for _ in range(1000):
|
||||
test_db.add_place(rank_address=30, rank_search=30)
|
||||
|
||||
idx = indexer.Indexer('dbname=test_nominatim_python_unittest', test_tokenizer, threads)
|
||||
idx.index_by_rank(28, 30)
|
||||
|
||||
assert test_db.placex_unindexed() == 0
|
||||
|
@ -8,6 +8,7 @@
|
||||
Legacy word table for testing with functions to prefil and test contents
|
||||
of the table.
|
||||
"""
|
||||
from nominatim_db.db.connection import execute_scalar
|
||||
|
||||
class MockIcuWordTable:
|
||||
""" A word table for testing using legacy word table structure.
|
||||
@ -35,9 +36,9 @@ class MockIcuWordTable:
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute("""INSERT INTO word (word_token, type, word, info)
|
||||
VALUES (%s, 'S', %s,
|
||||
json_build_object('class', %s,
|
||||
'type', %s,
|
||||
'op', %s))
|
||||
json_build_object('class', %s::text,
|
||||
'type', %s::text,
|
||||
'op', %s::text))
|
||||
""", (word_token, word, cls, typ, oper))
|
||||
self.conn.commit()
|
||||
|
||||
@ -70,25 +71,22 @@ class MockIcuWordTable:
|
||||
word = word_tokens[0]
|
||||
for token in word_tokens:
|
||||
cur.execute("""INSERT INTO word (word_id, word_token, type, word, info)
|
||||
VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s))
|
||||
VALUES (%s, %s, 'H', %s, jsonb_build_object('lookup', %s::text))
|
||||
""", (word_id, token, word, word_tokens[0]))
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
|
||||
def count(self):
|
||||
with self.conn.cursor() as cur:
|
||||
return cur.scalar("SELECT count(*) FROM word")
|
||||
return execute_scalar(self.conn, "SELECT count(*) FROM word")
|
||||
|
||||
|
||||
def count_special(self):
|
||||
with self.conn.cursor() as cur:
|
||||
return cur.scalar("SELECT count(*) FROM word WHERE type = 'S'")
|
||||
return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'S'")
|
||||
|
||||
|
||||
def count_housenumbers(self):
|
||||
with self.conn.cursor() as cur:
|
||||
return cur.scalar("SELECT count(*) FROM word WHERE type = 'H'")
|
||||
return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE type = 'H'")
|
||||
|
||||
|
||||
def get_special(self):
|
||||
|
@ -8,6 +8,7 @@
|
||||
Legacy word table for testing with functions to prefil and test contents
|
||||
of the table.
|
||||
"""
|
||||
from nominatim_db.db.connection import execute_scalar
|
||||
|
||||
class MockLegacyWordTable:
|
||||
""" A word table for testing using legacy word table structure.
|
||||
@ -58,18 +59,16 @@ class MockLegacyWordTable:
|
||||
|
||||
|
||||
def count(self):
|
||||
with self.conn.cursor() as cur:
|
||||
return cur.scalar("SELECT count(*) FROM word")
|
||||
return execute_scalar(self.conn, "SELECT count(*) FROM word")
|
||||
|
||||
|
||||
def count_special(self):
|
||||
with self.conn.cursor() as cur:
|
||||
return cur.scalar("SELECT count(*) FROM word WHERE class != 'place'")
|
||||
return execute_scalar(self.conn, "SELECT count(*) FROM word WHERE class != 'place'")
|
||||
|
||||
|
||||
def get_special(self):
|
||||
with self.conn.cursor() as cur:
|
||||
cur.execute("""SELECT word_token, word, class, type, operator
|
||||
cur.execute("""SELECT word_token, word, class as cls, type, operator
|
||||
FROM word WHERE class != 'place'""")
|
||||
result = set((tuple(row) for row in cur))
|
||||
assert len(result) == cur.rowcount, "Word table has duplicates."
|
||||
|
@ -9,8 +9,6 @@ Custom mocks for testing.
|
||||
"""
|
||||
import itertools
|
||||
|
||||
import psycopg2.extras
|
||||
|
||||
from nominatim_db.db import properties
|
||||
|
||||
# This must always point to the mock word table for the default tokenizer.
|
||||
@ -56,7 +54,6 @@ class MockPlacexTable:
|
||||
admin_level=None, address=None, extratags=None, geom='POINT(10 4)',
|
||||
country=None, housenumber=None, rank_search=30):
|
||||
with self.conn.cursor() as cur:
|
||||
psycopg2.extras.register_hstore(cur)
|
||||
cur.execute("""INSERT INTO placex (place_id, osm_type, osm_id, class,
|
||||
type, name, admin_level, address,
|
||||
housenumber, rank_search,
|
||||
|
@ -199,16 +199,16 @@ def test_update_sql_functions(db_prop, temp_db_cursor,
|
||||
assert test_content == set((('1133', ), ))
|
||||
|
||||
|
||||
def test_finalize_import(tokenizer_factory, temp_db_conn,
|
||||
temp_db_cursor, test_config, sql_preprocessor_cfg):
|
||||
def test_finalize_import(tokenizer_factory, temp_db_cursor,
|
||||
test_config, sql_preprocessor_cfg):
|
||||
tok = tokenizer_factory()
|
||||
tok.init_new_db(test_config)
|
||||
|
||||
assert not temp_db_conn.index_exists('idx_word_word_id')
|
||||
assert not temp_db_cursor.index_exists('word', 'idx_word_word_id')
|
||||
|
||||
tok.finalize_import(test_config)
|
||||
|
||||
assert temp_db_conn.index_exists('idx_word_word_id')
|
||||
assert temp_db_cursor.index_exists('word', 'idx_word_word_id')
|
||||
|
||||
|
||||
def test_check_database(test_config, tokenizer_factory,
|
||||
|
@ -8,10 +8,11 @@
|
||||
Tests for functions to import a new database.
|
||||
"""
|
||||
from pathlib import Path
|
||||
from contextlib import closing
|
||||
|
||||
import pytest
|
||||
import psycopg2
|
||||
import pytest_asyncio
|
||||
import psycopg
|
||||
from psycopg import sql as pysql
|
||||
|
||||
from nominatim_db.tools import database_import
|
||||
from nominatim_db.errors import UsageError
|
||||
@ -21,10 +22,7 @@ class TestDatabaseSetup:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_nonexistant_db(self):
|
||||
conn = psycopg2.connect(database='postgres')
|
||||
|
||||
try:
|
||||
conn.set_isolation_level(0)
|
||||
with psycopg.connect(dbname='postgres', autocommit=True) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}')
|
||||
|
||||
@ -32,22 +30,17 @@ class TestDatabaseSetup:
|
||||
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(f'DROP DATABASE IF EXISTS {self.DBNAME}')
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cursor(self):
|
||||
conn = psycopg2.connect(database=self.DBNAME)
|
||||
|
||||
try:
|
||||
with psycopg.connect(dbname=self.DBNAME) as conn:
|
||||
with conn.cursor() as cur:
|
||||
yield cur
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def conn(self):
|
||||
return closing(psycopg2.connect(database=self.DBNAME))
|
||||
return psycopg.connect(dbname=self.DBNAME)
|
||||
|
||||
|
||||
def test_setup_skeleton(self):
|
||||
@ -132,7 +125,7 @@ def test_import_osm_data_simple_ignore_no_data(table_factory, osm2pgsql_options)
|
||||
ignore_errors=True)
|
||||
|
||||
|
||||
def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_options):
|
||||
def test_import_osm_data_drop(table_factory, temp_db_cursor, tmp_path, osm2pgsql_options):
|
||||
table_factory('place', content=((1, ), ))
|
||||
table_factory('planet_osm_nodes')
|
||||
|
||||
@ -144,7 +137,7 @@ def test_import_osm_data_drop(table_factory, temp_db_conn, tmp_path, osm2pgsql_o
|
||||
database_import.import_osm_data(Path('file.pbf'), osm2pgsql_options, drop=True)
|
||||
|
||||
assert not flatfile.exists()
|
||||
assert not temp_db_conn.table_exists('planet_osm_nodes')
|
||||
assert not temp_db_cursor.table_exists('planet_osm_nodes')
|
||||
|
||||
|
||||
def test_import_osm_data_default_cache(table_factory, osm2pgsql_options, capfd):
|
||||
@ -178,18 +171,19 @@ def test_truncate_database_tables(temp_db_conn, temp_db_cursor, table_factory, w
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", (1, 5))
|
||||
def test_load_data(dsn, place_row, placex_table, osmline_table,
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_data(dsn, place_row, placex_table, osmline_table,
|
||||
temp_db_cursor, threads):
|
||||
for func in ('precompute_words', 'getorcreate_housenumber_id', 'make_standard_name'):
|
||||
temp_db_cursor.execute(f"""CREATE FUNCTION {func} (src TEXT)
|
||||
RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL
|
||||
""")
|
||||
temp_db_cursor.execute(pysql.SQL("""CREATE FUNCTION {} (src TEXT)
|
||||
RETURNS TEXT AS $$ SELECT 'a'::TEXT $$ LANGUAGE SQL
|
||||
""").format(pysql.Identifier(func)))
|
||||
for oid in range(100, 130):
|
||||
place_row(osm_id=oid)
|
||||
place_row(osm_type='W', osm_id=342, cls='place', typ='houses',
|
||||
geom='SRID=4326;LINESTRING(0 0, 10 10)')
|
||||
|
||||
database_import.load_data(dsn, threads)
|
||||
await database_import.load_data(dsn, threads)
|
||||
|
||||
assert temp_db_cursor.table_rows('placex') == 30
|
||||
assert temp_db_cursor.table_rows('location_property_osmline') == 1
|
||||
@ -241,11 +235,12 @@ class TestSetupSQL:
|
||||
|
||||
|
||||
@pytest.mark.parametrize("drop", [True, False])
|
||||
def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop):
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_search_indices(self, temp_db_conn, temp_db_cursor, drop):
|
||||
self.write_sql('indices.sql',
|
||||
"""CREATE FUNCTION test() RETURNS bool
|
||||
AS $$ SELECT {{drop}} $$ LANGUAGE SQL""")
|
||||
|
||||
database_import.create_search_indices(temp_db_conn, self.config, drop)
|
||||
await database_import.create_search_indices(temp_db_conn, self.config, drop)
|
||||
|
||||
temp_db_cursor.scalar('SELECT test()') == drop
|
||||
|
@ -75,7 +75,8 @@ def test_load_white_and_black_lists(sp_importer):
|
||||
assert isinstance(black_list, dict) and isinstance(white_list, dict)
|
||||
|
||||
|
||||
def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
|
||||
def test_create_place_classtype_indexes(temp_db_with_extensions,
|
||||
temp_db_conn, temp_db_cursor,
|
||||
table_factory, sp_importer):
|
||||
"""
|
||||
Test that _create_place_classtype_indexes() create the
|
||||
@ -88,10 +89,11 @@ def test_create_place_classtype_indexes(temp_db_with_extensions, temp_db_conn,
|
||||
table_factory(table_name, 'place_id BIGINT, centroid GEOMETRY')
|
||||
|
||||
sp_importer._create_place_classtype_indexes('', phrase_class, phrase_type)
|
||||
temp_db_conn.commit()
|
||||
|
||||
assert check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type)
|
||||
assert check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type)
|
||||
|
||||
def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
|
||||
def test_create_place_classtype_table(temp_db_conn, temp_db_cursor, placex_table, sp_importer):
|
||||
"""
|
||||
Test that _create_place_classtype_table() create
|
||||
the right place_classtype table.
|
||||
@ -99,10 +101,12 @@ def test_create_place_classtype_table(temp_db_conn, placex_table, sp_importer):
|
||||
phrase_class = 'class'
|
||||
phrase_type = 'type'
|
||||
sp_importer._create_place_classtype_table('', phrase_class, phrase_type)
|
||||
temp_db_conn.commit()
|
||||
|
||||
assert check_table_exist(temp_db_conn, phrase_class, phrase_type)
|
||||
assert check_table_exist(temp_db_cursor, phrase_class, phrase_type)
|
||||
|
||||
def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_importer):
|
||||
def test_grant_access_to_web_user(temp_db_conn, temp_db_cursor, table_factory,
|
||||
def_config, sp_importer):
|
||||
"""
|
||||
Test that _grant_access_to_webuser() give
|
||||
right access to the web user.
|
||||
@ -114,12 +118,13 @@ def test_grant_access_to_web_user(temp_db_conn, table_factory, def_config, sp_im
|
||||
table_factory(table_name)
|
||||
|
||||
sp_importer._grant_access_to_webuser(phrase_class, phrase_type)
|
||||
temp_db_conn.commit()
|
||||
|
||||
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
|
||||
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, phrase_class, phrase_type)
|
||||
|
||||
def test_create_place_classtype_table_and_indexes(
|
||||
temp_db_conn, def_config, placex_table,
|
||||
sp_importer):
|
||||
temp_db_cursor, def_config, placex_table,
|
||||
sp_importer, temp_db_conn):
|
||||
"""
|
||||
Test that _create_place_classtype_table_and_indexes()
|
||||
create the right place_classtype tables and place_id indexes
|
||||
@ -129,14 +134,15 @@ def test_create_place_classtype_table_and_indexes(
|
||||
pairs = set([('class1', 'type1'), ('class2', 'type2')])
|
||||
|
||||
sp_importer._create_classtype_table_and_indexes(pairs)
|
||||
temp_db_conn.commit()
|
||||
|
||||
for pair in pairs:
|
||||
assert check_table_exist(temp_db_conn, pair[0], pair[1])
|
||||
assert check_placeid_and_centroid_indexes(temp_db_conn, pair[0], pair[1])
|
||||
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, pair[0], pair[1])
|
||||
assert check_table_exist(temp_db_cursor, pair[0], pair[1])
|
||||
assert check_placeid_and_centroid_indexes(temp_db_cursor, pair[0], pair[1])
|
||||
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, pair[0], pair[1])
|
||||
|
||||
def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
|
||||
temp_db_conn):
|
||||
temp_db_conn, temp_db_cursor):
|
||||
"""
|
||||
Check for the remove_non_existent_phrases_from_db() method.
|
||||
|
||||
@ -159,15 +165,14 @@ def test_remove_non_existent_tables_from_db(sp_importer, default_phrases,
|
||||
"""
|
||||
|
||||
sp_importer._remove_non_existent_tables_from_db()
|
||||
temp_db_conn.commit()
|
||||
|
||||
# Changes are not committed yet. Use temp_db_conn for checking results.
|
||||
with temp_db_conn.cursor(cursor_factory=CursorForTesting) as cur:
|
||||
assert cur.row_set(query_tables) \
|
||||
assert temp_db_cursor.row_set(query_tables) \
|
||||
== {('place_classtype_testclasstypetable_to_keep', )}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("should_replace", [(True), (False)])
|
||||
def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
|
||||
def test_import_phrases(monkeypatch, temp_db_cursor, def_config, sp_importer,
|
||||
placex_table, table_factory, tokenizer_mock,
|
||||
xml_wiki_content, should_replace):
|
||||
"""
|
||||
@ -193,49 +198,49 @@ def test_import_phrases(monkeypatch, temp_db_conn, def_config, sp_importer,
|
||||
class_test = 'aerialway'
|
||||
type_test = 'zip_line'
|
||||
|
||||
assert check_table_exist(temp_db_conn, class_test, type_test)
|
||||
assert check_placeid_and_centroid_indexes(temp_db_conn, class_test, type_test)
|
||||
assert check_grant_access(temp_db_conn, def_config.DATABASE_WEBUSER, class_test, type_test)
|
||||
assert check_table_exist(temp_db_conn, 'amenity', 'animal_shelter')
|
||||
assert check_table_exist(temp_db_cursor, class_test, type_test)
|
||||
assert check_placeid_and_centroid_indexes(temp_db_cursor, class_test, type_test)
|
||||
assert check_grant_access(temp_db_cursor, def_config.DATABASE_WEBUSER, class_test, type_test)
|
||||
assert check_table_exist(temp_db_cursor, 'amenity', 'animal_shelter')
|
||||
if should_replace:
|
||||
assert not check_table_exist(temp_db_conn, 'wrong_class', 'wrong_type')
|
||||
assert not check_table_exist(temp_db_cursor, 'wrong_class', 'wrong_type')
|
||||
|
||||
assert temp_db_conn.table_exists('place_classtype_amenity_animal_shelter')
|
||||
assert temp_db_cursor.table_exists('place_classtype_amenity_animal_shelter')
|
||||
if should_replace:
|
||||
assert not temp_db_conn.table_exists('place_classtype_wrongclass_wrongtype')
|
||||
assert not temp_db_cursor.table_exists('place_classtype_wrongclass_wrongtype')
|
||||
|
||||
def check_table_exist(temp_db_conn, phrase_class, phrase_type):
|
||||
def check_table_exist(temp_db_cursor, phrase_class, phrase_type):
|
||||
"""
|
||||
Verify that the place_classtype table exists for the given
|
||||
phrase_class and phrase_type.
|
||||
"""
|
||||
return temp_db_conn.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
|
||||
return temp_db_cursor.table_exists('place_classtype_{}_{}'.format(phrase_class, phrase_type))
|
||||
|
||||
|
||||
def check_grant_access(temp_db_conn, user, phrase_class, phrase_type):
|
||||
def check_grant_access(temp_db_cursor, user, phrase_class, phrase_type):
|
||||
"""
|
||||
Check that the web user has been granted right access to the
|
||||
place_classtype table of the given phrase_class and phrase_type.
|
||||
"""
|
||||
table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
|
||||
|
||||
with temp_db_conn.cursor() as temp_db_cursor:
|
||||
temp_db_cursor.execute("""
|
||||
SELECT * FROM information_schema.role_table_grants
|
||||
WHERE table_name='{}'
|
||||
AND grantee='{}'
|
||||
AND privilege_type='SELECT'""".format(table_name, user))
|
||||
return temp_db_cursor.fetchone()
|
||||
temp_db_cursor.execute("""
|
||||
SELECT * FROM information_schema.role_table_grants
|
||||
WHERE table_name='{}'
|
||||
AND grantee='{}'
|
||||
AND privilege_type='SELECT'""".format(table_name, user))
|
||||
return temp_db_cursor.fetchone()
|
||||
|
||||
def check_placeid_and_centroid_indexes(temp_db_conn, phrase_class, phrase_type):
|
||||
def check_placeid_and_centroid_indexes(temp_db_cursor, phrase_class, phrase_type):
|
||||
"""
|
||||
Check that the place_id index and centroid index exist for the
|
||||
place_classtype table of the given phrase_class and phrase_type.
|
||||
"""
|
||||
table_name = 'place_classtype_{}_{}'.format(phrase_class, phrase_type)
|
||||
index_prefix = 'idx_place_classtype_{}_{}_'.format(phrase_class, phrase_type)
|
||||
|
||||
return (
|
||||
temp_db_conn.index_exists(index_prefix + 'centroid')
|
||||
temp_db_cursor.index_exists(table_name, index_prefix + 'centroid')
|
||||
and
|
||||
temp_db_conn.index_exists(index_prefix + 'place_id')
|
||||
temp_db_cursor.index_exists(table_name, index_prefix + 'place_id')
|
||||
)
|
||||
|
@ -8,10 +8,10 @@
|
||||
Tests for migration functions
|
||||
"""
|
||||
import pytest
|
||||
import psycopg2.extras
|
||||
|
||||
from nominatim_db.tools import migration
|
||||
from nominatim_db.errors import UsageError
|
||||
from nominatim_db.db.connection import server_version_tuple
|
||||
import nominatim_db.version
|
||||
|
||||
from mock_legacy_word_table import MockLegacyWordTable
|
||||
@ -43,7 +43,6 @@ def test_no_migration_old_versions(temp_db_with_extensions, table_factory, def_c
|
||||
def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
|
||||
table_factory, def_config, monkeypatch,
|
||||
postprocess_mock):
|
||||
psycopg2.extras.register_hstore(temp_db_cursor)
|
||||
# don't actually run any migration, except the property table creation
|
||||
monkeypatch.setattr(migration, '_MIGRATION_FUNCTIONS',
|
||||
[((3, 5, 0, 99), migration.add_nominatim_property_table)])
|
||||
@ -63,7 +62,7 @@ def test_set_up_migration_for_36(temp_db_with_extensions, temp_db_cursor,
|
||||
WHERE property = 'database_version'""")
|
||||
|
||||
|
||||
def test_already_at_version(def_config, property_table):
|
||||
def test_already_at_version(temp_db_with_extensions, def_config, property_table):
|
||||
|
||||
property_table.set('database_version',
|
||||
str(nominatim_db.version.NOMINATIM_VERSION))
|
||||
@ -71,8 +70,8 @@ def test_already_at_version(def_config, property_table):
|
||||
assert migration.migrate(def_config, {}) == 0
|
||||
|
||||
|
||||
def test_run_single_migration(def_config, temp_db_cursor, property_table,
|
||||
monkeypatch, postprocess_mock):
|
||||
def test_run_single_migration(temp_db_with_extensions, def_config, temp_db_cursor,
|
||||
property_table, monkeypatch, postprocess_mock):
|
||||
oldversion = [x for x in nominatim_db.version.NOMINATIM_VERSION]
|
||||
oldversion[0] -= 1
|
||||
property_table.set('database_version',
|
||||
@ -226,7 +225,7 @@ def test_create_tiger_housenumber_index(temp_db_conn, temp_db_cursor, table_fact
|
||||
migration.create_tiger_housenumber_index(temp_db_conn)
|
||||
temp_db_conn.commit()
|
||||
|
||||
if temp_db_conn.server_version_tuple() >= (11, 0, 0):
|
||||
if server_version_tuple(temp_db_conn) >= (11, 0, 0):
|
||||
assert temp_db_cursor.index_exists('location_property_tiger',
|
||||
'idx_location_property_tiger_housenumber_migrated')
|
||||
|
||||
|
@ -47,7 +47,7 @@ class MockPostcodeTable:
|
||||
country_code, postcode,
|
||||
geometry)
|
||||
VALUES (nextval('seq_place'), 1, %s, %s,
|
||||
'SRID=4326;POINT(%s %s)')""",
|
||||
ST_SetSRID(ST_MakePoint(%s, %s), 4326))""",
|
||||
(country, postcode, x, y))
|
||||
self.conn.commit()
|
||||
|
||||
|
@ -12,6 +12,7 @@ from pathlib import Path
|
||||
import pytest
|
||||
|
||||
from nominatim_db.tools import refresh
|
||||
from nominatim_db.db.connection import postgis_version_tuple
|
||||
|
||||
def test_refresh_import_wikipedia_not_existing(dsn):
|
||||
assert refresh.import_wikipedia_articles(dsn, Path('.')) == 1
|
||||
@ -23,13 +24,13 @@ def test_refresh_import_secondary_importance_non_existing(dsn):
|
||||
def test_refresh_import_secondary_importance_testdb(dsn, src_dir, temp_db_conn, temp_db_cursor):
|
||||
temp_db_cursor.execute('CREATE EXTENSION postgis')
|
||||
|
||||
if temp_db_conn.postgis_version_tuple()[0] < 3:
|
||||
if postgis_version_tuple(temp_db_conn)[0] < 3:
|
||||
assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') > 0
|
||||
else:
|
||||
temp_db_cursor.execute('CREATE EXTENSION postgis_raster')
|
||||
assert refresh.import_secondary_importance(dsn, src_dir / 'test' / 'testdb') == 0
|
||||
|
||||
assert temp_db_conn.table_exists('secondary_importance')
|
||||
assert temp_db_cursor.table_exists('secondary_importance')
|
||||
|
||||
|
||||
@pytest.mark.parametrize("replace", (True, False))
|
||||
|
@ -11,7 +11,9 @@ import tarfile
|
||||
from textwrap import dedent
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from nominatim_db.db.connection import execute_scalar
|
||||
from nominatim_db.tools import tiger_data, freeze
|
||||
from nominatim_db.errors import UsageError
|
||||
|
||||
@ -31,8 +33,7 @@ class MockTigerTable:
|
||||
cur.execute("CREATE TABLE place (number INTEGER)")
|
||||
|
||||
def count(self):
|
||||
with self.conn.cursor() as cur:
|
||||
return cur.scalar("SELECT count(*) FROM tiger")
|
||||
return execute_scalar(self.conn, "SELECT count(*) FROM tiger")
|
||||
|
||||
def row(self):
|
||||
with self.conn.cursor() as cur:
|
||||
@ -76,82 +77,91 @@ def csv_factory(tmp_path):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", (1, 5))
|
||||
def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads):
|
||||
tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'),
|
||||
def_config, threads, tokenizer_mock())
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data(def_config, src_dir, tiger_table, tokenizer_mock, threads):
|
||||
await tiger_data.add_tiger_data(str(src_dir / 'test' / 'testdb' / 'tiger'),
|
||||
def_config, threads, tokenizer_mock())
|
||||
|
||||
assert tiger_table.count() == 6213
|
||||
|
||||
|
||||
def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock,
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_database_frozen(def_config, temp_db_conn, tiger_table, tokenizer_mock,
|
||||
tmp_path):
|
||||
freeze.drop_update_tables(temp_db_conn)
|
||||
|
||||
with pytest.raises(UsageError) as excinfo:
|
||||
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
|
||||
assert "database frozen" in str(excinfo.value)
|
||||
|
||||
assert tiger_table.count() == 0
|
||||
|
||||
def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_no_files(def_config, tiger_table, tokenizer_mock,
|
||||
tmp_path):
|
||||
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
|
||||
assert tiger_table.count() == 0
|
||||
|
||||
|
||||
def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock,
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_bad_file(def_config, tiger_table, tokenizer_mock,
|
||||
tmp_path):
|
||||
sqlfile = tmp_path / '1010.csv'
|
||||
sqlfile.write_text("""Random text""")
|
||||
|
||||
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
|
||||
assert tiger_table.count() == 0
|
||||
|
||||
|
||||
def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock,
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_hnr_nan(def_config, tiger_table, tokenizer_mock,
|
||||
csv_factory, tmp_path):
|
||||
csv_factory('file1', hnr_from=99)
|
||||
csv_factory('file2', hnr_from='L12')
|
||||
csv_factory('file3', hnr_to='12.4')
|
||||
|
||||
tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tmp_path), def_config, 1, tokenizer_mock())
|
||||
|
||||
assert tiger_table.count() == 1
|
||||
assert tiger_table.row()['start'] == 99
|
||||
assert tiger_table.row().start == 99
|
||||
|
||||
|
||||
@pytest.mark.parametrize("threads", (1, 5))
|
||||
def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock,
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_tarfile(def_config, tiger_table, tokenizer_mock,
|
||||
tmp_path, src_dir, threads):
|
||||
tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz")
|
||||
tar.add(str(src_dir / 'test' / 'testdb' / 'tiger' / '01001.csv'))
|
||||
tar.close()
|
||||
|
||||
tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads,
|
||||
tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, threads,
|
||||
tokenizer_mock())
|
||||
|
||||
assert tiger_table.count() == 6213
|
||||
|
||||
|
||||
def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock,
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_bad_tarfile(def_config, tiger_table, tokenizer_mock,
|
||||
tmp_path):
|
||||
tarfile = tmp_path / 'sample.tar.gz'
|
||||
tarfile.write_text("""Random text""")
|
||||
|
||||
with pytest.raises(UsageError):
|
||||
tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tarfile), def_config, 1, tokenizer_mock())
|
||||
|
||||
|
||||
def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock,
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_tiger_data_empty_tarfile(def_config, tiger_table, tokenizer_mock,
|
||||
tmp_path):
|
||||
tar = tarfile.open(str(tmp_path / 'sample.tar.gz'), "w:gz")
|
||||
tar.add(__file__)
|
||||
tar.close()
|
||||
|
||||
tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1,
|
||||
tokenizer_mock())
|
||||
await tiger_data.add_tiger_data(str(tmp_path / 'sample.tar.gz'), def_config, 1,
|
||||
tokenizer_mock())
|
||||
|
||||
assert tiger_table.count() == 0
|
||||
|
@ -26,10 +26,14 @@ export DEBIAN_FRONTEND=noninteractive #DOCS:
|
||||
nlohmann-json3-dev postgresql-14-postgis-3 \
|
||||
postgresql-contrib-14 postgresql-14-postgis-3-scripts \
|
||||
libicu-dev python3-dotenv \
|
||||
python3-psycopg2 python3-psutil python3-jinja2 \
|
||||
python3-pip python3-psutil python3-jinja2 \
|
||||
python3-sqlalchemy python3-asyncpg \
|
||||
python3-icu python3-datrie python3-yaml git
|
||||
|
||||
# Some of the Python packages that come with Ubuntu 22.04 are too old,
|
||||
# so install the latest version from pip:
|
||||
|
||||
pip3 install --user psycopg[binary]
|
||||
|
||||
#
|
||||
# System Configuration
|
||||
|
Loading…
Reference in New Issue
Block a user