diff --git a/.mypy.ini b/.mypy.ini index 611c3c5d..aa6782de 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -18,3 +18,6 @@ ignore_missing_imports = True [mypy-falcon.*] ignore_missing_imports = True + +[mypy-geoalchemy2.*] +ignore_missing_imports = True diff --git a/nominatim/api/connection.py b/nominatim/api/connection.py new file mode 100644 index 00000000..1502ba01 --- /dev/null +++ b/nominatim/api/connection.py @@ -0,0 +1,43 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Extended SQLAlchemy connection class that also includes access to the schema. +""" +from typing import Any, Mapping, Sequence, Union + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection + +from nominatim.db.sqlalchemy_schema import SearchTables + +class SearchConnection: + """ An extended SQLAlchemy connection class, that also contains + then table definitions. The underlying asynchronous SQLAlchemy + connection can be accessed with the 'connection' property. + The 't' property is the collection of Nominatim tables. + """ + + def __init__(self, conn: AsyncConnection, + tables: SearchTables) -> None: + self.connection = conn + self.t = tables # pylint: disable=invalid-name + + + async def scalar(self, sql: sa.sql.base.Executable, + params: Union[Mapping[str, Any], None] = None + ) -> Any: + """ Execute a 'scalar()' query on the connection. + """ + return await self.connection.scalar(sql, params) + + + async def execute(self, sql: sa.sql.base.Executable, + params: Union[Mapping[str, Any], Sequence[Mapping[str, Any]], None] = None + ) -> sa.engine.Result: + """ Execute a 'execute()' query on the connection. + """ + return await self.connection.execute(sql, params) diff --git a/nominatim/api/core.py b/nominatim/api/core.py index 159229dd..ae5bb5d8 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -16,8 +16,10 @@ import sqlalchemy as sa import sqlalchemy.ext.asyncio as sa_asyncio import asyncpg +from nominatim.db.sqlalchemy_schema import SearchTables from nominatim.config import Configuration from nominatim.api.status import get_status, StatusResult +from nominatim.api.connection import SearchConnection class NominatimAPIAsync: """ API loader asynchornous version. @@ -29,6 +31,7 @@ class NominatimAPIAsync: self._engine_lock = asyncio.Lock() self._engine: Optional[sa_asyncio.AsyncEngine] = None + self._tables: Optional[SearchTables] = None async def setup_database(self) -> None: @@ -73,6 +76,7 @@ class NominatimAPIAsync: # Make sure that all connections get the new settings await self.close() + self._tables = SearchTables(sa.MetaData(), engine.name) # pylint: disable=no-member self._engine = engine @@ -86,7 +90,7 @@ class NominatimAPIAsync: @contextlib.asynccontextmanager - async def begin(self) -> AsyncIterator[sa_asyncio.AsyncConnection]: + async def begin(self) -> AsyncIterator[SearchConnection]: """ Create a new connection with automatic transaction handling. This function may be used to get low-level access to the database. @@ -97,9 +101,10 @@ class NominatimAPIAsync: await self.setup_database() assert self._engine is not None + assert self._tables is not None async with self._engine.begin() as conn: - yield conn + yield SearchConnection(conn, self._tables) async def status(self) -> StatusResult: diff --git a/nominatim/api/status.py b/nominatim/api/status.py index 560953d3..a992460c 100644 --- a/nominatim/api/status.py +++ b/nominatim/api/status.py @@ -11,9 +11,9 @@ from typing import Optional, cast import datetime as dt import sqlalchemy as sa -from sqlalchemy.ext.asyncio.engine import AsyncConnection import asyncpg +from nominatim.api.connection import SearchConnection from nominatim import version class StatusResult: @@ -28,7 +28,7 @@ class StatusResult: self.database_version: Optional[version.NominatimVersion] = None -async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]: +async def _get_database_date(conn: SearchConnection) -> Optional[dt.datetime]: """ Query the database date. """ sql = sa.text('SELECT lastimportdate FROM import_status LIMIT 1') @@ -40,7 +40,7 @@ async def _get_database_date(conn: AsyncConnection) -> Optional[dt.datetime]: return None -async def _get_database_version(conn: AsyncConnection) -> Optional[version.NominatimVersion]: +async def _get_database_version(conn: SearchConnection) -> Optional[version.NominatimVersion]: sql = sa.text("""SELECT value FROM nominatim_properties WHERE property = 'database_version'""") result = await conn.execute(sql) @@ -51,7 +51,7 @@ async def _get_database_version(conn: AsyncConnection) -> Optional[version.Nomin return None -async def get_status(conn: AsyncConnection) -> StatusResult: +async def get_status(conn: SearchConnection) -> StatusResult: """ Execute a status API call. """ status = StatusResult(0, 'OK') diff --git a/nominatim/db/sqlalchemy_schema.py b/nominatim/db/sqlalchemy_schema.py new file mode 100644 index 00000000..17839168 --- /dev/null +++ b/nominatim/db/sqlalchemy_schema.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +SQLAlchemy definitions for all tables used by the frontend. +""" +from typing import Any + +import sqlalchemy as sa +from geoalchemy2 import Geometry +from sqlalchemy.dialects.postgresql import HSTORE, ARRAY, JSONB +from sqlalchemy.dialects.sqlite import JSON as sqlite_json + +#pylint: disable=too-many-instance-attributes +class SearchTables: + """ Data class that holds the tables of the Nominatim database. + """ + + def __init__(self, meta: sa.MetaData, engine_name: str) -> None: + if engine_name == 'postgresql': + Composite: Any = HSTORE + Json: Any = JSONB + IntArray: Any = ARRAY(sa.Integer()) #pylint: disable=invalid-name + elif engine_name == 'sqlite': + Composite = sqlite_json + Json = sqlite_json + IntArray = sqlite_json + else: + raise ValueError("Only 'postgresql' and 'sqlite' engines are supported.") + + self.meta = meta + + self.import_status = sa.Table('import_status', meta, + sa.Column('lastimportdate', sa.DateTime(True), nullable=False), + sa.Column('sequence_id', sa.Integer), + sa.Column('indexed', sa.Boolean)) + + self.properties = sa.Table('nominatim_properties', meta, + sa.Column('property', sa.Text, nullable=False), + sa.Column('value', sa.Text)) + + self.placex = sa.Table('placex', meta, + sa.Column('place_id', sa.BigInteger, nullable=False, unique=True), + sa.Column('parent_place_id', sa.BigInteger), + sa.Column('linked_place_id', sa.BigInteger), + sa.Column('importance', sa.Float), + sa.Column('indexed_date', sa.DateTime), + sa.Column('rank_address', sa.SmallInteger), + sa.Column('rank_search', sa.SmallInteger), + sa.Column('partition', sa.SmallInteger), + sa.Column('indexed_status', sa.SmallInteger), + sa.Column('osm_type', sa.String(1), nullable=False), + sa.Column('osm_id', sa.BigInteger, nullable=False), + sa.Column('class', sa.Text, nullable=False, key='class_'), + sa.Column('type', sa.Text, nullable=False), + sa.Column('admin_level', sa.SmallInteger), + sa.Column('name', Composite), + sa.Column('address', Composite), + sa.Column('extratags', Composite), + sa.Column('geometry', Geometry(srid=4326), nullable=False), + sa.Column('wikipedia', sa.Text), + sa.Column('country_code', sa.String(2)), + sa.Column('housenumber', sa.Text), + sa.Column('postcode', sa.Text), + sa.Column('centroid', Geometry(srid=4326, spatial_index=False))) + + self.addressline = sa.Table('place_addressline', meta, + sa.Column('place_id', sa.BigInteger, index=True), + sa.Column('address_place_id', sa.BigInteger, index=True), + sa.Column('distance', sa.Float), + sa.Column('cached_rank_address', sa.SmallInteger), + sa.Column('fromarea', sa.Boolean), + sa.Column('isaddress', sa.Boolean)) + + self.postcode = sa.Table('location_postcode', meta, + sa.Column('place_id', sa.BigInteger, unique=True), + sa.Column('parent_place_id', sa.BigInteger), + sa.Column('rank_search', sa.SmallInteger), + sa.Column('rank_address', sa.SmallInteger), + sa.Column('indexed_status', sa.SmallInteger), + sa.Column('indexed_date', sa.DateTime), + sa.Column('country_code', sa.String(2)), + sa.Column('postcode', sa.Text, index=True), + sa.Column('geometry', Geometry(srid=4326))) + + self.osmline = sa.Table('location_property_osmline', meta, + sa.Column('place_id', sa.BigInteger, nullable=False, unique=True), + sa.Column('osm_id', sa.BigInteger), + sa.Column('parent_place_id', sa.BigInteger), + sa.Column('indexed_date', sa.DateTime), + sa.Column('startnumber', sa.Integer), + sa.Column('endnumber', sa.Integer), + sa.Column('step', sa.SmallInteger), + sa.Column('partition', sa.SmallInteger), + sa.Column('indexed_status', sa.SmallInteger), + sa.Column('linegeo', Geometry(srid=4326)), + sa.Column('address', Composite), + sa.Column('postcode', sa.Text), + sa.Column('country_code', sa.String(2))) + + self.word = sa.Table('word', meta, + sa.Column('word_id', sa.Integer), + sa.Column('word_token', sa.Text, nullable=False), + sa.Column('type', sa.Text, nullable=False), + sa.Column('word', sa.Text), + sa.Column('info', Json)) + + self.country_name = sa.Table('country_name', meta, + sa.Column('country_code', sa.String(2)), + sa.Column('name', Composite), + sa.Column('derived_name', Composite), + sa.Column('country_default_language_code', sa.Text), + sa.Column('partition', sa.Integer)) + + self.country_grid = sa.Table('country_osm_grid', meta, + sa.Column('country_code', sa.String(2)), + sa.Column('area', sa.Float), + sa.Column('geometry', Geometry(srid=4326))) + + # The following tables are not necessarily present. + self.search_name = sa.Table('search_name', meta, + sa.Column('place_id', sa.BigInteger, index=True), + sa.Column('importance', sa.Float), + sa.Column('search_rank', sa.SmallInteger), + sa.Column('address_rank', sa.SmallInteger), + sa.Column('name_vector', IntArray, index=True), + sa.Column('nameaddress_vector', IntArray, index=True), + sa.Column('country_code', sa.String(2)), + sa.Column('centroid', Geometry(srid=4326))) + + self.tiger = sa.Table('location_property_tiger', meta, + sa.Column('place_id', sa.BigInteger), + sa.Column('parent_place_id', sa.BigInteger), + sa.Column('startnumber', sa.Integer), + sa.Column('endnumber', sa.Integer), + sa.Column('step', sa.SmallInteger), + sa.Column('partition', sa.SmallInteger), + sa.Column('linegeo', Geometry(srid=4326, spatial_index=False)), + sa.Column('postcode', sa.Text))