diff --git a/nominatim/api/connection.py b/nominatim/api/connection.py index bf217314..405213e9 100644 --- a/nominatim/api/connection.py +++ b/nominatim/api/connection.py @@ -9,6 +9,7 @@ Extended SQLAlchemy connection class that also includes access to the schema. """ from typing import cast, Any, Mapping, Sequence, Union, Dict, Optional, Set, \ Awaitable, Callable, TypeVar +import asyncio import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncConnection @@ -34,6 +35,14 @@ class SearchConnection: self.t = tables # pylint: disable=invalid-name self._property_cache = properties self._classtables: Optional[Set[str]] = None + self.query_timeout: Optional[int] = None + + + def set_query_timeout(self, timeout: Optional[int]) -> None: + """ Set the timeout after which a query over this connection + is cancelled. + """ + self.query_timeout = timeout async def scalar(self, sql: sa.sql.base.Executable, @@ -42,7 +51,7 @@ class SearchConnection: """ Execute a 'scalar()' query on the connection. """ log().sql(self.connection, sql, params) - return await self.connection.scalar(sql, params) + return await asyncio.wait_for(self.connection.scalar(sql, params), self.query_timeout) async def execute(self, sql: 'sa.Executable', @@ -51,7 +60,7 @@ class SearchConnection: """ Execute a 'execute()' query on the connection. """ log().sql(self.connection, sql, params) - return await self.connection.execute(sql, params) + return await asyncio.wait_for(self.connection.execute(sql, params), self.query_timeout) async def get_property(self, name: str, cached: bool = True) -> str: diff --git a/nominatim/api/core.py b/nominatim/api/core.py index 1690b9f5..0f1dd715 100644 --- a/nominatim/api/core.py +++ b/nominatim/api/core.py @@ -36,6 +36,8 @@ class NominatimAPIAsync: environ: Optional[Mapping[str, str]] = None, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: self.config = Configuration(project_dir, environ) + self.query_timeout = self.config.get_int('QUERY_TIMEOUT') \ + if self.config.QUERY_TIMEOUT else None self.server_version = 0 if sys.version_info >= (3, 10): @@ -128,6 +130,7 @@ class NominatimAPIAsync: """ try: async with self.begin() as conn: + conn.set_query_timeout(self.query_timeout) status = await get_status(conn) except (PGCORE_ERROR, sa.exc.OperationalError): return StatusResult(700, 'Database connection failed') @@ -142,6 +145,7 @@ class NominatimAPIAsync: """ details = ntyp.LookupDetails.from_kwargs(params) async with self.begin() as conn: + conn.set_query_timeout(self.query_timeout) if details.keywords: await make_query_analyzer(conn) return await get_detailed_place(conn, place, details) @@ -154,6 +158,7 @@ class NominatimAPIAsync: """ details = ntyp.LookupDetails.from_kwargs(params) async with self.begin() as conn: + conn.set_query_timeout(self.query_timeout) if details.keywords: await make_query_analyzer(conn) return SearchResults(filter(None, @@ -173,6 +178,7 @@ class NominatimAPIAsync: details = ntyp.ReverseDetails.from_kwargs(params) async with self.begin() as conn: + conn.set_query_timeout(self.query_timeout) if details.keywords: await make_query_analyzer(conn) geocoder = ReverseGeocoder(conn, details) @@ -187,7 +193,10 @@ class NominatimAPIAsync: raise UsageError('Nothing to search for.') async with self.begin() as conn: - geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params)) + conn.set_query_timeout(self.query_timeout) + geocoder = ForwardGeocoder(conn, ntyp.SearchDetails.from_kwargs(params), + self.config.get_int('REQUEST_TIMEOUT') \ + if self.config.REQUEST_TIMEOUT else None) phrases = [Phrase(PhraseType.NONE, p.strip()) for p in query.split(',')] return await geocoder.lookup(phrases) @@ -204,6 +213,7 @@ class NominatimAPIAsync: """ Find an address using structured search. """ async with self.begin() as conn: + conn.set_query_timeout(self.query_timeout) details = ntyp.SearchDetails.from_kwargs(params) phrases: List[Phrase] = [] @@ -244,7 +254,9 @@ class NominatimAPIAsync: if amenity: details.layers |= ntyp.DataLayer.POI - geocoder = ForwardGeocoder(conn, details) + geocoder = ForwardGeocoder(conn, details, + self.config.get_int('REQUEST_TIMEOUT') \ + if self.config.REQUEST_TIMEOUT else None) return await geocoder.lookup(phrases) @@ -260,6 +272,7 @@ class NominatimAPIAsync: details = ntyp.SearchDetails.from_kwargs(params) async with self.begin() as conn: + conn.set_query_timeout(self.query_timeout) if near_query: phrases = [Phrase(PhraseType.NONE, p) for p in near_query.split(',')] else: @@ -267,7 +280,9 @@ class NominatimAPIAsync: if details.keywords: await make_query_analyzer(conn) - geocoder = ForwardGeocoder(conn, details) + geocoder = ForwardGeocoder(conn, details, + self.config.get_int('REQUEST_TIMEOUT') \ + if self.config.REQUEST_TIMEOUT else None) return await geocoder.lookup_pois(categories, phrases) diff --git a/nominatim/api/search/geocoder.py b/nominatim/api/search/geocoder.py index 564e3d8d..f88bffbd 100644 --- a/nominatim/api/search/geocoder.py +++ b/nominatim/api/search/geocoder.py @@ -9,6 +9,7 @@ Public interface to the search code. """ from typing import List, Any, Optional, Iterator, Tuple import itertools +import datetime as dt from nominatim.api.connection import SearchConnection from nominatim.api.types import SearchDetails @@ -24,9 +25,11 @@ class ForwardGeocoder: """ Main class responsible for place search. """ - def __init__(self, conn: SearchConnection, params: SearchDetails) -> None: + def __init__(self, conn: SearchConnection, + params: SearchDetails, timeout: Optional[int]) -> None: self.conn = conn self.params = params + self.timeout = dt.timedelta(seconds=timeout or 1000000) self.query_analyzer: Optional[AbstractQueryAnalyzer] = None @@ -71,6 +74,7 @@ class ForwardGeocoder: """ log().section('Execute database searches') results = SearchResults() + end_time = dt.datetime.now() + self.timeout num_results = 0 min_ranking = 1000.0 @@ -85,6 +89,8 @@ class ForwardGeocoder: log().result_dump('Results', ((r.accuracy, r) for r in results[num_results:])) num_results = len(results) prev_penalty = search.penalty + if dt.datetime.now() >= end_time: + break if results: min_ranking = min(r.ranking for r in results) diff --git a/nominatim/server/falcon/server.py b/nominatim/server/falcon/server.py index e551e542..f1030f5c 100644 --- a/nominatim/server/falcon/server.py +++ b/nominatim/server/falcon/server.py @@ -37,6 +37,17 @@ async def nominatim_error_handler(req: Request, resp: Response, #pylint: disable resp.content_type = exception.content_type +async def timeout_error_handler(req: Request, resp: Response, #pylint: disable=unused-argument + exception: TimeoutError, #pylint: disable=unused-argument + _: Any) -> None: + """ Special error handler that passes message and content type as + per exception info. + """ + resp.status = 503 + resp.text = "Query took too long to process." + resp.content_type = 'text/plain; charset=utf-8' + + class ParamWrapper(api_impl.ASGIAdaptor): """ Adaptor class for server glue to Falcon framework. """ @@ -139,6 +150,7 @@ def get_application(project_dir: Path, app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'), middleware=middleware) app.add_error_handler(HTTPNominatimError, nominatim_error_handler) + app.add_error_handler(TimeoutError, timeout_error_handler) legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS') for name, func in api_impl.ROUTES: diff --git a/nominatim/server/starlette/server.py b/nominatim/server/starlette/server.py index 5567ac9c..19a9943c 100644 --- a/nominatim/server/starlette/server.py +++ b/nominatim/server/starlette/server.py @@ -7,14 +7,14 @@ """ Server implementation using the starlette webserver framework. """ -from typing import Any, Optional, Mapping, Callable, cast, Coroutine +from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable from pathlib import Path import datetime as dt from starlette.applications import Starlette from starlette.routing import Route from starlette.exceptions import HTTPException -from starlette.responses import Response +from starlette.responses import Response, PlainTextResponse from starlette.requests import Request from starlette.middleware import Middleware from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint @@ -110,6 +110,13 @@ class FileLoggingMiddleware(BaseHTTPMiddleware): return response +async def timeout_error(request: Request, #pylint: disable=unused-argument + _: Exception) -> Response: + """ Error handler for query timeouts. + """ + return PlainTextResponse("Query took too long to process.", status_code=503) + + def get_application(project_dir: Path, environ: Optional[Mapping[str, str]] = None, debug: bool = True) -> Starlette: @@ -136,10 +143,15 @@ def get_application(project_dir: Path, if log_file: middleware.append(Middleware(FileLoggingMiddleware, file_name=log_file)) + exceptions: Dict[Any, Callable[[Request, Exception], Awaitable[Response]]] = { + TimeoutError: timeout_error + } + async def _shutdown() -> None: await app.state.API.close() app = Starlette(debug=debug, routes=routes, middleware=middleware, + exception_handlers=exceptions, on_shutdown=[_shutdown]) app.state.API = NominatimAPIAsync(project_dir, environ) diff --git a/settings/env.defaults b/settings/env.defaults index c4739e78..ff0a7648 100644 --- a/settings/env.defaults +++ b/settings/env.defaults @@ -214,6 +214,16 @@ NOMINATIM_SERVE_LEGACY_URLS=yes # of connections _per worker_. NOMINATIM_API_POOL_SIZE=10 +# Timeout is seconds after which a single query to the database is cancelled. +# The user receives a 503 response, when a query times out. +# When empty, then timeouts are disabled. +NOMINATIM_QUERY_TIMEOUT=10 + +# Maximum time a single request is allowed to take. When the timeout is +# exceeeded, the available results are returned. +# When empty, then timouts are disabled. +NOMINATIM_REQUEST_TIMEOUT=60 + # Search elements just within countries # If, despite not finding a point within the static grid of countries, it # finds a geometry of a region, do not return the geometry. Return "Unable