mirror of
https://github.com/osm-search/Nominatim.git
synced 2024-11-22 21:28:10 +03:00
add type annotations for indexer
This commit is contained in:
parent
8adab2c6ca
commit
5617bffe2f
@ -6,7 +6,7 @@
|
||||
# For a full list of authors see the git log.
|
||||
""" Non-blocking database connections.
|
||||
"""
|
||||
from typing import Callable, Any, Optional, List, Iterator
|
||||
from typing import Callable, Any, Optional, Iterator, Sequence
|
||||
import logging
|
||||
import select
|
||||
import time
|
||||
@ -22,7 +22,7 @@ try:
|
||||
except ImportError:
|
||||
__has_psycopg2_errors__ = False
|
||||
|
||||
from nominatim.typing import T_cursor
|
||||
from nominatim.typing import T_cursor, Query
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
@ -65,8 +65,8 @@ class DBConnection:
|
||||
ignore_sql_errors: bool = False) -> None:
|
||||
self.dsn = dsn
|
||||
|
||||
self.current_query: Optional[str] = None
|
||||
self.current_params: Optional[List[Any]] = None
|
||||
self.current_query: Optional[Query] = None
|
||||
self.current_params: Optional[Sequence[Any]] = None
|
||||
self.ignore_sql_errors = ignore_sql_errors
|
||||
|
||||
self.conn: Optional['psycopg2.connection'] = None
|
||||
@ -128,7 +128,7 @@ class DBConnection:
|
||||
self.current_query = None
|
||||
return
|
||||
|
||||
def perform(self, sql: str, args: Optional[List[Any]] = None) -> None:
|
||||
def perform(self, sql: Query, args: Optional[Sequence[Any]] = None) -> None:
|
||||
""" Send SQL query to the server. Returns immediately without
|
||||
blocking.
|
||||
"""
|
||||
|
@ -74,7 +74,7 @@ class Cursor(psycopg2.extras.DictCursor):
|
||||
if cascade:
|
||||
sql += ' CASCADE'
|
||||
|
||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name))) # type: ignore[no-untyped-call]
|
||||
self.execute(pysql.SQL(sql).format(pysql.Identifier(name)))
|
||||
|
||||
|
||||
class Connection(psycopg2.extensions.connection):
|
||||
|
@ -7,15 +7,18 @@
|
||||
"""
|
||||
Main work horse for indexing (computing addresses) the database.
|
||||
"""
|
||||
from typing import Optional, Any, cast
|
||||
import logging
|
||||
import time
|
||||
|
||||
import psycopg2.extras
|
||||
|
||||
from nominatim.tokenizer.base import AbstractTokenizer
|
||||
from nominatim.indexer.progress import ProgressLogger
|
||||
from nominatim.indexer import runners
|
||||
from nominatim.db.async_connection import DBConnection, WorkerPool
|
||||
from nominatim.db.connection import connect
|
||||
from nominatim.db.connection import connect, Connection, Cursor
|
||||
from nominatim.typing import DictCursorResults
|
||||
|
||||
LOG = logging.getLogger()
|
||||
|
||||
@ -23,10 +26,11 @@ LOG = logging.getLogger()
|
||||
class PlaceFetcher:
|
||||
""" Asynchronous connection that fetches place details for processing.
|
||||
"""
|
||||
def __init__(self, dsn, setup_conn):
|
||||
self.wait_time = 0
|
||||
self.current_ids = None
|
||||
self.conn = DBConnection(dsn, cursor_factory=psycopg2.extras.DictCursor)
|
||||
def __init__(self, dsn: str, setup_conn: Connection) -> None:
|
||||
self.wait_time = 0.0
|
||||
self.current_ids: Optional[DictCursorResults] = None
|
||||
self.conn: Optional[DBConnection] = DBConnection(dsn,
|
||||
cursor_factory=psycopg2.extras.DictCursor)
|
||||
|
||||
with setup_conn.cursor() as cur:
|
||||
# need to fetch those manually because register_hstore cannot
|
||||
@ -37,7 +41,7 @@ class PlaceFetcher:
|
||||
psycopg2.extras.register_hstore(self.conn.conn, oid=hstore_oid,
|
||||
array_oid=hstore_array_oid)
|
||||
|
||||
def close(self):
|
||||
def close(self) -> None:
|
||||
""" Close the underlying asynchronous connection.
|
||||
"""
|
||||
if self.conn:
|
||||
@ -45,44 +49,46 @@ class PlaceFetcher:
|
||||
self.conn = None
|
||||
|
||||
|
||||
def fetch_next_batch(self, cur, runner):
|
||||
def fetch_next_batch(self, cur: Cursor, runner: runners.Runner) -> bool:
|
||||
""" Send a request for the next batch of places.
|
||||
If details for the places are required, they will be fetched
|
||||
asynchronously.
|
||||
|
||||
Returns true if there is still data available.
|
||||
"""
|
||||
ids = cur.fetchmany(100)
|
||||
ids = cast(Optional[DictCursorResults], cur.fetchmany(100))
|
||||
|
||||
if not ids:
|
||||
self.current_ids = None
|
||||
return False
|
||||
|
||||
if hasattr(runner, 'get_place_details'):
|
||||
runner.get_place_details(self.conn, ids)
|
||||
self.current_ids = []
|
||||
else:
|
||||
self.current_ids = ids
|
||||
assert self.conn is not None
|
||||
self.current_ids = runner.get_place_details(self.conn, ids)
|
||||
|
||||
return True
|
||||
|
||||
def get_batch(self):
|
||||
def get_batch(self) -> DictCursorResults:
|
||||
""" Get the next batch of data, previously requested with
|
||||
`fetch_next_batch`.
|
||||
"""
|
||||
assert self.conn is not None
|
||||
assert self.conn.cursor is not None
|
||||
|
||||
if self.current_ids is not None and not self.current_ids:
|
||||
tstart = time.time()
|
||||
self.conn.wait()
|
||||
self.wait_time += time.time() - tstart
|
||||
self.current_ids = self.conn.cursor.fetchall()
|
||||
self.current_ids = cast(Optional[DictCursorResults],
|
||||
self.conn.cursor.fetchall())
|
||||
|
||||
return self.current_ids
|
||||
return self.current_ids if self.current_ids is not None else []
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> 'PlaceFetcher':
|
||||
return self
|
||||
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
assert self.conn is not None
|
||||
self.conn.wait()
|
||||
self.close()
|
||||
|
||||
@ -91,13 +97,13 @@ class Indexer:
|
||||
""" Main indexing routine.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn, tokenizer, num_threads):
|
||||
def __init__(self, dsn: str, tokenizer: AbstractTokenizer, num_threads: int):
|
||||
self.dsn = dsn
|
||||
self.tokenizer = tokenizer
|
||||
self.num_threads = num_threads
|
||||
|
||||
|
||||
def has_pending(self):
|
||||
def has_pending(self) -> bool:
|
||||
""" Check if any data still needs indexing.
|
||||
This function must only be used after the import has finished.
|
||||
Otherwise it will be very expensive.
|
||||
@ -108,7 +114,7 @@ class Indexer:
|
||||
return cur.rowcount > 0
|
||||
|
||||
|
||||
def index_full(self, analyse=True):
|
||||
def index_full(self, analyse: bool = True) -> None:
|
||||
""" Index the complete database. This will first index boundaries
|
||||
followed by all other objects. When `analyse` is True, then the
|
||||
database will be analysed at the appropriate places to
|
||||
@ -117,7 +123,7 @@ class Indexer:
|
||||
with connect(self.dsn) as conn:
|
||||
conn.autocommit = True
|
||||
|
||||
def _analyze():
|
||||
def _analyze() -> None:
|
||||
if analyse:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute('ANALYZE')
|
||||
@ -138,7 +144,7 @@ class Indexer:
|
||||
_analyze()
|
||||
|
||||
|
||||
def index_boundaries(self, minrank, maxrank):
|
||||
def index_boundaries(self, minrank: int, maxrank: int) -> None:
|
||||
""" Index only administrative boundaries within the given rank range.
|
||||
"""
|
||||
LOG.warning("Starting indexing boundaries using %s threads",
|
||||
@ -148,7 +154,7 @@ class Indexer:
|
||||
for rank in range(max(minrank, 4), min(maxrank, 26)):
|
||||
self._index(runners.BoundaryRunner(rank, analyzer))
|
||||
|
||||
def index_by_rank(self, minrank, maxrank):
|
||||
def index_by_rank(self, minrank: int, maxrank: int) -> None:
|
||||
""" Index all entries of placex in the given rank range (inclusive)
|
||||
in order of their address rank.
|
||||
|
||||
@ -168,7 +174,7 @@ class Indexer:
|
||||
self._index(runners.InterpolationRunner(analyzer), 20)
|
||||
|
||||
|
||||
def index_postcodes(self):
|
||||
def index_postcodes(self) -> None:
|
||||
"""Index the entries ofthe location_postcode table.
|
||||
"""
|
||||
LOG.warning("Starting indexing postcodes using %s threads", self.num_threads)
|
||||
@ -176,7 +182,7 @@ class Indexer:
|
||||
self._index(runners.PostcodeRunner(), 20)
|
||||
|
||||
|
||||
def update_status_table(self):
|
||||
def update_status_table(self) -> None:
|
||||
""" Update the status in the status table to 'indexed'.
|
||||
"""
|
||||
with connect(self.dsn) as conn:
|
||||
@ -185,7 +191,7 @@ class Indexer:
|
||||
|
||||
conn.commit()
|
||||
|
||||
def _index(self, runner, batch=1):
|
||||
def _index(self, runner: runners.Runner, batch: int = 1) -> None:
|
||||
""" Index a single rank or table. `runner` describes the SQL to use
|
||||
for indexing. `batch` describes the number of objects that
|
||||
should be processed with a single SQL statement
|
||||
|
@ -22,7 +22,7 @@ class ProgressLogger:
|
||||
should be reported.
|
||||
"""
|
||||
|
||||
def __init__(self, name, total, log_interval=1):
|
||||
def __init__(self, name: str, total: int, log_interval: int = 1) -> None:
|
||||
self.name = name
|
||||
self.total_places = total
|
||||
self.done_places = 0
|
||||
@ -30,7 +30,7 @@ class ProgressLogger:
|
||||
self.log_interval = log_interval
|
||||
self.next_info = INITIAL_PROGRESS if LOG.isEnabledFor(logging.WARNING) else total + 1
|
||||
|
||||
def add(self, num=1):
|
||||
def add(self, num: int = 1) -> None:
|
||||
""" Mark `num` places as processed. Print a log message if the
|
||||
logging is at least info and the log interval has passed.
|
||||
"""
|
||||
@ -55,14 +55,14 @@ class ProgressLogger:
|
||||
|
||||
self.next_info += int(places_per_sec) * self.log_interval
|
||||
|
||||
def done(self):
|
||||
def done(self) -> None:
|
||||
""" Print final statistics about the progress.
|
||||
"""
|
||||
rank_end_time = datetime.now()
|
||||
|
||||
if rank_end_time == self.rank_start_time:
|
||||
diff_seconds = 0
|
||||
places_per_sec = self.done_places
|
||||
diff_seconds = 0.0
|
||||
places_per_sec = float(self.done_places)
|
||||
else:
|
||||
diff_seconds = (rank_end_time - self.rank_start_time).total_seconds()
|
||||
places_per_sec = self.done_places / diff_seconds
|
||||
|
@ -8,35 +8,49 @@
|
||||
Mix-ins that provide the actual commands for the indexer for various indexing
|
||||
tasks.
|
||||
"""
|
||||
from typing import Any, List
|
||||
import functools
|
||||
|
||||
from typing_extensions import Protocol
|
||||
from psycopg2 import sql as pysql
|
||||
import psycopg2.extras
|
||||
|
||||
from nominatim.data.place_info import PlaceInfo
|
||||
from nominatim.tokenizer.base import AbstractAnalyzer
|
||||
from nominatim.db.async_connection import DBConnection
|
||||
from nominatim.typing import Query, DictCursorResult, DictCursorResults
|
||||
|
||||
# pylint: disable=C0111
|
||||
|
||||
def _mk_valuelist(template, num):
|
||||
def _mk_valuelist(template: str, num: int) -> pysql.Composed:
|
||||
return pysql.SQL(',').join([pysql.SQL(template)] * num)
|
||||
|
||||
def _analyze_place(place, analyzer):
|
||||
def _analyze_place(place: DictCursorResult, analyzer: AbstractAnalyzer) -> psycopg2.extras.Json:
|
||||
return psycopg2.extras.Json(analyzer.process_place(PlaceInfo(place)))
|
||||
|
||||
|
||||
class Runner(Protocol):
|
||||
def name(self) -> str: ...
|
||||
def sql_count_objects(self) -> Query: ...
|
||||
def sql_get_objects(self) -> Query: ...
|
||||
def get_place_details(self, worker: DBConnection,
|
||||
ids: DictCursorResults) -> DictCursorResults: ...
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None: ...
|
||||
|
||||
|
||||
class AbstractPlacexRunner:
|
||||
""" Returns SQL commands for indexing of the placex table.
|
||||
"""
|
||||
SELECT_SQL = pysql.SQL('SELECT place_id FROM placex ')
|
||||
UPDATE_LINE = "(%s, %s::hstore, %s::hstore, %s::int, %s::jsonb)"
|
||||
|
||||
def __init__(self, rank, analyzer):
|
||||
def __init__(self, rank: int, analyzer: AbstractAnalyzer) -> None:
|
||||
self.rank = rank
|
||||
self.analyzer = analyzer
|
||||
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _index_sql(num_places):
|
||||
def _index_sql(self, num_places: int) -> pysql.Composed:
|
||||
return pysql.SQL(
|
||||
""" UPDATE placex
|
||||
SET indexed_status = 0, address = v.addr, token_info = v.ti,
|
||||
@ -46,16 +60,17 @@ class AbstractPlacexRunner:
|
||||
""").format(_mk_valuelist(AbstractPlacexRunner.UPDATE_LINE, num_places))
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_place_details(worker, ids):
|
||||
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
|
||||
worker.perform("""SELECT place_id, extra.*
|
||||
FROM placex, LATERAL placex_indexing_prepare(placex) as extra
|
||||
WHERE place_id IN %s""",
|
||||
(tuple((p[0] for p in ids)), ))
|
||||
|
||||
return []
|
||||
|
||||
def index_places(self, worker, places):
|
||||
values = []
|
||||
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
|
||||
values: List[Any] = []
|
||||
for place in places:
|
||||
for field in ('place_id', 'name', 'address', 'linked_place_id'):
|
||||
values.append(place[field])
|
||||
@ -68,15 +83,15 @@ class RankRunner(AbstractPlacexRunner):
|
||||
""" Returns SQL commands for indexing one rank within the placex table.
|
||||
"""
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"rank {self.rank}"
|
||||
|
||||
def sql_count_objects(self):
|
||||
def sql_count_objects(self) -> pysql.Composed:
|
||||
return pysql.SQL("""SELECT count(*) FROM placex
|
||||
WHERE rank_address = {} and indexed_status > 0
|
||||
""").format(pysql.Literal(self.rank))
|
||||
|
||||
def sql_get_objects(self):
|
||||
def sql_get_objects(self) -> pysql.Composed:
|
||||
return self.SELECT_SQL + pysql.SQL(
|
||||
"""WHERE indexed_status > 0 and rank_address = {}
|
||||
ORDER BY geometry_sector
|
||||
@ -88,17 +103,17 @@ class BoundaryRunner(AbstractPlacexRunner):
|
||||
of a certain rank.
|
||||
"""
|
||||
|
||||
def name(self):
|
||||
def name(self) -> str:
|
||||
return f"boundaries rank {self.rank}"
|
||||
|
||||
def sql_count_objects(self):
|
||||
def sql_count_objects(self) -> pysql.Composed:
|
||||
return pysql.SQL("""SELECT count(*) FROM placex
|
||||
WHERE indexed_status > 0
|
||||
AND rank_search = {}
|
||||
AND class = 'boundary' and type = 'administrative'
|
||||
""").format(pysql.Literal(self.rank))
|
||||
|
||||
def sql_get_objects(self):
|
||||
def sql_get_objects(self) -> pysql.Composed:
|
||||
return self.SELECT_SQL + pysql.SQL(
|
||||
"""WHERE indexed_status > 0 and rank_search = {}
|
||||
and class = 'boundary' and type = 'administrative'
|
||||
@ -111,37 +126,33 @@ class InterpolationRunner:
|
||||
location_property_osmline.
|
||||
"""
|
||||
|
||||
def __init__(self, analyzer):
|
||||
def __init__(self, analyzer: AbstractAnalyzer) -> None:
|
||||
self.analyzer = analyzer
|
||||
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
def name(self) -> str:
|
||||
return "interpolation lines (location_property_osmline)"
|
||||
|
||||
@staticmethod
|
||||
def sql_count_objects():
|
||||
def sql_count_objects(self) -> str:
|
||||
return """SELECT count(*) FROM location_property_osmline
|
||||
WHERE indexed_status > 0"""
|
||||
|
||||
@staticmethod
|
||||
def sql_get_objects():
|
||||
def sql_get_objects(self) -> str:
|
||||
return """SELECT place_id
|
||||
FROM location_property_osmline
|
||||
WHERE indexed_status > 0
|
||||
ORDER BY geometry_sector"""
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_place_details(worker, ids):
|
||||
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
|
||||
worker.perform("""SELECT place_id, get_interpolation_address(address, osm_id) as address
|
||||
FROM location_property_osmline WHERE place_id IN %s""",
|
||||
(tuple((p[0] for p in ids)), ))
|
||||
return []
|
||||
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(maxsize=1)
|
||||
def _index_sql(num_places):
|
||||
def _index_sql(self, num_places: int) -> pysql.Composed:
|
||||
return pysql.SQL("""UPDATE location_property_osmline
|
||||
SET indexed_status = 0, address = v.addr, token_info = v.ti
|
||||
FROM (VALUES {}) as v(id, addr, ti)
|
||||
@ -149,8 +160,8 @@ class InterpolationRunner:
|
||||
""").format(_mk_valuelist("(%s, %s::hstore, %s::jsonb)", num_places))
|
||||
|
||||
|
||||
def index_places(self, worker, places):
|
||||
values = []
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
|
||||
values: List[Any] = []
|
||||
for place in places:
|
||||
values.extend((place[x] for x in ('place_id', 'address')))
|
||||
values.append(_analyze_place(place, self.analyzer))
|
||||
@ -159,26 +170,28 @@ class InterpolationRunner:
|
||||
|
||||
|
||||
|
||||
class PostcodeRunner:
|
||||
class PostcodeRunner(Runner):
|
||||
""" Provides the SQL commands for indexing the location_postcode table.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def name():
|
||||
def name(self) -> str:
|
||||
return "postcodes (location_postcode)"
|
||||
|
||||
@staticmethod
|
||||
def sql_count_objects():
|
||||
|
||||
def sql_count_objects(self) -> str:
|
||||
return 'SELECT count(*) FROM location_postcode WHERE indexed_status > 0'
|
||||
|
||||
@staticmethod
|
||||
def sql_get_objects():
|
||||
|
||||
def sql_get_objects(self) -> str:
|
||||
return """SELECT place_id FROM location_postcode
|
||||
WHERE indexed_status > 0
|
||||
ORDER BY country_code, postcode"""
|
||||
|
||||
@staticmethod
|
||||
def index_places(worker, ids):
|
||||
|
||||
def get_place_details(self, worker: DBConnection, ids: DictCursorResults) -> DictCursorResults:
|
||||
return ids
|
||||
|
||||
def index_places(self, worker: DBConnection, places: DictCursorResults) -> None:
|
||||
worker.perform(pysql.SQL("""UPDATE location_postcode SET indexed_status = 0
|
||||
WHERE place_id IN ({})""")
|
||||
.format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in ids))))
|
||||
.format(pysql.SQL(',').join((pysql.Literal(i[0]) for i in places))))
|
||||
|
@ -9,14 +9,15 @@ Type definitions for typing annotations.
|
||||
|
||||
Complex type definitions are moved here, to keep the source files readable.
|
||||
"""
|
||||
from typing import Union, Mapping, TypeVar, TYPE_CHECKING
|
||||
from typing import Any, Union, Mapping, TypeVar, Sequence, TYPE_CHECKING
|
||||
|
||||
# Generics varaible names do not confirm to naming styles, ignore globally here.
|
||||
# pylint: disable=invalid-name
|
||||
# pylint: disable=invalid-name,abstract-method,multiple-statements,missing-class-docstring
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import psycopg2.sql
|
||||
import psycopg2.extensions
|
||||
import psycopg2.extras
|
||||
import os
|
||||
|
||||
StrPath = Union[str, 'os.PathLike[str]']
|
||||
@ -26,4 +27,12 @@ SysEnv = Mapping[str, str]
|
||||
# psycopg2-related types
|
||||
|
||||
Query = Union[str, bytes, 'psycopg2.sql.Composable']
|
||||
|
||||
T_ResultKey = TypeVar('T_ResultKey', int, str)
|
||||
|
||||
class DictCursorResult(Mapping[str, Any]):
|
||||
def __getitem__(self, x: Union[int, str]) -> Any: ...
|
||||
|
||||
DictCursorResults = Sequence[DictCursorResult]
|
||||
|
||||
T_cursor = TypeVar('T_cursor', bound='psycopg2.extensions.cursor')
|
||||
|
Loading…
Reference in New Issue
Block a user