Merge pull request #3172 from lonvia/query-timeout

Introduce timeouts for queries
This commit is contained in:
Sarah Hoffmann 2023-08-25 10:00:22 +02:00 committed by GitHub
commit d5b6042118
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 72 additions and 8 deletions

View File

@ -9,6 +9,7 @@ Extended SQLAlchemy connection class that also includes access to the schema.
""" """
from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \ from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \
Awaitable, Callable, TypeVar Awaitable, Callable, TypeVar
import asyncio
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.asyncio import AsyncConnection
@ -34,6 +35,14 @@ class SearchConnection:
self.t = tables # pylint: disable=invalid-name self.t = tables # pylint: disable=invalid-name
self._property_cache = properties self._property_cache = properties
self._classtables: Optional[Set[str]] = None self._classtables: Optional[Set[str]] = None
self.query_timeout: Optional[int] = None
def set_query_timeout(self, timeout: Optional[int]) -> None:
""" Set the timeout after which a query over this connection
is cancelled.
"""
self.query_timeout = timeout
async def scalar(self, sql: sa.sql.base.Executable, async def scalar(self, sql: sa.sql.base.Executable,
@ -42,7 +51,7 @@ class SearchConnection:
""" Execute a 'scalar()' query on the connection. """ Execute a 'scalar()' query on the connection.
""" """
log().sql(self.connection, sql, params) log().sql(self.connection, sql, params)
return await self.connection.scalar(sql, params) return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout)
async def execute(self, sql: 'sa.Executable', async def execute(self, sql: 'sa.Executable',
@ -51,7 +60,7 @@ class SearchConnection:
""" Execute a 'execute()' query on the connection. """ Execute a 'execute()' query on the connection.
""" """
log().sql(self.connection, sql, params) log().sql(self.connection, sql, params)
return await self.connection.execute(sql, params) return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout)
async def get_property(self, name: str, cached: bool = True) -> str: async def get_property(self, name: str, cached: bool = True) -> str:

View File

@ -36,6 +36,8 @@ class NominatimAPIAsync:
environ: Optional[Mapping[str, str]] = None, environ: Optional[Mapping[str, str]] = None,
loop: Optional[asyncio.AbstractEventLoop] = None) -> None: loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.config = Configuration(project_dir, environ) self.config = Configuration(project_dir, environ)
self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \
if self.config.QUERY_TIMEOUT else None
self.server_version = 0 self.server_version = 0
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
@ -128,6 +130,7 @@ class NominatimAPIAsync:
""" """
try: try:
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
status = await get_status(conn) status = await get_status(conn)
except (PGCORE_ERROR, sa.exc.OperationalError): except (PGCORE_ERROR, sa.exc.OperationalError):
return StatusResult(700, 'Database connection failed') return StatusResult(700, 'Database connection failed')
@ -142,6 +145,7 @@ class NominatimAPIAsync:
""" """
details = ntyp.LookupDetails.from_kwargs(params) details = ntyp.LookupDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if details.keywords: if details.keywords:
await make_query_analyzer(conn) await make_query_analyzer(conn)
return await get_detailed_place(conn, place, details) return await get_detailed_place(conn, place, details)
@ -154,6 +158,7 @@ class NominatimAPIAsync:
""" """
details = ntyp.LookupDetails.from_kwargs(params) details = ntyp.LookupDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if details.keywords: if details.keywords:
await make_query_analyzer(conn) await make_query_analyzer(conn)
return SearchResults(filter(None, return SearchResults(filter(None,
@ -173,6 +178,7 @@ class NominatimAPIAsync:
details = ntyp.ReverseDetails.from_kwargs(params) details = ntyp.ReverseDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if details.keywords: if details.keywords:
await make_query_analyzer(conn) await make_query_analyzer(conn)
geocoder = ReverseGeocoder(conn, details) geocoder = ReverseGeocoder(conn, details)
@ -187,7 +193,10 @@ class NominatimAPIAsync:
raise UsageError('Nothing to search for.') raise UsageError('Nothing to search for.')
async with self.begin() as conn: async with self.begin() as conn:
geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params)) conn.set_query_timeout(self.query_timeout)
geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params),
self.config.get_int('REQUEST_TIMEOUT') \
if self.config.REQUEST_TIMEOUT else None)
phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')] phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')]
return await geocoder.lookup(phrases) return await geocoder.lookup(phrases)
@ -204,6 +213,7 @@ class NominatimAPIAsync:
""" Find an address using structured search. """ Find an address using structured search.
""" """
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
details = ntyp.SearchDetails.from_kwargs(params) details = ntyp.SearchDetails.from_kwargs(params)
phrases: List[Phrase] = [] phrases: List[Phrase] = []
@ -244,7 +254,9 @@ class NominatimAPIAsync:
if amenity: if amenity:
details.layers |= ntyp.DataLayer.POI details.layers |= ntyp.DataLayer.POI
geocoder = ForwardGeocoder(conn, details) geocoder = ForwardGeocoder(conn, details,
self.config.get_int('REQUEST_TIMEOUT') \
if self.config.REQUEST_TIMEOUT else None)
return await geocoder.lookup(phrases) return await geocoder.lookup(phrases)
@ -260,6 +272,7 @@ class NominatimAPIAsync:
details = ntyp.SearchDetails.from_kwargs(params) details = ntyp.SearchDetails.from_kwargs(params)
async with self.begin() as conn: async with self.begin() as conn:
conn.set_query_timeout(self.query_timeout)
if near_query: if near_query:
phrases = [Phrase(PhraseType.NONE, p) for p in near_query.split(',')] phrases = [Phrase(PhraseType.NONE, p) for p in near_query.split(',')]
else: else:
@ -267,7 +280,9 @@ class NominatimAPIAsync:
if details.keywords: if details.keywords:
await make_query_analyzer(conn) await make_query_analyzer(conn)
geocoder = ForwardGeocoder(conn, details) geocoder = ForwardGeocoder(conn, details,
self.config.get_int('REQUEST_TIMEOUT') \
if self.config.REQUEST_TIMEOUT else None)
return await geocoder.lookup_pois(categories, phrases) return await geocoder.lookup_pois(categories, phrases)

View File

@ -9,6 +9,7 @@ Public interface to the search code.
""" """
from typing import List, Any, Optional, Iterator, Tuple from typing import List, Any, Optional, Iterator, Tuple
import itertools import itertools
import datetime as dt
from nominatim.api.connection import SearchConnection from nominatim.api.connection import SearchConnection
from nominatim.api.types import SearchDetails from nominatim.api.types import SearchDetails
@ -24,9 +25,11 @@ class ForwardGeocoder:
""" Main class responsible for place search. """ Main class responsible for place search.
""" """
def __init__(self, conn: SearchConnection, params: SearchDetails) -> None: def __init__(self, conn: SearchConnection,
params: SearchDetails, timeout: Optional[int]) -> None:
self.conn = conn self.conn = conn
self.params = params self.params = params
self.timeout = dt.timedelta(seconds=timeout or 1000000)
self.query_analyzer: Optional[AbstractQueryAnalyzer] = None self.query_analyzer: Optional[AbstractQueryAnalyzer] = None
@ -71,6 +74,7 @@ class ForwardGeocoder:
""" """
log().section('Execute database searches') log().section('Execute database searches')
results = SearchResults() results = SearchResults()
end_time = dt.datetime.now() + self.timeout
num_results = 0 num_results = 0
min_ranking = 1000.0 min_ranking = 1000.0
@ -85,6 +89,8 @@ class ForwardGeocoder:
log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:])) log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:]))
num_results = len(results) num_results = len(results)
prev_penalty = search.penalty prev_penalty = search.penalty
if dt.datetime.now() >= end_time:
break
if results: if results:
min_ranking = min(r.ranking for r in results) min_ranking = min(r.ranking for r in results)

View File

@ -37,6 +37,17 @@ async def nominatim_error_handler(req: Request, resp: Response, #pylint: disable
resp.content_type = exception.content_type resp.content_type = exception.content_type
async def timeout_error_handler(req: Request, resp: Response, #pylint: disable=unused-argument
exception: TimeoutError, #pylint: disable=unused-argument
_: Any) -> None:
""" Special error handler that passes message and content type as
per exception info.
"""
resp.status = 503
resp.text = "Query took too long to process."
resp.content_type = 'text/plain; charset=utf-8'
class ParamWrapper(api_impl.ASGIAdaptor): class ParamWrapper(api_impl.ASGIAdaptor):
""" Adaptor class for server glue to Falcon framework. """ Adaptor class for server glue to Falcon framework.
""" """
@ -139,6 +150,7 @@ def get_application(project_dir: Path,
app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'), app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware) middleware=middleware)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler) app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler)
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS') legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES: for name, func in api_impl.ROUTES:

View File

@ -7,14 +7,14 @@
""" """
Server implementation using the starlette webserver framework. Server implementation using the starlette webserver framework.
""" """
from typing import Any, Optional, Mapping, Callable, cast, Coroutine from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable
from pathlib import Path from pathlib import Path
import datetime as dt import datetime as dt
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Route from starlette.routing import Route
from starlette.exceptions import HTTPException from starlette.exceptions import HTTPException
from starlette.responses import Response from starlette.responses import Response, PlainTextResponse
from starlette.requests import Request from starlette.requests import Request
from starlette.middleware import Middleware from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
@ -110,6 +110,13 @@ class FileLoggingMiddleware(BaseHTTPMiddleware):
return response return response
async def timeout_error(request: Request, #pylint: disable=unused-argument
_: Exception) -> Response:
""" Error handler for query timeouts.
"""
return PlainTextResponse("Query took too long to process.", status_code=503)
def get_application(project_dir: Path, def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None, environ: Optional[Mapping[str, str]] = None,
debug: bool = True) -> Starlette: debug: bool = True) -> Starlette:
@ -136,10 +143,15 @@ def get_application(project_dir: Path,
if log_file: if log_file:
middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file)) middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file))
exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = {
TimeoutError: timeout_error
}
async def _shutdown() -> None: async def _shutdown() -> None:
await app.state.API.close() await app.state.API.close()
app = Starlette(debug=debug, routes=routes, middleware=middleware, app = Starlette(debug=debug, routes=routes, middleware=middleware,
exception_handlers=exceptions,
on_shutdown=[_shutdown]) on_shutdown=[_shutdown])
app.state.API = NominatimAPIAsync(project_dir, environ) app.state.API = NominatimAPIAsync(project_dir, environ)

View File

@ -214,6 +214,16 @@ NOMINATIM_SERVE_LEGACY_URLS=yes
# of connections _per worker_. # of connections _per worker_.
NOMINATIM_API_POOL_SIZE=10 NOMINATIM_API_POOL_SIZE=10
# Timeout is seconds after which a single query to the database is cancelled.
# The user receives a 503 response, when a query times out.
# When empty, then timeouts are disabled.
NOMINATIM_QUERY_TIMEOUT=10
# Maximum time a single request is allowed to take. When the timeout is
# exceeeded, the available results are returned.
# When empty, then timouts are disabled.
NOMINATIM_REQUEST_TIMEOUT=60
# Search elements just within countries # Search elements just within countries
# If, despite not finding a point within the static grid of countries, it # If, despite not finding a point within the static grid of countries, it
# finds a geometry of a region, do not return the geometry. Return "Unable # finds a geometry of a region, do not return the geometry. Return "Unable