Merge pull request #3172 from lonvia/query-timeout

Introduce timeouts for queries
This commit is contained in:
Sarah Hoffmann 2023-08-25 10:00:22 +02:00 committed by GitHub
commit d5b6042118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 8 deletions

View File

@ -9,6 +9,7 @@ Extended SQLAlchemy connection class that also includes access to the schema.
"""
from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
Awaitable, Callable, TypeVar
import asyncio
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection
@ -34,6 +35,14 @@ class SearchConnection:
self.t = tables # pylint: disable=invalid-name
self._property_cache = properties
self._classtables: Optional[Set[str]] = None
self.query_timeout: Optional[int] = None
def set_query_timeout(self, timeout: Optional[int]) -> None:
""" Set the timeout after which a query over this connection
is cancelled.
"""
self.query_timeout = timeout
async def scalar(self, sql: sa.sql.base.Executable,
@ -42,7 +51,7 @@ class SearchConnection:
""" Execute a 'scalar()' query on the connection.
"""
log().sql(self.connection, sql, params)
return await self.connection.scalar(sql, params)
return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout)
async def execute(self, sql: 'sa.Executable',
@ -51,7 +60,7 @@ class SearchConnection:
""" Execute a 'execute()' query on the connection.
"""
log().sql(self.connection, sql, params)
return await self.connection.execute(sql, params)
return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout)
async def get_property(self, name: str, cached: bool = True) -> str:

View File

@ -36,6 +36,8 @@ class NominatimAPIAsync:
environ: Optional[Mapping[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.config = Configuration(project_dir, environ)
self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \
if self.config.QUERY_TIMEOUT else None
self.server_version = 0
if sys.version_info >= (3, 10):
@ -128,6 +130,7 @@ class NominatimAPIAsync:
"""
try:
async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
status = await get_status(conn)
except (PGCORE_ERROR, sa.exc.OperationalError):
return StatusResult(700, 'Database connection failed')
@ -142,6 +145,7 @@ class NominatimAPIAsync:
"""
details = ntyp.LookupDetails.from_kwargs(params)
async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if details.keywords:
await make_query_analyzer(conn)
return await get_detailed_place(conn, place, details)
@ -154,6 +158,7 @@ class NominatimAPIAsync:
"""
details = ntyp.LookupDetails.from_kwargs(params)
async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if details.keywords:
await make_query_analyzer(conn)
return SearchResults(filter(None,
@ -173,6 +178,7 @@ class NominatimAPIAsync:
details = ntyp.ReverseDetails.from_kwargs(params)
async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if details.keywords:
await make_query_analyzer(conn)
geocoder = ReverseGeocoder(conn, details)
@ -187,7 +193,10 @@ class NominatimAPIAsync:
raise UsageError('Nothing to search for.')
async with self.begin() as conn:
geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params))
conn.set_query_timeout(self.query_timeout)
geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
self.config.get_int('REQUEST_TIMEOUT') \
if self.config.REQUEST_TIMEOUT else None)
phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
return await geocoder.lookup(phrases)
@ -204,6 +213,7 @@ class NominatimAPIAsync:
""" Find an address using structured search.
"""
async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
details = ntyp.SearchDetails.from_kwargs(params)
phrases: List[Phrase] = []
@ -244,7 +254,9 @@ class NominatimAPIAsync:
if amenity:
details.layers |= ntyp.DataLayer.POI
geocoder = ForwardGeocoder(conn, details)
geocoder = ForwardGeocoder(conn, details,
self.config.get_int('REQUEST_TIMEOUT') \
if self.config.REQUEST_TIMEOUT else None)
return await geocoder.lookup(phrases)
@ -260,6 +272,7 @@ class NominatimAPIAsync:
details = ntyp.SearchDetails.from_kwargs(params)
async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if near_query:
phrases = [Phrase(PhraseType.NONE, p) for p in near_query.split(',')]
else:
@ -267,7 +280,9 @@ class NominatimAPIAsync:
if details.keywords:
await make_query_analyzer(conn)
geocoder = ForwardGeocoder(conn, details)
geocoder = ForwardGeocoder(conn, details,
self.config.get_int('REQUEST_TIMEOUT') \
if self.config.REQUEST_TIMEOUT else None)
return await geocoder.lookup_pois(categories, phrases)

View File

@ -9,6 +9,7 @@ Public interface to the search code.
"""
from typing import List, Any, Optional, Iterator, Tuple
import itertools
import datetime as dt
from nominatim.api.connection import SearchConnection
from nominatim.api.types import SearchDetails
@ -24,9 +25,11 @@ class ForwardGeocoder:
""" Main class responsible for place search.
"""
def __init__(self, conn: SearchConnection, params: SearchDetails) -> None:
def __init__(self, conn: SearchConnection,
params: SearchDetails, timeout: Optional[int]) -> None:
self.conn = conn
self.params = params
self.timeout = dt.timedelta(seconds=timeout or 1000000)
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
@ -71,6 +74,7 @@ class ForwardGeocoder:
"""
log().section('Execute database searches')
results = SearchResults()
end_time = dt.datetime.now() + self.timeout
num_results = 0
min_ranking = 1000.0
@ -85,6 +89,8 @@ class ForwardGeocoder:
log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
num_results = len(results)
prev_penalty = search.penalty
if dt.datetime.now() >= end_time:
break
if results:
min_ranking = min(r.ranking for r in results)

View File

@ -37,6 +37,17 @@ async def nominatim_error_handler(req: Request, resp: Response, #pylint: disable
resp.content_type = exception.content_type
async def timeout_error_handler(req: Request, resp: Response, #pylint: disable=unused-argument
exception: TimeoutError, #pylint: disable=unused-argument
_: Any) -> None:
""" Special error handler that passes message and content type as
per exception info.
"""
resp.status = 503
resp.text = "Query took too long to process."
resp.content_type = 'text/plain; charset=utf-8'
class ParamWrapper(api_impl.ASGIAdaptor):
""" Adaptor class for server glue to Falcon framework.
"""
@ -139,6 +150,7 @@ def get_application(project_dir: Path,
app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler)
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:

View File

@ -7,14 +7,14 @@
"""
Server implementation using the starlette webserver framework.
"""
from typing import Any, Optional, Mapping, Callable, cast, Coroutine
from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
from pathlib import Path
import datetime as dt
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.exceptions import HTTPException
from starlette.responses import Response
from starlette.responses import Response, PlainTextResponse
from starlette.requests import Request
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@ -110,6 +110,13 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
return response
async def timeout_error(request: Request, #pylint: disable=unused-argument
_: Exception) -> Response:
""" Error handler for query timeouts.
"""
return PlainTextResponse("Query took too long to process.", status_code=503)
def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None,
debug: bool = True) -> Starlette:
@ -136,10 +143,15 @@ def get_application(project_dir: Path,
if log_file:
middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
TimeoutError: timeout_error
}
async def _shutdown() -> None:
await app.state.API.close()
app = Starlette(debug=debug, routes=routes, middleware=middleware,
exception_handlers=exceptions,
on_shutdown=[_shutdown])
app.state.API = NominatimAPIAsync(project_dir, environ)

View File

@ -214,6 +214,16 @@ NOMINATIM_SERVE_LEGACY_URLS=yes
# of connections _per worker_.
NOMINATIM_API_POOL_SIZE=10
# Timeout is seconds after which a single query to the database is cancelled.
# The user receives a 503 response, when a query times out.
# When empty, then timeouts are disabled.
NOMINATIM_QUERY_TIMEOUT=10
# Maximum time a single request is allowed to take. When the timeout is
# exceeeded, the available results are returned.
# When empty, then timouts are disabled.
NOMINATIM_REQUEST_TIMEOUT=60
# Search elements just within countries
# If, despite not finding a point within the static grid of countries, it
# finds a geometry of a region, do not return the geometry. Return "Unable