From dc99bbb0afb7632b5497d4f6f41acff5b2e635e2 Mon Sep 17 00:00:00 2001 From: Sarah Hoffmann Date: Wed, 24 May 2023 13:52:31 +0200 Subject: [PATCH] implement actual database searches --- .pylintrc | 2 +- nominatim/api/connection.py | 28 +- nominatim/api/results.py | 26 + nominatim/api/search/db_search_fields.py | 49 +- nominatim/api/search/db_searches.py | 598 +++++++++++++++++- nominatim/api/types.py | 25 +- nominatim/typing.py | 2 + test/python/api/conftest.py | 30 + test/python/api/search/test_search_country.py | 61 ++ test/python/api/search/test_search_near.py | 102 +++ test/python/api/search/test_search_places.py | 385 +++++++++++ test/python/api/search/test_search_poi.py | 108 ++++ .../python/api/search/test_search_postcode.py | 97 +++ 13 files changed, 1502 insertions(+), 11 deletions(-) create mode 100644 test/python/api/search/test_search_country.py create mode 100644 test/python/api/search/test_search_near.py create mode 100644 test/python/api/search/test_search_places.py create mode 100644 test/python/api/search/test_search_poi.py create mode 100644 test/python/api/search/test_search_postcode.py diff --git a/.pylintrc b/.pylintrc index f2d3491f..c1384c00 100644 --- a/.pylintrc +++ b/.pylintrc @@ -15,4 +15,4 @@ ignored-classes=NominatimArgs,closing # typed Python is enabled. See also https://github.com/PyCQA/pylint/issues/5273 disable=too-few-public-methods,duplicate-code,too-many-ancestors,bad-option-value,no-self-use,not-context-manager,use-dict-literal,chained-comparison,attribute-defined-outside-init -good-names=i,j,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v +good-names=i,j,x,y,m,t,fd,db,cc,x1,x2,y1,y2,pt,k,v,nr diff --git a/nominatim/api/connection.py b/nominatim/api/connection.py index efa4490e..e157d062 100644 --- a/nominatim/api/connection.py +++ b/nominatim/api/connection.py @@ -7,11 +7,13 @@ """ Extended SQLAlchemy connection class that also includes access to the schema. """ -from typing import Any, Mapping, Sequence, Union, Dict, cast +from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set import sqlalchemy as sa +from geoalchemy2 import Geometry from sqlalchemy.ext.asyncio import AsyncConnection +from nominatim.typing import SaFromClause from nominatim.db.sqlalchemy_schema import SearchTables from nominatim.api.logging import log @@ -28,6 +30,7 @@ class SearchConnection: self.connection = conn self.t = tables # pylint: disable=invalid-name self._property_cache = properties + self._classtables: Optional[Set[str]] = None async def scalar(self, sql: sa.sql.base.Executable, @@ -87,3 +90,26 @@ class SearchConnection: raise ValueError(f"DB setting '{name}' not found in database.") return self._property_cache['DB:server_version'] + + + async def get_class_table(self, cls: str, typ: str) -> Optional[SaFromClause]: + """ Lookup up if there is a classtype table for the given category + and return a SQLAlchemy table for it, if it exists. + """ + if self._classtables is None: + res = await self.execute(sa.text("""SELECT tablename FROM pg_tables + WHERE tablename LIKE 'place_classtype_%' + """)) + self._classtables = {r[0] for r in res} + + tablename = f"place_classtype_{cls}_{typ}" + + if tablename not in self._classtables: + return None + + if tablename in self.t.meta.tables: + return self.t.meta.tables[tablename] + + return sa.Table(tablename, self.t.meta, + sa.Column('place_id', sa.BigInteger), + sa.Column('centroid', Geometry(srid=4326, spatial_index=False))) diff --git a/nominatim/api/results.py b/nominatim/api/results.py index 56243e8d..1c313398 100644 --- a/nominatim/api/results.py +++ b/nominatim/api/results.py @@ -179,6 +179,15 @@ class SearchResult(BaseResult): """ A search result for forward geocoding. """ bbox: Optional[Bbox] = None + accuracy: float = 0.0 + + + @property + def ranking(self) -> float: + """ Return the ranking, a combined measure of accuracy and importance. + """ + return (self.accuracy if self.accuracy is not None else 1) \ + - self.calculated_importance() class SearchResults(List[SearchResult]): @@ -306,6 +315,23 @@ def create_from_postcode_row(row: Optional[SaRow], geometry=_filter_geometries(row)) +def create_from_country_row(row: Optional[SaRow], + class_type: Type[BaseResultT]) -> Optional[BaseResultT]: + """ Construct a new result and add the data from the result row + from the fallback country tables. 'class_type' defines + the type of result to return. Returns None if the row is None. + """ + if row is None: + return None + + return class_type(source_table=SourceTable.COUNTRY, + category=('place', 'country'), + centroid=Point.from_wkb(row.centroid.data), + names=row.name, + rank_address=4, rank_search=4, + country_code=row.country_code) + + async def add_result_details(conn: SearchConnection, result: BaseResult, details: LookupDetails) -> None: """ Retrieve more details from the database according to the diff --git a/nominatim/api/search/db_search_fields.py b/nominatim/api/search/db_search_fields.py index 9fcc2c4e..325e08df 100644 --- a/nominatim/api/search/db_search_fields.py +++ b/nominatim/api/search/db_search_fields.py @@ -7,13 +7,13 @@ """ Data structures for more complex fields in abstract search descriptions. """ -from typing import List, Tuple, cast +from typing import List, Tuple, Iterator, cast import dataclasses import sqlalchemy as sa from sqlalchemy.dialects.postgresql import ARRAY -from nominatim.typing import SaFromClause, SaColumn +from nominatim.typing import SaFromClause, SaColumn, SaExpression from nominatim.api.search.query import Token @dataclasses.dataclass @@ -27,6 +27,21 @@ class WeightedStrings: return bool(self.values) + def __iter__(self) -> Iterator[Tuple[str, float]]: + return iter(zip(self.values, self.penalties)) + + + def get_penalty(self, value: str, default: float = 1000.0) -> float: + """ Get the penalty for the given value. Returns the given default + if the value does not exist. + """ + try: + return self.penalties[self.values.index(value)] + except ValueError: + pass + return default + + @dataclasses.dataclass class WeightedCategories: """ A list of class/type tuples together with a penalty. @@ -38,6 +53,36 @@ class WeightedCategories: return bool(self.values) + def __iter__(self) -> Iterator[Tuple[Tuple[str, str], float]]: + return iter(zip(self.values, self.penalties)) + + + def get_penalty(self, value: Tuple[str, str], default: float = 1000.0) -> float: + """ Get the penalty for the given value. Returns the given default + if the value does not exist. + """ + try: + return self.penalties[self.values.index(value)] + except ValueError: + pass + return default + + + def sql_restrict(self, table: SaFromClause) -> SaExpression: + """ Return an SQLAlcheny expression that restricts the + class and type columns of the given table to the values + in the list. + Must not be used with an empty list. + """ + assert self.values + if len(self.values) == 1: + return sa.and_(table.c.class_ == self.values[0][0], + table.c.type == self.values[0][1]) + + return sa.or_(*(sa.and_(table.c.class_ == c, table.c.type == t) + for c, t in self.values)) + + @dataclasses.dataclass(order=True) class RankedTokens: """ List of tokens together with the penalty of using it. diff --git a/nominatim/api/search/db_searches.py b/nominatim/api/search/db_searches.py index f0d75ad1..9a94b4f6 100644 --- a/nominatim/api/search/db_searches.py +++ b/nominatim/api/search/db_searches.py @@ -7,13 +7,181 @@ """ Implementation of the acutal database accesses for forward search. """ +from typing import List, Tuple, AsyncIterator import abc +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import ARRAY, array_agg + +from nominatim.typing import SaFromClause, SaScalarSelect, SaColumn, \ + SaExpression, SaSelect, SaRow from nominatim.api.connection import SearchConnection -from nominatim.api.types import SearchDetails +from nominatim.api.types import SearchDetails, DataLayer, GeometryFormat, Bbox import nominatim.api.results as nres from nominatim.api.search.db_search_fields import SearchData, WeightedCategories +#pylint: disable=singleton-comparison +#pylint: disable=too-many-branches,too-many-arguments,too-many-locals,too-many-statements + +def _select_placex(t: SaFromClause) -> SaSelect: + return sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name, + t.c.class_, t.c.type, + t.c.address, t.c.extratags, + t.c.housenumber, t.c.postcode, t.c.country_code, + t.c.importance, t.c.wikipedia, + t.c.parent_place_id, t.c.rank_address, t.c.rank_search, + t.c.centroid, + t.c.geometry.ST_Expand(0).label('bbox')) + + +def _add_geometry_columns(sql: SaSelect, col: SaColumn, details: SearchDetails) -> SaSelect: + if not details.geometry_output: + return sql + + out = [] + + if details.geometry_simplification > 0.0: + col = col.ST_SimplifyPreserveTopology(details.geometry_simplification) + + if details.geometry_output & GeometryFormat.GEOJSON: + out.append(col.ST_AsGeoJSON().label('geometry_geojson')) + if details.geometry_output & GeometryFormat.TEXT: + out.append(col.ST_AsText().label('geometry_text')) + if details.geometry_output & GeometryFormat.KML: + out.append(col.ST_AsKML().label('geometry_kml')) + if details.geometry_output & GeometryFormat.SVG: + out.append(col.ST_AsSVG().label('geometry_svg')) + + return sql.add_columns(*out) + + +def _make_interpolation_subquery(table: SaFromClause, inner: SaFromClause, + numerals: List[int], details: SearchDetails) -> SaScalarSelect: + all_ids = array_agg(table.c.place_id) # type: ignore[no-untyped-call] + sql = sa.select(all_ids).where(table.c.parent_place_id == inner.c.place_id) + + if len(numerals) == 1: + sql = sql.where(sa.between(numerals[0], table.c.startnumber, table.c.endnumber))\ + .where((numerals[0] - table.c.startnumber) % table.c.step == 0) + else: + sql = sql.where(sa.or_( + *(sa.and_(sa.between(n, table.c.startnumber, table.c.endnumber), + (n - table.c.startnumber) % table.c.step == 0) + for n in numerals))) + + if details.excluded: + sql = sql.where(table.c.place_id.not_in(details.excluded)) + + return sql.scalar_subquery() + + +def _filter_by_layer(table: SaFromClause, layers: DataLayer) -> SaColumn: + orexpr: List[SaExpression] = [] + if layers & DataLayer.ADDRESS and layers & DataLayer.POI: + orexpr.append(table.c.rank_address.between(1, 30)) + elif layers & DataLayer.ADDRESS: + orexpr.append(table.c.rank_address.between(1, 29)) + orexpr.append(sa.and_(table.c.rank_address == 30, + sa.or_(table.c.housenumber != None, + table.c.address.has_key('housename')))) + elif layers & DataLayer.POI: + orexpr.append(sa.and_(table.c.rank_address == 30, + table.c.class_.not_in(('place', 'building')))) + + if layers & DataLayer.MANMADE: + exclude = [] + if not layers & DataLayer.RAILWAY: + exclude.append('railway') + if not layers & DataLayer.NATURAL: + exclude.extend(('natural', 'water', 'waterway')) + orexpr.append(sa.and_(table.c.class_.not_in(tuple(exclude)), + table.c.rank_address == 0)) + else: + include = [] + if layers & DataLayer.RAILWAY: + include.append('railway') + if layers & DataLayer.NATURAL: + include.extend(('natural', 'water', 'waterway')) + orexpr.append(sa.and_(table.c.class_.in_(tuple(include)), + table.c.rank_address == 0)) + + if len(orexpr) == 1: + return orexpr[0] + + return sa.or_(*orexpr) + + +def _interpolated_position(table: SaFromClause, nr: SaColumn) -> SaColumn: + pos = sa.cast(nr - table.c.startnumber, sa.Float) / (table.c.endnumber - table.c.startnumber) + return sa.case( + (table.c.endnumber == table.c.startnumber, table.c.linegeo.ST_Centroid()), + else_=table.c.linegeo.ST_LineInterpolatePoint(pos)).label('centroid') + + +async def _get_placex_housenumbers(conn: SearchConnection, + place_ids: List[int], + details: SearchDetails) -> AsyncIterator[nres.SearchResult]: + t = conn.t.placex + sql = _select_placex(t).where(t.c.place_id.in_(place_ids)) + + sql = _add_geometry_columns(sql, t.c.geometry, details) + + for row in await conn.execute(sql): + result = nres.create_from_placex_row(row, nres.SearchResult) + assert result + result.bbox = Bbox.from_wkb(row.bbox.data) + yield result + + +async def _get_osmline(conn: SearchConnection, place_ids: List[int], + numerals: List[int], + details: SearchDetails) -> AsyncIterator[nres.SearchResult]: + t = conn.t.osmline + values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\ + .data([(n,) for n in numerals]) + sql = sa.select(t.c.place_id, t.c.osm_id, + t.c.parent_place_id, t.c.address, + values.c.nr.label('housenumber'), + _interpolated_position(t, values.c.nr), + t.c.postcode, t.c.country_code)\ + .where(t.c.place_id.in_(place_ids))\ + .join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber)) + + if details.geometry_output: + sub = sql.subquery() + sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details) + + for row in await conn.execute(sql): + result = nres.create_from_osmline_row(row, nres.SearchResult) + assert result + yield result + + +async def _get_tiger(conn: SearchConnection, place_ids: List[int], + numerals: List[int], osm_id: int, + details: SearchDetails) -> AsyncIterator[nres.SearchResult]: + t = conn.t.tiger + values = sa.values(sa.Column('nr', sa.Integer()), name='housenumber')\ + .data([(n,) for n in numerals]) + sql = sa.select(t.c.place_id, t.c.parent_place_id, + sa.literal('W').label('osm_type'), + sa.literal(osm_id).label('osm_id'), + values.c.nr.label('housenumber'), + _interpolated_position(t, values.c.nr), + t.c.postcode)\ + .where(t.c.place_id.in_(place_ids))\ + .join(values, values.c.nr.between(t.c.startnumber, t.c.endnumber)) + + if details.geometry_output: + sub = sql.subquery() + sql = _add_geometry_columns(sa.select(sub), sub.c.centroid, details) + + for row in await conn.execute(sql): + result = nres.create_from_tiger_row(row, nres.SearchResult) + assert result + yield result + + class AbstractSearch(abc.ABC): """ Encapuslation of a single lookup in the database. """ @@ -42,7 +210,79 @@ class NearSearch(AbstractSearch): details: SearchDetails) -> nres.SearchResults: """ Find results for the search in the database. """ - return nres.SearchResults([]) + results = nres.SearchResults() + base = await self.search.lookup(conn, details) + + if not base: + return results + + base.sort(key=lambda r: (r.accuracy, r.rank_search)) + max_accuracy = base[0].accuracy + 0.5 + base = nres.SearchResults(r for r in base if r.source_table == nres.SourceTable.PLACEX + and r.accuracy <= max_accuracy + and r.bbox and r.bbox.area < 20) + + if base: + baseids = [b.place_id for b in base[:5] if b.place_id] + + for category, penalty in self.categories: + await self.lookup_category(results, conn, baseids, category, penalty, details) + if len(results) >= details.max_results: + break + + return results + + + async def lookup_category(self, results: nres.SearchResults, + conn: SearchConnection, ids: List[int], + category: Tuple[str, str], penalty: float, + details: SearchDetails) -> None: + """ Find places of the given category near the list of + place ids and add the results to 'results'. + """ + table = await conn.get_class_table(*category) + + t = conn.t.placex.alias('p') + tgeom = conn.t.placex.alias('pgeom') + + sql = _select_placex(t).where(tgeom.c.place_id.in_(ids))\ + .where(t.c.class_ == category[0])\ + .where(t.c.type == category[1]) + + if table is None: + # No classtype table available, do a simplified lookup in placex. + sql = sql.join(tgeom, t.c.geometry.ST_DWithin(tgeom.c.centroid, 0.01))\ + .order_by(tgeom.c.centroid.ST_Distance(t.c.centroid)) + else: + # Use classtype table. We can afford to use a larger + # radius for the lookup. + sql = sql.join(table, t.c.place_id == table.c.place_id)\ + .join(tgeom, + sa.case((sa.and_(tgeom.c.rank_address < 9, + tgeom.c.geometry.ST_GeometryType().in_( + ('ST_Polygon', 'ST_MultiPolygon'))), + tgeom.c.geometry.ST_Contains(table.c.centroid)), + else_ = tgeom.c.centroid.ST_DWithin(table.c.centroid, 0.05)))\ + .order_by(tgeom.c.centroid.ST_Distance(table.c.centroid)) + + if details.countries: + sql = sql.where(t.c.country_code.in_(details.countries)) + if details.min_rank > 0: + sql = sql.where(t.c.rank_address >= details.min_rank) + if details.max_rank < 30: + sql = sql.where(t.c.rank_address <= details.max_rank) + if details.excluded: + sql = sql.where(t.c.place_id.not_in(details.excluded)) + if details.layers is not None: + sql = sql.where(_filter_by_layer(t, details.layers)) + + for row in await conn.execute(sql.limit(details.max_results)): + result = nres.create_from_placex_row(row, nres.SearchResult) + assert result + result.accuracy = self.penalty + penalty + result.bbox = Bbox.from_wkb(row.bbox.data) + results.append(result) + class PoiSearch(AbstractSearch): @@ -58,7 +298,65 @@ class PoiSearch(AbstractSearch): details: SearchDetails) -> nres.SearchResults: """ Find results for the search in the database. """ - return nres.SearchResults([]) + t = conn.t.placex + + rows: List[SaRow] = [] + + if details.near and details.near_radius is not None and details.near_radius < 0.2: + # simply search in placex table + sql = _select_placex(t) \ + .where(t.c.linked_place_id == None) \ + .where(t.c.geometry.ST_DWithin(details.near.sql_value(), + details.near_radius)) \ + .order_by(t.c.centroid.ST_Distance(details.near.sql_value())) + + if self.countries: + sql = sql.where(t.c.country_code.in_(self.countries.values)) + + if details.viewbox is not None and details.bounded_viewbox: + sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value())) + + classtype = self.categories.values + if len(classtype) == 1: + sql = sql.where(t.c.class_ == classtype[0][0]) \ + .where(t.c.type == classtype[0][1]) + else: + sql = sql.where(sa.or_(*(sa.and_(t.c.class_ == cls, t.c.type == typ) + for cls, typ in classtype))) + + rows.extend(await conn.execute(sql.limit(details.max_results))) + else: + # use the class type tables + for category in self.categories.values: + table = await conn.get_class_table(*category) + if table is not None: + sql = _select_placex(t)\ + .join(table, t.c.place_id == table.c.place_id)\ + .where(t.c.class_ == category[0])\ + .where(t.c.type == category[1]) + + if details.viewbox is not None and details.bounded_viewbox: + sql = sql.where(table.c.centroid.intersects(details.viewbox.sql_value())) + + if details.near: + sql = sql.order_by(table.c.centroid.ST_Distance(details.near.sql_value()))\ + .where(table.c.centroid.ST_DWithin(details.near.sql_value(), + details.near_radius or 0.5)) + + if self.countries: + sql = sql.where(t.c.country_code.in_(self.countries.values)) + + rows.extend(await conn.execute(sql.limit(details.max_results))) + + results = nres.SearchResults() + for row in rows: + result = nres.create_from_placex_row(row, nres.SearchResult) + assert result + result.accuracy = self.penalty + self.categories.get_penalty((row.class_, row.type)) + result.bbox = Bbox.from_wkb(row.bbox.data) + results.append(result) + + return results class CountrySearch(AbstractSearch): @@ -73,7 +371,72 @@ class CountrySearch(AbstractSearch): details: SearchDetails) -> nres.SearchResults: """ Find results for the search in the database. """ - return nres.SearchResults([]) + t = conn.t.placex + + sql = _select_placex(t)\ + .where(t.c.country_code.in_(self.countries.values))\ + .where(t.c.rank_address == 4) + + sql = _add_geometry_columns(sql, t.c.geometry, details) + + if details.excluded: + sql = sql.where(t.c.place_id.not_in(details.excluded)) + + if details.viewbox is not None and details.bounded_viewbox: + sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value())) + + if details.near is not None and details.near_radius is not None: + sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(), + details.near_radius)) + + results = nres.SearchResults() + for row in await conn.execute(sql): + result = nres.create_from_placex_row(row, nres.SearchResult) + assert result + result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0) + results.append(result) + + return results or await self.lookup_in_country_table(conn, details) + + + async def lookup_in_country_table(self, conn: SearchConnection, + details: SearchDetails) -> nres.SearchResults: + """ Look up the country in the fallback country tables. + """ + t = conn.t.country_name + tgrid = conn.t.country_grid + + sql = sa.select(tgrid.c.country_code, + tgrid.c.geometry.ST_Centroid().ST_Collect().ST_Centroid() + .label('centroid'))\ + .where(tgrid.c.country_code.in_(self.countries.values))\ + .group_by(tgrid.c.country_code) + + if details.viewbox is not None and details.bounded_viewbox: + sql = sql.where(tgrid.c.geometry.intersects(details.viewbox.sql_value())) + if details.near is not None and details.near_radius is not None: + sql = sql.where(tgrid.c.geometry.ST_DWithin(details.near.sql_value(), + details.near_radius)) + + sub = sql.subquery('grid') + + sql = sa.select(t.c.country_code, + (t.c.name + + sa.func.coalesce(t.c.derived_name, + sa.cast('', type_=conn.t.types.Composite)) + ).label('name'), + sub.c.centroid)\ + .join(sub, t.c.country_code == sub.c.country_code) + + results = nres.SearchResults() + for row in await conn.execute(sql): + result = nres.create_from_country_row(row, nres.SearchResult) + assert result + result.accuracy = self.penalty + self.countries.get_penalty(row.country_code, 5.0) + results.append(result) + + return results + class PostcodeSearch(AbstractSearch): @@ -91,7 +454,66 @@ class PostcodeSearch(AbstractSearch): details: SearchDetails) -> nres.SearchResults: """ Find results for the search in the database. """ - return nres.SearchResults([]) + t = conn.t.postcode + + sql = sa.select(t.c.place_id, t.c.parent_place_id, + t.c.rank_search, t.c.rank_address, + t.c.postcode, t.c.country_code, + t.c.geometry.label('centroid'))\ + .where(t.c.postcode.in_(self.postcodes.values)) + + sql = _add_geometry_columns(sql, t.c.geometry, details) + + penalty: SaExpression = sa.literal(self.penalty) + + if details.viewbox is not None: + if details.bounded_viewbox: + sql = sql.where(t.c.geometry.intersects(details.viewbox.sql_value())) + else: + penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0), + (t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0), + else_=2.0) + + if details.near is not None: + if details.near_radius is not None: + sql = sql.where(t.c.geometry.ST_DWithin(details.near.sql_value(), + details.near_radius)) + sql = sql.order_by(t.c.geometry.ST_Distance(details.near.sql_value())) + + if self.countries: + sql = sql.where(t.c.country_code.in_(self.countries.values)) + + if details.excluded: + sql = sql.where(t.c.place_id.not_in(details.excluded)) + + if self.lookups: + assert len(self.lookups) == 1 + assert self.lookups[0].lookup_type == 'restrict' + tsearch = conn.t.search_name + sql = sql.where(tsearch.c.place_id == t.c.parent_place_id)\ + .where(sa.func.array_cat(tsearch.c.name_vector, + tsearch.c.nameaddress_vector, + type_=ARRAY(sa.Integer)) + .contains(self.lookups[0].tokens)) + + for ranking in self.rankings: + penalty += ranking.sql_penalty(conn.t.search_name) + penalty += sa.case(*((t.c.postcode == v, p) for v, p in self.postcodes), + else_=1.0) + + + sql = sql.add_columns(penalty.label('accuracy')) + sql = sql.order_by('accuracy') + + results = nres.SearchResults() + for row in await conn.execute(sql.limit(details.max_results)): + result = nres.create_from_postcode_row(row, nres.SearchResult) + assert result + result.accuracy = row.accuracy + results.append(result) + + return results + class PlaceSearch(AbstractSearch): @@ -112,4 +534,168 @@ class PlaceSearch(AbstractSearch): details: SearchDetails) -> nres.SearchResults: """ Find results for the search in the database. """ - return nres.SearchResults([]) + t = conn.t.placex.alias('p') + tsearch = conn.t.search_name.alias('s') + limit = details.max_results + + sql = sa.select(t.c.place_id, t.c.osm_type, t.c.osm_id, t.c.name, + t.c.class_, t.c.type, + t.c.address, t.c.extratags, + t.c.housenumber, t.c.postcode, t.c.country_code, + t.c.wikipedia, + t.c.parent_place_id, t.c.rank_address, t.c.rank_search, + t.c.centroid, + t.c.geometry.ST_Expand(0).label('bbox'))\ + .where(t.c.place_id == tsearch.c.place_id) + + + sql = _add_geometry_columns(sql, t.c.geometry, details) + + penalty: SaExpression = sa.literal(self.penalty) + for ranking in self.rankings: + penalty += ranking.sql_penalty(tsearch) + + for lookup in self.lookups: + sql = sql.where(lookup.sql_condition(tsearch)) + + if self.countries: + sql = sql.where(tsearch.c.country_code.in_(self.countries.values)) + + if self.postcodes: + tpc = conn.t.postcode + if self.expected_count > 1000: + # Many results expected. Restrict by postcode. + sql = sql.where(sa.select(tpc.c.postcode) + .where(tpc.c.postcode.in_(self.postcodes.values)) + .where(tsearch.c.centroid.ST_DWithin(tpc.c.geometry, 0.12)) + .exists()) + + # Less results, only have a preference for close postcodes + pc_near = sa.select(sa.func.min(tpc.c.geometry.ST_Distance(tsearch.c.centroid)))\ + .where(tpc.c.postcode.in_(self.postcodes.values))\ + .scalar_subquery() + penalty += sa.case((t.c.postcode.in_(self.postcodes.values), 0.0), + else_=sa.func.coalesce(pc_near, 2.0)) + + if details.viewbox is not None: + if details.bounded_viewbox: + sql = sql.where(tsearch.c.centroid.intersects(details.viewbox.sql_value())) + else: + penalty += sa.case((t.c.geometry.intersects(details.viewbox.sql_value()), 0.0), + (t.c.geometry.intersects(details.viewbox_x2.sql_value()), 1.0), + else_=2.0) + + if details.near is not None: + if details.near_radius is not None: + sql = sql.where(tsearch.c.centroid.ST_DWithin(details.near.sql_value(), + details.near_radius)) + sql = sql.add_columns(-tsearch.c.centroid.ST_Distance(details.near.sql_value()) + .label('importance')) + sql = sql.order_by(sa.desc(sa.text('importance'))) + else: + sql = sql.order_by(penalty - sa.case((tsearch.c.importance > 0, tsearch.c.importance), + else_=0.75001-(sa.cast(tsearch.c.search_rank, sa.Float())/40))) + sql = sql.add_columns(t.c.importance) + + + sql = sql.add_columns(penalty.label('accuracy'))\ + .order_by(sa.text('accuracy')) + + if self.housenumbers: + hnr_regexp = f"\\m({'|'.join(self.housenumbers.values)})\\M" + sql = sql.where(tsearch.c.address_rank.between(16, 30))\ + .where(sa.or_(tsearch.c.address_rank < 30, + t.c.housenumber.regexp_match(hnr_regexp, flags='i'))) + + # Cross check for housenumbers, need to do that on a rather large + # set. Worst case there are 40.000 main streets in OSM. + inner = sql.limit(10000).subquery() + + # Housenumbers from placex + thnr = conn.t.placex.alias('hnr') + pid_list = array_agg(thnr.c.place_id) # type: ignore[no-untyped-call] + place_sql = sa.select(pid_list)\ + .where(thnr.c.parent_place_id == inner.c.place_id)\ + .where(thnr.c.housenumber.regexp_match(hnr_regexp, flags='i'))\ + .where(thnr.c.linked_place_id == None)\ + .where(thnr.c.indexed_status == 0) + + if details.excluded: + place_sql = place_sql.where(thnr.c.place_id.not_in(details.excluded)) + if self.qualifiers: + place_sql = place_sql.where(self.qualifiers.sql_restrict(thnr)) + + numerals = [int(n) for n in self.housenumbers.values if n.isdigit()] + interpol_sql: SaExpression + tiger_sql: SaExpression + if numerals and \ + (not self.qualifiers or ('place', 'house') in self.qualifiers.values): + # Housenumbers from interpolations + interpol_sql = _make_interpolation_subquery(conn.t.osmline, inner, + numerals, details) + # Housenumbers from Tiger + tiger_sql = sa.case((inner.c.country_code == 'us', + _make_interpolation_subquery(conn.t.tiger, inner, + numerals, details) + ), else_=None) + else: + interpol_sql = sa.literal(None) + tiger_sql = sa.literal(None) + + unsort = sa.select(inner, place_sql.scalar_subquery().label('placex_hnr'), + interpol_sql.label('interpol_hnr'), + tiger_sql.label('tiger_hnr')).subquery('unsort') + sql = sa.select(unsort)\ + .order_by(sa.case((unsort.c.placex_hnr != None, 1), + (unsort.c.interpol_hnr != None, 2), + (unsort.c.tiger_hnr != None, 3), + else_=4), + unsort.c.accuracy) + else: + sql = sql.where(t.c.linked_place_id == None)\ + .where(t.c.indexed_status == 0) + if self.qualifiers: + sql = sql.where(self.qualifiers.sql_restrict(t)) + if details.excluded: + sql = sql.where(tsearch.c.place_id.not_in(details.excluded)) + if details.min_rank > 0: + sql = sql.where(sa.or_(tsearch.c.address_rank >= details.min_rank, + tsearch.c.search_rank >= details.min_rank)) + if details.max_rank < 30: + sql = sql.where(sa.or_(tsearch.c.address_rank <= details.max_rank, + tsearch.c.search_rank <= details.max_rank)) + if details.layers is not None: + sql = sql.where(_filter_by_layer(t, details.layers)) + + + results = nres.SearchResults() + for row in await conn.execute(sql.limit(limit)): + result = nres.create_from_placex_row(row, nres.SearchResult) + assert result + result.bbox = Bbox.from_wkb(row.bbox.data) + result.accuracy = row.accuracy + if not details.excluded or not result.place_id in details.excluded: + results.append(result) + + if self.housenumbers and row.rank_address < 30: + if row.placex_hnr: + subs = _get_placex_housenumbers(conn, row.placex_hnr, details) + elif row.interpol_hnr: + subs = _get_osmline(conn, row.interpol_hnr, numerals, details) + elif row.tiger_hnr: + subs = _get_tiger(conn, row.tiger_hnr, numerals, row.osm_id, details) + else: + subs = None + + if subs is not None: + async for sub in subs: + assert sub.housenumber + sub.accuracy = result.accuracy + if not any(nr in self.housenumbers.values + for nr in sub.housenumber.split(';')): + sub.accuracy += 0.6 + results.append(sub) + + result.accuracy += 1.0 # penalty for missing housenumber + + return results diff --git a/nominatim/api/types.py b/nominatim/api/types.py index ff7457ec..9042e707 100644 --- a/nominatim/api/types.py +++ b/nominatim/api/types.py @@ -15,6 +15,9 @@ import enum import math from struct import unpack +from geoalchemy2 import WKTElement +import geoalchemy2.functions + from nominatim.errors import UsageError # pylint: disable=no-member,too-many-boolean-expressions,too-many-instance-attributes @@ -119,6 +122,12 @@ class Point(NamedTuple): return Point(x, y) + def sql_value(self) -> WKTElement: + """ Create an SQL expression for the point. + """ + return WKTElement(f'POINT({self.x} {self.y})', srid=4326) + + AnyPoint = Union[Point, Tuple[float, float]] @@ -163,12 +172,26 @@ class Bbox: return self.coords[2] + @property + def area(self) -> float: + """ Return the area of the box in WGS84. + """ + return (self.coords[2] - self.coords[0]) * (self.coords[3] - self.coords[1]) + + + def sql_value(self) -> Any: + """ Create an SQL expression for the box. + """ + return geoalchemy2.functions.ST_MakeEnvelope(*self.coords, 4326) + + def contains(self, pt: Point) -> bool: """ Check if the point is inside or on the boundary of the box. """ return self.coords[0] <= pt[0] and self.coords[1] <= pt[1]\ and self.coords[2] >= pt[0] and self.coords[3] >= pt[1] + @staticmethod def from_wkb(wkb: Optional[bytes]) -> 'Optional[Bbox]': """ Create a Bbox from a bounding box polygon as returned by @@ -418,7 +441,7 @@ class SearchDetails(LookupDetails): if self.viewbox is not None: xext = (self.viewbox.maxlon - self.viewbox.minlon)/2 yext = (self.viewbox.maxlat - self.viewbox.minlat)/2 - self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.maxlon - yext, + self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.minlat - yext, self.viewbox.maxlon + xext, self.viewbox.maxlat + yext) diff --git a/nominatim/typing.py b/nominatim/typing.py index bc4c5534..d988fe04 100644 --- a/nominatim/typing.py +++ b/nominatim/typing.py @@ -63,8 +63,10 @@ else: TypeAlias = str SaSelect: TypeAlias = 'sa.Select[Any]' +SaScalarSelect: TypeAlias = 'sa.ScalarSelect[Any]' SaRow: TypeAlias = 'sa.Row[Any]' SaColumn: TypeAlias = 'sa.ColumnElement[Any]' +SaExpression: TypeAlias = 'sa.ColumnElement[bool]' SaLabel: TypeAlias = 'sa.Label[Any]' SaFromClause: TypeAlias = 'sa.FromClause' SaSelectable: TypeAlias = 'sa.Selectable' diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index d8a6dfa0..cfe14e1e 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -12,6 +12,8 @@ import pytest import time import datetime as dt +import sqlalchemy as sa + import nominatim.api as napi from nominatim.db.sql_preprocessor import SQLPreprocessor import nominatim.api.logging as loglib @@ -129,6 +131,34 @@ class APITester: 'geometry': 'SRID=4326;' + geometry}) + def add_country_name(self, country_code, names, partition=0): + self.add_data('country_name', + {'country_code': country_code, + 'name': names, + 'partition': partition}) + + + def add_search_name(self, place_id, **kw): + centroid = kw.get('centroid', (23.0, 34.0)) + self.add_data('search_name', + {'place_id': place_id, + 'importance': kw.get('importance', 0.00001), + 'search_rank': kw.get('search_rank', 30), + 'address_rank': kw.get('address_rank', 30), + 'name_vector': kw.get('names', []), + 'nameaddress_vector': kw.get('address', []), + 'country_code': kw.get('country_code', 'xx'), + 'centroid': 'SRID=4326;POINT(%f %f)' % centroid}) + + + def add_class_type_table(self, cls, typ): + self.async_to_sync( + self.exec_async(sa.text(f"""CREATE TABLE place_classtype_{cls}_{typ} + AS (SELECT place_id, centroid FROM placex + WHERE class = '{cls}' AND type = '{typ}') + """))) + + async def exec_async(self, sql, *args, **kwargs): async with self.api._async_api.begin() as conn: return await conn.execute(sql, *args, **kwargs) diff --git a/test/python/api/search/test_search_country.py b/test/python/api/search/test_search_country.py new file mode 100644 index 00000000..bb0abc39 --- /dev/null +++ b/test/python/api/search/test_search_country.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the country searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import CountrySearch +from nominatim.api.search.db_search_fields import WeightedStrings + + +def run_search(apiobj, global_penalty, ccodes, + country_penalties=None, details=SearchDetails()): + if country_penalties is None: + country_penalties = [0.0] * len(ccodes) + + class MySearchData: + penalty = global_penalty + countries = WeightedStrings(ccodes, country_penalties) + + search = CountrySearch(MySearchData()) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + return apiobj.async_to_sync(run()) + + +def test_find_from_placex(apiobj): + apiobj.add_placex(place_id=55, class_='boundary', type='administrative', + rank_search=4, rank_address=4, + name={'name': 'Lolaland'}, + country_code='yw', + centroid=(10, 10), + geometry='POLYGON((9.5 9.5, 9.5 10.5, 10.5 10.5, 10.5 9.5, 9.5 9.5))') + + results = run_search(apiobj, 0.5, ['de', 'yw'], [0.0, 0.3]) + + assert len(results) == 1 + assert results[0].place_id == 55 + assert results[0].accuracy == 0.8 + +def test_find_from_fallback_countries(apiobj): + apiobj.add_country('ro', 'POLYGON((0 0, 0 1, 1 1, 1 0, 0 0))') + apiobj.add_country_name('ro', {'name': 'România'}) + + results = run_search(apiobj, 0.0, ['ro']) + + assert len(results) == 1 + assert results[0].names == {'name': 'România'} + + +def test_find_none(apiobj): + assert len(run_search(apiobj, 0.0, ['xx'])) == 0 diff --git a/test/python/api/search/test_search_near.py b/test/python/api/search/test_search_near.py new file mode 100644 index 00000000..cfbdadb2 --- /dev/null +++ b/test/python/api/search/test_search_near.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the near searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import NearSearch, PlaceSearch +from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\ + FieldLookup, FieldRanking, RankedTokens + + +def run_search(apiobj, global_penalty, cat, cat_penalty=None, + details=SearchDetails()): + + class PlaceSearchData: + penalty = 0.0 + postcodes = WeightedStrings([], []) + countries = WeightedStrings([], []) + housenumbers = WeightedStrings([], []) + qualifiers = WeightedStrings([], []) + lookups = [FieldLookup('name_vector', [56], 'lookup_all')] + rankings = [] + + place_search = PlaceSearch(0.0, PlaceSearchData(), 2) + + if cat_penalty is None: + cat_penalty = [0.0] * len(cat) + + near_search = NearSearch(0.1, WeightedCategories(cat, cat_penalty), place_search) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await near_search.lookup(conn, details) + + results = apiobj.async_to_sync(run()) + results.sort(key=lambda r: r.accuracy) + + return results + + +def test_no_results_inner_query(apiobj): + assert not run_search(apiobj, 0.4, [('this', 'that')]) + + +class TestNearSearch: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=100, country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_search_name(100, names=[56], country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_placex(place_id=101, country_code='mx', + centroid=(-10.3, 56.9)) + apiobj.add_search_name(101, names=[56], country_code='mx', + centroid=(-10.3, 56.9)) + + + def test_near_in_placex(self, apiobj): + apiobj.add_placex(place_id=22, class_='amenity', type='bank', + centroid=(5.6001, 4.2994)) + apiobj.add_placex(place_id=23, class_='amenity', type='bench', + centroid=(5.6001, 4.2994)) + + results = run_search(apiobj, 0.1, [('amenity', 'bank')]) + + assert [r.place_id for r in results] == [22] + + + def test_multiple_types_near_in_placex(self, apiobj): + apiobj.add_placex(place_id=22, class_='amenity', type='bank', + importance=0.002, + centroid=(5.6001, 4.2994)) + apiobj.add_placex(place_id=23, class_='amenity', type='bench', + importance=0.001, + centroid=(5.6001, 4.2994)) + + results = run_search(apiobj, 0.1, [('amenity', 'bank'), + ('amenity', 'bench')]) + + assert [r.place_id for r in results] == [22, 23] + + + def test_near_in_classtype(self, apiobj): + apiobj.add_placex(place_id=22, class_='amenity', type='bank', + centroid=(5.6, 4.34)) + apiobj.add_placex(place_id=23, class_='amenity', type='bench', + centroid=(5.6, 4.34)) + apiobj.add_class_type_table('amenity', 'bank') + apiobj.add_class_type_table('amenity', 'bench') + + results = run_search(apiobj, 0.1, [('amenity', 'bank')]) + + assert [r.place_id for r in results] == [22] + diff --git a/test/python/api/search/test_search_places.py b/test/python/api/search/test_search_places.py new file mode 100644 index 00000000..df369b81 --- /dev/null +++ b/test/python/api/search/test_search_places.py @@ -0,0 +1,385 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the generic place searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import PlaceSearch +from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories,\ + FieldLookup, FieldRanking, RankedTokens + +def run_search(apiobj, global_penalty, lookup, ranking, count=2, + hnrs=[], pcs=[], ccodes=[], quals=[], + details=SearchDetails()): + class MySearchData: + penalty = global_penalty + postcodes = WeightedStrings(pcs, [0.0] * len(pcs)) + countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) + housenumbers = WeightedStrings(hnrs, [0.0] * len(hnrs)) + qualifiers = WeightedCategories(quals, [0.0] * len(quals)) + lookups = lookup + rankings = ranking + + search = PlaceSearch(0.0, MySearchData(), count) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + results = apiobj.async_to_sync(run()) + results.sort(key=lambda r: r.accuracy) + + return results + + +class TestNameOnlySearches: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=100, country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_search_name(100, names=[1,2,10,11], country_code='us', + centroid=(5.6, 4.3)) + apiobj.add_placex(place_id=101, country_code='mx', + centroid=(-10.3, 56.9)) + apiobj.add_search_name(101, names=[1,2,20,21], country_code='mx', + centroid=(-10.3, 56.9)) + + + @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict']) + @pytest.mark.parametrize('rank,res', [([10], [100, 101]), + ([20], [101, 100])]) + def test_lookup_all_match(self, apiobj, lookup_type, rank, res): + lookup = FieldLookup('name_vector', [1,2], lookup_type) + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert [r.place_id for r in results] == res + + + @pytest.mark.parametrize('lookup_type', ['lookup_all', 'restrict']) + def test_lookup_all_partial_match(self, apiobj, lookup_type): + lookup = FieldLookup('name_vector', [1,20], lookup_type) + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert len(results) == 1 + assert results[0].place_id == 101 + + @pytest.mark.parametrize('rank,res', [([10], [100, 101]), + ([20], [101, 100])]) + def test_lookup_any_match(self, apiobj, rank, res): + lookup = FieldLookup('name_vector', [11,21], 'lookup_any') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, rank)]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert [r.place_id for r in results] == res + + + def test_lookup_any_partial_match(self, apiobj): + lookup = FieldLookup('name_vector', [20], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + + assert len(results) == 1 + assert results[0].place_id == 101 + + + @pytest.mark.parametrize('cc,res', [('us', 100), ('mx', 101)]) + def test_lookup_restrict_country(self, apiobj, cc, res): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], ccodes=[cc]) + + assert [r.place_id for r in results] == [res] + + + def test_lookup_restrict_placeid(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails(excluded=[101])) + + assert [r.place_id for r in results] == [100] + + + @pytest.mark.parametrize('geom', [napi.GeometryFormat.GEOJSON, + napi.GeometryFormat.KML, + napi.GeometryFormat.SVG, + napi.GeometryFormat.TEXT]) + def test_return_geometries(self, apiobj, geom): + lookup = FieldLookup('name_vector', [20], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails(geometry_output=geom)) + + assert geom.name.lower() in results[0].geometry + + + @pytest.mark.parametrize('viewbox', ['5.0,4.0,6.0,5.0', '5.7,4.0,6.0,5.0']) + def test_prefer_viewbox(self, apiobj, viewbox): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + assert [r.place_id for r in results] == [101, 100] + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails.from_kwargs({'viewbox': viewbox})) + assert [r.place_id for r in results] == [100, 101] + + + def test_force_viewbox(self, apiobj): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + + details=SearchDetails.from_kwargs({'viewbox': '5.0,4.0,6.0,5.0', + 'bounded_viewbox': True}) + + results = run_search(apiobj, 0.1, [lookup], [], details=details) + assert [r.place_id for r in results] == [100] + + + def test_prefer_near(self, apiobj): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.9, [RankedTokens(0.0, [21])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking]) + assert [r.place_id for r in results] == [101, 100] + + results = run_search(apiobj, 0.1, [lookup], [ranking], + details=SearchDetails.from_kwargs({'near': '5.6,4.3'})) + results.sort(key=lambda r: -r.importance) + assert [r.place_id for r in results] == [100, 101] + + + def test_force_near(self, apiobj): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + + details=SearchDetails.from_kwargs({'near': '5.6,4.3', + 'near_radius': 0.11}) + + results = run_search(apiobj, 0.1, [lookup], [], details=details) + + assert [r.place_id for r in results] == [100] + + +class TestStreetWithHousenumber: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=1, class_='place', type='house', + parent_place_id=1000, + housenumber='20 a', country_code='es') + apiobj.add_placex(place_id=2, class_='place', type='house', + parent_place_id=1000, + housenumber='21;22', country_code='es') + apiobj.add_placex(place_id=1000, class_='highway', type='residential', + rank_search=26, rank_address=26, + country_code='es') + apiobj.add_search_name(1000, names=[1,2,10,11], + search_rank=26, address_rank=26, + country_code='es') + apiobj.add_placex(place_id=91, class_='place', type='house', + parent_place_id=2000, + housenumber='20', country_code='pt') + apiobj.add_placex(place_id=92, class_='place', type='house', + parent_place_id=2000, + housenumber='22', country_code='pt') + apiobj.add_placex(place_id=93, class_='place', type='house', + parent_place_id=2000, + housenumber='24', country_code='pt') + apiobj.add_placex(place_id=2000, class_='highway', type='residential', + rank_search=26, rank_address=26, + country_code='pt') + apiobj.add_search_name(2000, names=[1,2,20,21], + search_rank=26, address_rank=26, + country_code='pt') + + + @pytest.mark.parametrize('hnr,res', [('20', [91, 1]), ('20 a', [1]), + ('21', [2]), ('22', [2, 92]), + ('24', [93]), ('25', [])]) + def test_lookup_by_single_housenumber(self, apiobj, hnr, res): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=[hnr]) + + assert [r.place_id for r in results] == res + [1000, 2000] + + + @pytest.mark.parametrize('cc,res', [('es', [2, 1000]), ('pt', [92, 2000])]) + def test_lookup_with_country_restriction(self, apiobj, cc, res): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + ccodes=[cc]) + + assert [r.place_id for r in results] == res + + + def test_lookup_exclude_housenumber_placeid(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + details=SearchDetails(excluded=[92])) + + assert [r.place_id for r in results] == [2, 1000, 2000] + + + def test_lookup_exclude_street_placeid(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'lookup_all') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, [lookup], [ranking], hnrs=['22'], + details=SearchDetails(excluded=[1000])) + + assert [r.place_id for r in results] == [2, 92, 2000] + + + @pytest.mark.parametrize('geom', [napi.GeometryFormat.GEOJSON, + napi.GeometryFormat.KML, + napi.GeometryFormat.SVG, + napi.GeometryFormat.TEXT]) + def test_return_geometries(self, apiobj, geom): + lookup = FieldLookup('name_vector', [1, 2], 'lookup_all') + + results = run_search(apiobj, 0.1, [lookup], [], hnrs=['20', '21', '22'], + details=SearchDetails(geometry_output=geom)) + + assert results + assert all(geom.name.lower() in r.geometry for r in results) + + +class TestInterpolations: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=990, class_='highway', type='service', + rank_search=27, rank_address=27, + centroid=(10.0, 10.0), + geometry='LINESTRING(9.995 10, 10.005 10)') + apiobj.add_search_name(990, names=[111], + search_rank=27, address_rank=27) + apiobj.add_placex(place_id=991, class_='place', type='house', + parent_place_id=990, + rank_search=30, rank_address=30, + housenumber='23', + centroid=(10.0, 10.00002)) + apiobj.add_osmline(place_id=992, + parent_place_id=990, + startnumber=21, endnumber=29, step=2, + centroid=(10.0, 10.00001), + geometry='LINESTRING(9.995 10.00001, 10.005 10.00001)') + + + @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) + def test_lookup_housenumber(self, apiobj, hnr, res): + lookup = FieldLookup('name_vector', [111], 'lookup_all') + + results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) + + assert [r.place_id for r in results] == res + [990] + + +class TestTiger: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=990, class_='highway', type='service', + rank_search=27, rank_address=27, + country_code='us', + centroid=(10.0, 10.0), + geometry='LINESTRING(9.995 10, 10.005 10)') + apiobj.add_search_name(990, names=[111], country_code='us', + search_rank=27, address_rank=27) + apiobj.add_placex(place_id=991, class_='place', type='house', + parent_place_id=990, + rank_search=30, rank_address=30, + housenumber='23', + country_code='us', + centroid=(10.0, 10.00002)) + apiobj.add_tiger(place_id=992, + parent_place_id=990, + startnumber=21, endnumber=29, step=2, + centroid=(10.0, 10.00001), + geometry='LINESTRING(9.995 10.00001, 10.005 10.00001)') + + + @pytest.mark.parametrize('hnr,res', [('21', [992]), ('22', []), ('23', [991])]) + def test_lookup_housenumber(self, apiobj, hnr, res): + lookup = FieldLookup('name_vector', [111], 'lookup_all') + + results = run_search(apiobj, 0.1, [lookup], [], hnrs=[hnr]) + + assert [r.place_id for r in results] == res + [990] + + +class TestLayersRank30: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_placex(place_id=223, class_='place', type='house', + housenumber='1', + rank_address=30, + rank_search=30) + apiobj.add_search_name(223, names=[34], + importance=0.0009, + address_rank=30, search_rank=30) + apiobj.add_placex(place_id=224, class_='amenity', type='toilet', + rank_address=30, + rank_search=30) + apiobj.add_search_name(224, names=[34], + importance=0.0008, + address_rank=30, search_rank=30) + apiobj.add_placex(place_id=225, class_='man_made', type='tower', + rank_address=0, + rank_search=30) + apiobj.add_search_name(225, names=[34], + importance=0.0007, + address_rank=0, search_rank=30) + apiobj.add_placex(place_id=226, class_='railway', type='station', + rank_address=0, + rank_search=30) + apiobj.add_search_name(226, names=[34], + importance=0.0006, + address_rank=0, search_rank=30) + apiobj.add_placex(place_id=227, class_='natural', type='cave', + rank_address=0, + rank_search=30) + apiobj.add_search_name(227, names=[34], + importance=0.0005, + address_rank=0, search_rank=30) + + + @pytest.mark.parametrize('layer,res', [(napi.DataLayer.ADDRESS, [223]), + (napi.DataLayer.POI, [224]), + (napi.DataLayer.ADDRESS | napi.DataLayer.POI, [223, 224]), + (napi.DataLayer.MANMADE, [225]), + (napi.DataLayer.RAILWAY, [226]), + (napi.DataLayer.NATURAL, [227]), + (napi.DataLayer.MANMADE | napi.DataLayer.NATURAL, [225, 227]), + (napi.DataLayer.MANMADE | napi.DataLayer.RAILWAY, [225, 226])]) + def test_layers_rank30(self, apiobj, layer, res): + lookup = FieldLookup('name_vector', [34], 'lookup_any') + + results = run_search(apiobj, 0.1, [lookup], [], + details=SearchDetails(layers=layer)) + + assert [r.place_id for r in results] == res diff --git a/test/python/api/search/test_search_poi.py b/test/python/api/search/test_search_poi.py new file mode 100644 index 00000000..b80c0752 --- /dev/null +++ b/test/python/api/search/test_search_poi.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the POI searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import PoiSearch +from nominatim.api.search.db_search_fields import WeightedStrings, WeightedCategories + + +def run_search(apiobj, global_penalty, poitypes, poi_penalties=None, + ccodes=[], details=SearchDetails()): + if poi_penalties is None: + poi_penalties = [0.0] * len(poitypes) + + class MySearchData: + penalty = global_penalty + qualifiers = WeightedCategories(poitypes, poi_penalties) + countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) + + search = PoiSearch(MySearchData()) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + return apiobj.async_to_sync(run()) + + +@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), + ('5.0, 4.59933', 1)]) +def test_simple_near_search_in_placex(apiobj, coord, pid): + apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', + centroid=(5.0, 4.6)) + apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', + centroid=(34.3, 56.1)) + + details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.001}) + + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) + + assert [r.place_id for r in results] == [pid] + + +@pytest.mark.parametrize('coord,pid', [('34.3, 56.100021', 2), + ('34.3, 56.4', 2), + ('5.0, 4.59933', 1)]) +def test_simple_near_search_in_classtype(apiobj, coord, pid): + apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', + centroid=(5.0, 4.6)) + apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', + centroid=(34.3, 56.1)) + apiobj.add_class_type_table('highway', 'bus_stop') + + details = SearchDetails.from_kwargs({'near': coord, 'near_radius': 0.5}) + + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], details=details) + + assert [r.place_id for r in results] == [pid] + + +class TestPoiSearchWithRestrictions: + + @pytest.fixture(autouse=True, params=["placex", "classtype"]) + def fill_database(self, apiobj, request): + apiobj.add_placex(place_id=1, class_='highway', type='bus_stop', + country_code='au', + centroid=(34.3, 56.10003)) + apiobj.add_placex(place_id=2, class_='highway', type='bus_stop', + country_code='nz', + centroid=(34.3, 56.1)) + if request.param == 'classtype': + apiobj.add_class_type_table('highway', 'bus_stop') + self.args = {'near': '34.3, 56.4', 'near_radius': 0.5} + else: + self.args = {'near': '34.3, 56.100021', 'near_radius': 0.001} + + + def test_unrestricted(self, apiobj): + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + details=SearchDetails.from_kwargs(self.args)) + + assert [r.place_id for r in results] == [1, 2] + + + def test_restict_country(self, apiobj): + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + ccodes=['de', 'nz'], + details=SearchDetails.from_kwargs(self.args)) + + assert [r.place_id for r in results] == [2] + + + def test_restrict_by_viewbox(self, apiobj): + args = {'bounded_viewbox': True, 'viewbox': '34.299,56.0,34.3001,56.10001'} + args.update(self.args) + results = run_search(apiobj, 0.1, [('highway', 'bus_stop')], [0.5], + ccodes=['de', 'nz'], + details=SearchDetails.from_kwargs(args)) + + assert [r.place_id for r in results] == [2] diff --git a/test/python/api/search/test_search_postcode.py b/test/python/api/search/test_search_postcode.py new file mode 100644 index 00000000..a43bc897 --- /dev/null +++ b/test/python/api/search/test_search_postcode.py @@ -0,0 +1,97 @@ +# SPDX-License-Identifier: GPL-3.0-or-later +# +# This file is part of Nominatim. (https://nominatim.org) +# +# Copyright (C) 2023 by the Nominatim developer community. +# For a full list of authors see the git log. +""" +Tests for running the postcode searcher. +""" +import pytest + +import nominatim.api as napi +from nominatim.api.types import SearchDetails +from nominatim.api.search.db_searches import PostcodeSearch +from nominatim.api.search.db_search_fields import WeightedStrings, FieldLookup, \ + FieldRanking, RankedTokens + +def run_search(apiobj, global_penalty, pcs, pc_penalties=None, + ccodes=[], lookup=[], ranking=[], details=SearchDetails()): + if pc_penalties is None: + pc_penalties = [0.0] * len(pcs) + + class MySearchData: + penalty = global_penalty + postcodes = WeightedStrings(pcs, pc_penalties) + countries = WeightedStrings(ccodes, [0.0] * len(ccodes)) + lookups = lookup + rankings = ranking + + search = PostcodeSearch(0.0, MySearchData()) + + async def run(): + async with apiobj.api._async_api.begin() as conn: + return await search.lookup(conn, details) + + return apiobj.async_to_sync(run()) + + +def test_postcode_only_search(apiobj): + apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') + apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') + + results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1]) + + assert len(results) == 2 + assert [r.place_id for r in results] == [100, 101] + + +def test_postcode_with_country(apiobj): + apiobj.add_postcode(place_id=100, country_code='ch', postcode='12345') + apiobj.add_postcode(place_id=101, country_code='pl', postcode='12 345') + + results = run_search(apiobj, 0.3, ['12345', '12 345'], [0.0, 0.1], + ccodes=['de', 'pl']) + + assert len(results) == 1 + assert results[0].place_id == 101 + + +class TestPostcodeSearchWithAddress: + + @pytest.fixture(autouse=True) + def fill_database(self, apiobj): + apiobj.add_postcode(place_id=100, country_code='ch', + parent_place_id=1000, postcode='12345') + apiobj.add_postcode(place_id=101, country_code='pl', + parent_place_id=2000, postcode='12345') + apiobj.add_placex(place_id=1000, class_='place', type='village', + rank_search=22, rank_address=22, + country_code='ch') + apiobj.add_search_name(1000, names=[1,2,10,11], + search_rank=22, address_rank=22, + country_code='ch') + apiobj.add_placex(place_id=2000, class_='place', type='village', + rank_search=22, rank_address=22, + country_code='pl') + apiobj.add_search_name(2000, names=[1,2,20,21], + search_rank=22, address_rank=22, + country_code='pl') + + + def test_lookup_both(self, apiobj): + lookup = FieldLookup('name_vector', [1,2], 'restrict') + ranking = FieldRanking('name_vector', 0.3, [RankedTokens(0.0, [10])]) + + results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup], ranking=[ranking]) + + assert [r.place_id for r in results] == [100, 101] + + + def test_restrict_by_name(self, apiobj): + lookup = FieldLookup('name_vector', [10], 'restrict') + + results = run_search(apiobj, 0.1, ['12345'], lookup=[lookup]) + + assert [r.place_id for r in results] == [100] +