implement search builder

This commit is contained in:
Sarah Hoffmann 2023-05-23 11:20:34 +02:00
parent 3bf489cd7c
commit c42273a4db
7 changed files with 1208 additions and 3 deletions

View File

@ -0,0 +1,322 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Convertion from token assignment to an abstract DB search.
"""
from typing import Optional, List, Tuple, Iterator
import heapq
from nominatim.api.types import SearchDetails, DataLayer
from nominatim.api.search.query import QueryStruct, TokenType, TokenRange, BreakType
from nominatim.api.search.token_assignment import TokenAssignment
import nominatim.api.search.db_search_fields as dbf
import nominatim.api.search.db_searches as dbs
from nominatim.api.logging import log
class SearchBuilder:
""" Build the abstract search queries from token assignments.
"""
def __init__(self, query: QueryStruct, details: SearchDetails) -> None:
self.query = query
self.details = details
@property
def configured_for_country(self) -> bool:
""" Return true if the search details are configured to
allow countries in the result.
"""
return self.details.min_rank <= 4 and self.details.max_rank >= 4 \
and self.details.layer_enabled(DataLayer.ADDRESS)
@property
def configured_for_postcode(self) -> bool:
""" Return true if the search details are configured to
allow postcodes in the result.
"""
return self.details.min_rank <= 5 and self.details.max_rank >= 11\
and self.details.layer_enabled(DataLayer.ADDRESS)
@property
def configured_for_housenumbers(self) -> bool:
""" Return true if the search details are configured to
allow addresses in the result.
"""
return self.details.max_rank >= 30 \
and self.details.layer_enabled(DataLayer.ADDRESS)
def build(self, assignment: TokenAssignment) -> Iterator[dbs.AbstractSearch]:
""" Yield all possible abstract searches for the given token assignment.
"""
sdata = self.get_search_data(assignment)
if sdata is None:
return
categories = self.get_search_categories(assignment)
if assignment.name is None:
if categories and not sdata.postcodes:
sdata.qualifiers = categories
categories = None
builder = self.build_poi_search(sdata)
else:
builder = self.build_special_search(sdata, assignment.address,
bool(categories))
else:
builder = self.build_name_search(sdata, assignment.name, assignment.address,
bool(categories))
if categories:
penalty = min(categories.penalties)
categories.penalties = [p - penalty for p in categories.penalties]
for search in builder:
yield dbs.NearSearch(penalty, categories, search)
else:
yield from builder
def build_poi_search(self, sdata: dbf.SearchData) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search query for a simple category search.
This kind of search requires an additional geographic constraint.
"""
if not sdata.housenumbers \
and ((self.details.viewbox and self.details.bounded_viewbox) or self.details.near):
yield dbs.PoiSearch(sdata)
def build_special_search(self, sdata: dbf.SearchData,
address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search queries for searches that do not involve
a named place.
"""
if sdata.qualifiers or sdata.housenumbers:
# No special searches over housenumbers or qualifiers supported.
return
if sdata.countries and not address and not sdata.postcodes \
and self.configured_for_country:
yield dbs.CountrySearch(sdata)
if sdata.postcodes and (is_category or self.configured_for_postcode):
if address:
sdata.lookups = [dbf.FieldLookup('nameaddress_vector',
[t.token for r in address
for t in self.query.get_partials_list(r)],
'restrict')]
yield dbs.PostcodeSearch(0.4, sdata)
def build_name_search(self, sdata: dbf.SearchData,
name: TokenRange, address: List[TokenRange],
is_category: bool) -> Iterator[dbs.AbstractSearch]:
""" Build abstract search queries for simple name or address searches.
"""
if is_category or not sdata.housenumbers or self.configured_for_housenumbers:
sdata.rankings.append(self.get_name_ranking(name))
name_penalty = sdata.rankings[-1].normalize_penalty()
for penalty, count, lookup in self.yield_lookups(name, address):
sdata.lookups = lookup
yield dbs.PlaceSearch(penalty + name_penalty, sdata, count)
def yield_lookups(self, name: TokenRange, address: List[TokenRange])\
-> Iterator[Tuple[float, int, List[dbf.FieldLookup]]]:
""" Yield all variants how the given name and address should best
be searched for. This takes into account how frequent the terms
are and tries to find a lookup that optimizes index use.
"""
penalty = 0.0 # extra penalty currently unused
name_partials = self.query.get_partials_list(name)
exp_name_count = min(t.count for t in name_partials)
addr_partials = []
for trange in address:
addr_partials.extend(self.query.get_partials_list(trange))
addr_tokens = [t.token for t in addr_partials]
partials_indexed = all(t.is_indexed for t in name_partials) \
and all(t.is_indexed for t in addr_partials)
if (len(name_partials) > 3 or exp_name_count < 1000) and partials_indexed:
# Lookup by name partials, use address partials to restrict results.
lookup = [dbf.FieldLookup('name_vector',
[t.token for t in name_partials], 'lookup_all')]
if addr_tokens:
lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
yield penalty, exp_name_count, lookup
return
exp_addr_count = min(t.count for t in addr_partials) if addr_partials else exp_name_count
if exp_addr_count < 1000 and partials_indexed:
# Lookup by address partials and restrict results through name terms.
yield penalty, exp_addr_count,\
[dbf.FieldLookup('name_vector', [t.token for t in name_partials], 'restrict'),
dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all')]
return
# Partial term to frequent. Try looking up by rare full names first.
name_fulls = self.query.get_tokens(name, TokenType.WORD)
rare_names = list(filter(lambda t: t.count < 1000, name_fulls))
# At this point drop unindexed partials from the address.
# This might yield wrong results, nothing we can do about that.
if not partials_indexed:
addr_tokens = [t.token for t in addr_partials if t.is_indexed]
log().var_dump('before', penalty)
penalty += 1.2 * sum(t.penalty for t in addr_partials if not t.is_indexed)
log().var_dump('after', penalty)
if rare_names:
# Any of the full names applies with all of the partials from the address
lookup = [dbf.FieldLookup('name_vector', [t.token for t in rare_names], 'lookup_any')]
if addr_tokens:
lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'restrict'))
yield penalty, sum(t.count for t in rare_names), lookup
# To catch remaining results, lookup by name and address
if all(t.is_indexed for t in name_partials):
lookup = [dbf.FieldLookup('name_vector',
[t.token for t in name_partials], 'lookup_all')]
else:
# we don't have the partials, try with the non-rare names
non_rare_names = [t.token for t in name_fulls if t.count >= 1000]
if not non_rare_names:
return
lookup = [dbf.FieldLookup('name_vector', non_rare_names, 'lookup_any')]
if addr_tokens:
lookup.append(dbf.FieldLookup('nameaddress_vector', addr_tokens, 'lookup_all'))
yield penalty + 0.1 * max(0, 5 - len(name_partials) - len(addr_tokens)),\
min(exp_name_count, exp_addr_count), lookup
def get_name_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
""" Create a ranking expression for a name term in the given range.
"""
name_fulls = self.query.get_tokens(trange, TokenType.WORD)
ranks = [dbf.RankedTokens(t.penalty, [t.token]) for t in name_fulls]
ranks.sort(key=lambda r: r.penalty)
# Fallback, sum of penalty for partials
name_partials = self.query.get_partials_list(trange)
default = sum(t.penalty for t in name_partials) + 0.2
return dbf.FieldRanking('name_vector', default, ranks)
def get_addr_ranking(self, trange: TokenRange) -> dbf.FieldRanking:
""" Create a list of ranking expressions for an address term
for the given ranges.
"""
todo: List[Tuple[int, int, dbf.RankedTokens]] = []
heapq.heappush(todo, (0, trange.start, dbf.RankedTokens(0.0, [])))
ranks: List[dbf.RankedTokens] = []
while todo: # pylint: disable=too-many-nested-blocks
neglen, pos, rank = heapq.heappop(todo)
for tlist in self.query.nodes[pos].starting:
if tlist.ttype in (TokenType.PARTIAL, TokenType.WORD):
if tlist.end < trange.end:
chgpenalty = PENALTY_WORDCHANGE[self.query.nodes[tlist.end].btype]
if tlist.ttype == TokenType.PARTIAL:
penalty = rank.penalty + chgpenalty \
+ max(t.penalty for t in tlist.tokens)
heapq.heappush(todo, (neglen - 1, tlist.end,
dbf.RankedTokens(penalty, rank.tokens)))
else:
for t in tlist.tokens:
heapq.heappush(todo, (neglen - 1, tlist.end,
rank.with_token(t, chgpenalty)))
elif tlist.end == trange.end:
if tlist.ttype == TokenType.PARTIAL:
ranks.append(dbf.RankedTokens(rank.penalty
+ max(t.penalty for t in tlist.tokens),
rank.tokens))
else:
ranks.extend(rank.with_token(t, 0.0) for t in tlist.tokens)
if len(ranks) >= 10:
# Too many variants, bail out and only add
# Worst-case Fallback: sum of penalty of partials
name_partials = self.query.get_partials_list(trange)
default = sum(t.penalty for t in name_partials) + 0.2
ranks.append(dbf.RankedTokens(rank.penalty + default, []))
# Bail out of outer loop
todo.clear()
break
ranks.sort(key=lambda r: len(r.tokens))
default = ranks[0].penalty + 0.3
del ranks[0]
ranks.sort(key=lambda r: r.penalty)
return dbf.FieldRanking('nameaddress_vector', default, ranks)
def get_search_data(self, assignment: TokenAssignment) -> Optional[dbf.SearchData]:
""" Collect the tokens for the non-name search fields in the
assignment.
"""
sdata = dbf.SearchData()
sdata.penalty = assignment.penalty
if assignment.country:
tokens = self.query.get_tokens(assignment.country, TokenType.COUNTRY)
if self.details.countries:
tokens = [t for t in tokens if t.lookup_word in self.details.countries]
if not tokens:
return None
sdata.set_strings('countries', tokens)
elif self.details.countries:
sdata.countries = dbf.WeightedStrings(self.details.countries,
[0.0] * len(self.details.countries))
if assignment.housenumber:
sdata.set_strings('housenumbers',
self.query.get_tokens(assignment.housenumber,
TokenType.HOUSENUMBER))
if assignment.postcode:
sdata.set_strings('postcodes',
self.query.get_tokens(assignment.postcode,
TokenType.POSTCODE))
if assignment.qualifier:
sdata.set_qualifiers(self.query.get_tokens(assignment.qualifier,
TokenType.QUALIFIER))
if assignment.address:
sdata.set_ranking([self.get_addr_ranking(r) for r in assignment.address])
else:
sdata.rankings = []
return sdata
def get_search_categories(self,
assignment: TokenAssignment) -> Optional[dbf.WeightedCategories]:
""" Collect tokens for category search or use the categories
requested per parameter.
Returns None if no category search is requested.
"""
if assignment.category:
tokens = [t for t in self.query.get_tokens(assignment.category,
TokenType.CATEGORY)
if not self.details.categories
or t.get_category() in self.details.categories]
return dbf.WeightedCategories([t.get_category() for t in tokens],
[t.penalty for t in tokens])
if self.details.categories:
return dbf.WeightedCategories(self.details.categories,
[0.0] * len(self.details.categories))
return None
PENALTY_WORDCHANGE = {
BreakType.START: 0.0,
BreakType.END: 0.0,
BreakType.PHRASE: 0.0,
BreakType.WORD: 0.1,
BreakType.PART: 0.2,
BreakType.TOKEN: 0.4
}

View File

@ -0,0 +1,167 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Data structures for more complex fields in abstract search descriptions.
"""
from typing import List, Tuple, cast
import dataclasses
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY
from nominatim.typing import SaFromClause, SaColumn
from nominatim.api.search.query import Token
@dataclasses.dataclass
class WeightedStrings:
""" A list of strings together with a penalty.
"""
values: List[str]
penalties: List[float]
def __bool__(self) -> bool:
return bool(self.values)
@dataclasses.dataclass
class WeightedCategories:
""" A list of class/type tuples together with a penalty.
"""
values: List[Tuple[str, str]]
penalties: List[float]
def __bool__(self) -> bool:
return bool(self.values)
@dataclasses.dataclass(order=True)
class RankedTokens:
""" List of tokens together with the penalty of using it.
"""
penalty: float
tokens: List[int]
def with_token(self, t: Token, transition_penalty: float) -> 'RankedTokens':
""" Create a new RankedTokens list with the given token appended.
The tokens penalty as well as the given transision penalty
are added to the overall penalty.
"""
return RankedTokens(self.penalty + t.penalty + transition_penalty,
self.tokens + [t.token])
@dataclasses.dataclass
class FieldRanking:
""" A list of rankings to be applied sequentially until one matches.
The matched ranking determines the penalty. If none matches a
default penalty is applied.
"""
column: str
default: float
rankings: List[RankedTokens]
def normalize_penalty(self) -> float:
""" Reduce the default and ranking penalties, such that the minimum
penalty is 0. Return the penalty that was subtracted.
"""
if self.rankings:
min_penalty = min(self.default, min(r.penalty for r in self.rankings))
else:
min_penalty = self.default
if min_penalty > 0.0:
self.default -= min_penalty
for ranking in self.rankings:
ranking.penalty -= min_penalty
return min_penalty
def sql_penalty(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the rankings.
"""
assert self.rankings
col = table.c[self.column]
return sa.case(*((col.contains(r.tokens),r.penalty) for r in self.rankings),
else_=self.default)
@dataclasses.dataclass
class FieldLookup:
""" A list of tokens to be searched for. The column names the database
column to search in and the lookup_type the operator that is applied.
'lookup_all' requires all tokens to match. 'lookup_any' requires
one of the tokens to match. 'restrict' requires to match all tokens
but avoids the use of indexes.
"""
column: str
tokens: List[int]
lookup_type: str
def sql_condition(self, table: SaFromClause) -> SaColumn:
""" Create an SQL expression for the given match condition.
"""
col = table.c[self.column]
if self.lookup_type == 'lookup_all':
return col.contains(self.tokens)
if self.lookup_type == 'lookup_any':
return cast(SaColumn, col.overlap(self.tokens))
return sa.func.array_cat(col, sa.text('ARRAY[]::integer[]'),
type_=ARRAY(sa.Integer())).contains(self.tokens)
class SearchData:
""" Search fields derived from query and token assignment
to be used with the SQL queries.
"""
penalty: float
lookups: List[FieldLookup] = []
rankings: List[FieldRanking]
housenumbers: WeightedStrings = WeightedStrings([], [])
postcodes: WeightedStrings = WeightedStrings([], [])
countries: WeightedStrings = WeightedStrings([], [])
qualifiers: WeightedCategories = WeightedCategories([], [])
def set_strings(self, field: str, tokens: List[Token]) -> None:
""" Set on of the WeightedStrings properties from the given
token list. Adapt the global penalty, so that the
minimum penalty is 0.
"""
if tokens:
min_penalty = min(t.penalty for t in tokens)
self.penalty += min_penalty
wstrs = WeightedStrings([t.lookup_word for t in tokens],
[t.penalty - min_penalty for t in tokens])
setattr(self, field, wstrs)
def set_qualifiers(self, tokens: List[Token]) -> None:
""" Set the qulaifier field from the given tokens.
"""
if tokens:
min_penalty = min(t.penalty for t in tokens)
self.penalty += min_penalty
self.qualifiers = WeightedCategories([t.get_category() for t in tokens],
[t.penalty - min_penalty for t in tokens])
def set_ranking(self, rankings: List[FieldRanking]) -> None:
""" Set the list of rankings and normalize the ranking.
"""
self.rankings = []
for ranking in rankings:
if ranking.rankings:
self.penalty += ranking.normalize_penalty()
self.rankings.append(ranking)
else:
self.penalty += ranking.default

View File

@ -0,0 +1,115 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Implementation of the acutal database accesses for forward search.
"""
import abc
from nominatim.api.connection import SearchConnection
from nominatim.api.types import SearchDetails
import nominatim.api.results as nres
from nominatim.api.search.db_search_fields import SearchData, WeightedCategories
class AbstractSearch(abc.ABC):
""" Encapuslation of a single lookup in the database.
"""
def __init__(self, penalty: float) -> None:
self.penalty = penalty
@abc.abstractmethod
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
class NearSearch(AbstractSearch):
""" Category search of a place type near the result of another search.
"""
def __init__(self, penalty: float, categories: WeightedCategories,
search: AbstractSearch) -> None:
super().__init__(penalty)
self.search = search
self.categories = categories
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
return nres.SearchResults([])
class PoiSearch(AbstractSearch):
""" Category search in a geographic area.
"""
def __init__(self, sdata: SearchData) -> None:
super().__init__(sdata.penalty)
self.categories = sdata.qualifiers
self.countries = sdata.countries
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
return nres.SearchResults([])
class CountrySearch(AbstractSearch):
""" Search for a country name or country code.
"""
def __init__(self, sdata: SearchData) -> None:
super().__init__(sdata.penalty)
self.countries = sdata.countries
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
return nres.SearchResults([])
class PostcodeSearch(AbstractSearch):
""" Search for a postcode.
"""
def __init__(self, extra_penalty: float, sdata: SearchData) -> None:
super().__init__(sdata.penalty + extra_penalty)
self.countries = sdata.countries
self.postcodes = sdata.postcodes
self.lookups = sdata.lookups
self.rankings = sdata.rankings
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
return nres.SearchResults([])
class PlaceSearch(AbstractSearch):
""" Generic search for an address or named place.
"""
def __init__(self, extra_penalty: float, sdata: SearchData, expected_count: int) -> None:
super().__init__(sdata.penalty + extra_penalty)
self.countries = sdata.countries
self.postcodes = sdata.postcodes
self.housenumbers = sdata.housenumbers
self.qualifiers = sdata.qualifiers
self.lookups = sdata.lookups
self.rankings = sdata.rankings
self.expected_count = expected_count
async def lookup(self, conn: SearchConnection,
details: SearchDetails) -> nres.SearchResults:
""" Find results for the search in the database.
"""
return nres.SearchResults([])

View File

@ -169,7 +169,10 @@ class QueryNode:
and ending at the node 'end'. Returns 'None' if no such
tokens exist.
"""
return next((t.tokens for t in self.starting if t.end == end and t.ttype == ttype), None)
for tlist in self.starting:
if tlist.end == end and tlist.ttype == ttype:
return tlist.tokens
return None
@dataclasses.dataclass

View File

@ -7,13 +7,18 @@
"""
Complex datatypes used by the Nominatim API.
"""
from typing import Optional, Union, Tuple, NamedTuple, TypeVar, Type, Dict, Any
from typing import Optional, Union, Tuple, NamedTuple, TypeVar, Type, Dict, \
Any, List, Sequence
from collections import abc
import dataclasses
import enum
import math
from struct import unpack
from nominatim.errors import UsageError
# pylint: disable=no-member,too-many-boolean-expressions,too-many-instance-attributes
@dataclasses.dataclass
class PlaceID:
""" Reference an object by Nominatim's internal ID.
@ -85,6 +90,36 @@ class Point(NamedTuple):
return Point(x, y)
@staticmethod
def from_param(inp: Any) -> 'Point':
""" Create a point from an input parameter. The parameter
may be given as a point, a string or a sequence of
strings or floats. Raises a UsageError if the format is
not correct.
"""
if isinstance(inp, Point):
return inp
seq: Sequence[str]
if isinstance(inp, str):
seq = inp.split(',')
elif isinstance(inp, abc.Sequence):
seq = inp
if len(seq) != 2:
raise UsageError('Point parameter needs 2 coordinates.')
try:
x, y = filter(math.isfinite, map(float, seq))
except ValueError as exc:
raise UsageError('Point parameter needs to be numbers.') from exc
if x < -180.0 or x > 180.0 or y < -90.0 or y > 90.0:
raise UsageError('Point coordinates invalid.')
return Point(x, y)
AnyPoint = Union[Point, Tuple[float, float]]
WKB_BBOX_HEADER_LE = b'\x01\x03\x00\x00\x20\xE6\x10\x00\x00\x01\x00\x00\x00\x05\x00\x00\x00'
@ -128,6 +163,12 @@ class Bbox:
return self.coords[2]
def contains(self, pt: Point) -> bool:
""" Check if the point is inside or on the boundary of the box.
"""
return self.coords[0] <= pt[0] and self.coords[1] <= pt[1]\
and self.coords[2] >= pt[0] and self.coords[3] >= pt[1]
@staticmethod
def from_wkb(wkb: Optional[bytes]) -> 'Optional[Bbox]':
""" Create a Bbox from a bounding box polygon as returned by
@ -156,6 +197,38 @@ class Bbox:
pt[0] + buffer, pt[1] + buffer)
@staticmethod
def from_param(inp: Any) -> 'Bbox':
""" Return a Bbox from an input parameter. The box may be
given as a Bbox, a string or a list or strings or integer.
Raises a UsageError if the format is incorrect.
"""
if isinstance(inp, Bbox):
return inp
seq: Sequence[str]
if isinstance(inp, str):
seq = inp.split(',')
elif isinstance(inp, abc.Sequence):
seq = inp
if len(seq) != 4:
raise UsageError('Bounding box parameter needs 4 coordinates.')
try:
x1, y1, x2, y2 = filter(math.isfinite, map(float, seq))
except ValueError as exc:
raise UsageError('Bounding box parameter needs to be numbers.') from exc
if x1 < -180.0 or x1 > 180.0 or y1 < -90.0 or y1 > 90.0 \
or x2 < -180.0 or x2 > 180.0 or y2 < -90.0 or y2 > 90.0:
raise UsageError('Bounding box coordinates invalid.')
if x1 == x2 or y1 == y2:
raise UsageError('Bounding box with invalid parameters.')
return Bbox(min(x1, x2), min(y1, y2), max(x1, x2), max(y1, y2))
class GeometryFormat(enum.Flag):
""" Geometry output formats supported by Nominatim.
"""
@ -176,6 +249,47 @@ class DataLayer(enum.Flag):
NATURAL = enum.auto()
def format_country(cc: Any) -> List[str]:
""" Extract a list of country codes from the input which may be either
a string or list of strings. Filters out all values that are not
a two-letter string.
"""
clist: Sequence[str]
if isinstance(cc, str):
clist = cc.split(',')
elif isinstance(cc, abc.Sequence):
clist = cc
else:
raise UsageError("Parameter 'country' needs to be a comma-separated list "
"or a Python list of strings.")
return [cc.lower() for cc in clist if isinstance(cc, str) and len(cc) == 2]
def format_excluded(ids: Any) -> List[int]:
""" Extract a list of place ids from the input which may be either
a string or a list of strings or ints. Ignores empty value but
throws a UserError on anything that cannot be converted to int.
"""
plist: Sequence[str]
if isinstance(ids, str):
plist = ids.split(',')
elif isinstance(ids, abc.Sequence):
plist = ids
else:
raise UsageError("Parameter 'excluded' needs to be a comma-separated list "
"or a Python list of numbers.")
if any(not isinstance(i, int) or (isinstance(i, str) and not i.isdigit()) for i in plist):
raise UsageError("Parameter 'excluded' only takes place IDs.")
return [int(id) for id in plist if id]
def format_categories(categories: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
""" Extract a list of categories. Currently a noop.
"""
return categories
TParam = TypeVar('TParam', bound='LookupDetails') # pylint: disable=invalid-name
@dataclasses.dataclass
@ -244,3 +358,92 @@ class ReverseDetails(LookupDetails):
layers: DataLayer = DataLayer.ADDRESS | DataLayer.POI
""" Filter which kind of data to include.
"""
@dataclasses.dataclass
class SearchDetails(LookupDetails):
""" Collection of parameters for the search call.
"""
max_results: int = 10
""" Maximum number of results to be returned. The actual number of results
may be less.
"""
min_rank: int = dataclasses.field(default=0,
metadata={'transform': lambda v: max(0, min(v, 30))}
)
""" Lowest address rank to return.
"""
max_rank: int = dataclasses.field(default=30,
metadata={'transform': lambda v: max(0, min(v, 30))}
)
""" Highest address rank to return.
"""
layers: Optional[DataLayer] = None
""" Filter which kind of data to include. When 'None' (the default) then
filtering by layers is disabled.
"""
countries: List[str] = dataclasses.field(default_factory=list,
metadata={'transform': format_country})
""" Restrict search results to the given countries. An empty list (the
default) will disable this filter.
"""
excluded: List[int] = dataclasses.field(default_factory=list,
metadata={'transform': format_excluded})
""" List of OSM objects to exclude from the results. Currenlty only
works when the internal place ID is given.
An empty list (the default) will disable this filter.
"""
viewbox: Optional[Bbox] = dataclasses.field(default=None,
metadata={'transform': Bbox.from_param})
""" Focus the search on a given map area.
"""
bounded_viewbox: bool = False
""" Use 'viewbox' as a filter and restrict results to places within the
given area.
"""
near: Optional[Point] = dataclasses.field(default=None,
metadata={'transform': Point.from_param})
""" Order results by distance to the given point.
"""
near_radius: Optional[float] = None
""" Use near point as a filter and drop results outside the given
radius. Radius is given in degrees WSG84.
"""
categories: List[Tuple[str, str]] = dataclasses.field(default_factory=list,
metadata={'transform': format_categories})
""" Restrict search to places with one of the given class/type categories.
An empty list (the default) will disable this filter.
"""
def __post_init__(self) -> None:
if self.viewbox is not None:
xext = (self.viewbox.maxlon - self.viewbox.minlon)/2
yext = (self.viewbox.maxlat - self.viewbox.minlat)/2
self.viewbox_x2 = Bbox(self.viewbox.minlon - xext, self.viewbox.maxlon - yext,
self.viewbox.maxlon + xext, self.viewbox.maxlat + yext)
def restrict_min_max_rank(self, new_min: int, new_max: int) -> None:
""" Change the min_rank and max_rank fields to respect the
given boundaries.
"""
assert new_min <= new_max
self.min_rank = max(self.min_rank, new_min)
self.max_rank = min(self.max_rank, new_max)
def is_impossible(self) -> bool:
""" Check if the parameter configuration is contradictionary and
cannot yield any results.
"""
return (self.min_rank > self.max_rank
or (self.bounded_viewbox
and self.viewbox is not None and self.near is not None
and self.viewbox.contains(self.near))
or self.layers is not None and not self.layers)
def layer_enabled(self, layer: DataLayer) -> bool:
""" Check if the given layer has been choosen. Also returns
true when layer restriction has been disabled completely.
"""
return self.layers is None or bool(self.layers & layer)

View File

@ -0,0 +1,395 @@
# SPDX-License-Identifier: GPL-3.0-or-later
#
# This file is part of Nominatim. (https://nominatim.org)
#
# Copyright (C) 2023 by the Nominatim developer community.
# For a full list of authors see the git log.
"""
Tests for creating abstract searches from token assignments.
"""
import pytest
from nominatim.api.search.query import Token, TokenRange, BreakType, PhraseType, TokenType, QueryStruct, Phrase
from nominatim.api.search.db_search_builder import SearchBuilder
from nominatim.api.search.token_assignment import TokenAssignment
from nominatim.api.types import SearchDetails
import nominatim.api.search.db_searches as dbs
class MyToken(Token):
def get_category(self):
return 'this', 'that'
def make_query(*args):
q = None
for tlist in args:
if q is None:
q = QueryStruct([Phrase(PhraseType.NONE, '')])
else:
q.add_node(BreakType.WORD, PhraseType.NONE)
start = len(q.nodes) - 1
for end, ttype, tinfo in tlist:
for tid, word in tinfo:
q.add_token(TokenRange(start, end), ttype,
MyToken(0.5 if ttype == TokenType.PARTIAL else 0.0, tid, 1, word, True))
q.add_node(BreakType.END, PhraseType.NONE)
return q
def test_country_search():
q = make_query([(1, TokenType.COUNTRY, [(2, 'de'), (3, 'en')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.CountrySearch)
assert set(search.countries.values) == {'de', 'en'}
def test_country_search_with_country_restriction():
q = make_query([(1, TokenType.COUNTRY, [(2, 'de'), (3, 'en')])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'en,fr'}))
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.CountrySearch)
assert set(search.countries.values) == {'en'}
def test_country_search_with_confllicting_country_restriction():
q = make_query([(1, TokenType.COUNTRY, [(2, 'de'), (3, 'en')])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'fr'}))
searches = list(builder.build(TokenAssignment(country=TokenRange(0, 1))))
assert len(searches) == 0
def test_postcode_search_simple():
q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PostcodeSearch)
assert search.postcodes.values == ['2367']
assert not search.countries.values
assert not search.lookups
assert not search.rankings
def test_postcode_with_country():
q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])],
[(2, TokenType.COUNTRY, [(1, 'xx')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1),
country=TokenRange(1, 2))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PostcodeSearch)
assert search.postcodes.values == ['2367']
assert search.countries.values == ['xx']
assert not search.lookups
assert not search.rankings
def test_postcode_with_address():
q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])],
[(2, TokenType.PARTIAL, [(100, 'word')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1),
address=[TokenRange(1, 2)])))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PostcodeSearch)
assert search.postcodes.values == ['2367']
assert not search.countries
assert search.lookups
assert not search.rankings
def test_postcode_with_address_with_full_word():
q = make_query([(1, TokenType.POSTCODE, [(34, '2367')])],
[(2, TokenType.PARTIAL, [(100, 'word')]),
(2, TokenType.WORD, [(1, 'full')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(postcode=TokenRange(0, 1),
address=[TokenRange(1, 2)])))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PostcodeSearch)
assert search.postcodes.values == ['2367']
assert not search.countries
assert search.lookups
assert len(search.rankings) == 1
@pytest.mark.parametrize('kwargs', [{'viewbox': '0,0,1,1', 'bounded_viewbox': True},
{'near': '10,10'}])
def test_category_only(kwargs):
q = make_query([(1, TokenType.CATEGORY, [(2, 'foo')])])
builder = SearchBuilder(q, SearchDetails.from_kwargs(kwargs))
searches = list(builder.build(TokenAssignment(category=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PoiSearch)
assert search.categories.values == [('this', 'that')]
@pytest.mark.parametrize('kwargs', [{'viewbox': '0,0,1,1'},
{}])
def test_category_skipped(kwargs):
q = make_query([(1, TokenType.CATEGORY, [(2, 'foo')])])
builder = SearchBuilder(q, SearchDetails.from_kwargs(kwargs))
searches = list(builder.build(TokenAssignment(category=TokenRange(0, 1))))
assert len(searches) == 0
def test_name_only_search():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert not search.postcodes.values
assert not search.countries.values
assert not search.housenumbers.values
assert not search.qualifiers.values
assert len(search.lookups) == 1
assert len(search.rankings) == 1
def test_name_with_qualifier():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])],
[(2, TokenType.QUALIFIER, [(55, 'hotel')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1),
qualifier=TokenRange(1, 2))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert not search.postcodes.values
assert not search.countries.values
assert not search.housenumbers.values
assert search.qualifiers.values == [('this', 'that')]
assert len(search.lookups) == 1
assert len(search.rankings) == 1
def test_name_with_housenumber_search():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])],
[(2, TokenType.HOUSENUMBER, [(66, '66')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1),
housenumber=TokenRange(1, 2))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert not search.postcodes.values
assert not search.countries.values
assert search.housenumbers.values == ['66']
assert len(search.lookups) == 1
assert len(search.rankings) == 1
def test_name_and_address():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])],
[(2, TokenType.PARTIAL, [(2, 'b')]),
(2, TokenType.WORD, [(101, 'b')])],
[(3, TokenType.PARTIAL, [(3, 'c')]),
(3, TokenType.WORD, [(102, 'c')])]
)
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1),
address=[TokenRange(1, 2),
TokenRange(2, 3)])))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert not search.postcodes.values
assert not search.countries.values
assert not search.housenumbers.values
assert len(search.lookups) == 2
assert len(search.rankings) == 3
def test_name_and_complex_address():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])],
[(2, TokenType.PARTIAL, [(2, 'b')]),
(3, TokenType.WORD, [(101, 'bc')])],
[(3, TokenType.PARTIAL, [(3, 'c')])],
[(4, TokenType.PARTIAL, [(4, 'd')]),
(4, TokenType.WORD, [(103, 'd')])]
)
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1),
address=[TokenRange(1, 2),
TokenRange(2, 4)])))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert not search.postcodes.values
assert not search.countries.values
assert not search.housenumbers.values
assert len(search.lookups) == 2
assert len(search.rankings) == 2
def test_name_only_near_search():
q = make_query([(1, TokenType.CATEGORY, [(88, 'g')])],
[(2, TokenType.PARTIAL, [(1, 'a')]),
(2, TokenType.WORD, [(100, 'a')])])
builder = SearchBuilder(q, SearchDetails())
searches = list(builder.build(TokenAssignment(name=TokenRange(1, 2),
category=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.NearSearch)
assert isinstance(search.search, dbs.PlaceSearch)
def test_name_only_search_with_category():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'categories': [('foo', 'bar')]}))
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.NearSearch)
assert isinstance(search.search, dbs.PlaceSearch)
def test_name_only_search_with_countries():
q = make_query([(1, TokenType.PARTIAL, [(1, 'a')]),
(1, TokenType.WORD, [(100, 'a')])])
builder = SearchBuilder(q, SearchDetails.from_kwargs({'countries': 'de,en'}))
searches = list(builder.build(TokenAssignment(name=TokenRange(0, 1))))
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert not search.postcodes.values
assert set(search.countries.values) == {'de', 'en'}
assert not search.housenumbers.values
def make_counted_searches(name_part, name_full, address_part, address_full):
q = QueryStruct([Phrase(PhraseType.NONE, '')])
for i in range(2):
q.add_node(BreakType.WORD, PhraseType.NONE)
q.add_node(BreakType.END, PhraseType.NONE)
q.add_token(TokenRange(0, 1), TokenType.PARTIAL,
MyToken(0.5, 1, name_part, 'name_part', True))
q.add_token(TokenRange(0, 1), TokenType.WORD,
MyToken(0, 101, name_full, 'name_full', True))
q.add_token(TokenRange(1, 2), TokenType.PARTIAL,
MyToken(0.5, 2, address_part, 'address_part', True))
q.add_token(TokenRange(1, 2), TokenType.WORD,
MyToken(0, 102, address_full, 'address_full', True))
builder = SearchBuilder(q, SearchDetails())
return list(builder.build(TokenAssignment(name=TokenRange(0, 1),
address=[TokenRange(1, 2)])))
def test_infrequent_partials_in_name():
searches = make_counted_searches(1, 1, 1, 1)
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert len(search.lookups) == 2
assert len(search.rankings) == 2
assert set((l.column, l.lookup_type) for l in search.lookups) == \
{('name_vector', 'lookup_all'), ('nameaddress_vector', 'restrict')}
def test_frequent_partials_in_name_but_not_in_address():
searches = make_counted_searches(10000, 1, 1, 1)
assert len(searches) == 1
search = searches[0]
assert isinstance(search, dbs.PlaceSearch)
assert len(search.lookups) == 2
assert len(search.rankings) == 2
assert set((l.column, l.lookup_type) for l in search.lookups) == \
{('nameaddress_vector', 'lookup_all'), ('name_vector', 'restrict')}
def test_frequent_partials_in_name_and_address():
searches = make_counted_searches(10000, 1, 10000, 1)
assert len(searches) == 2
assert all(isinstance(s, dbs.PlaceSearch) for s in searches)
searches.sort(key=lambda s: s.penalty)
assert set((l.column, l.lookup_type) for l in searches[0].lookups) == \
{('name_vector', 'lookup_any'), ('nameaddress_vector', 'restrict')}
assert set((l.column, l.lookup_type) for l in searches[1].lookups) == \
{('nameaddress_vector', 'lookup_all'), ('name_vector', 'lookup_all')}

View File

@ -29,7 +29,7 @@ def make_query(*args):
start = len(q.nodes) - 1
for end, ttype in tlist:
q.add_token(TokenRange(start, end), ttype, [dummy])
q.add_token(TokenRange(start, end), ttype, dummy)
q.add_node(BreakType.END, PhraseType.NONE)