diff --git a/src/nominatim_api/server/falcon/server.py b/src/nominatim_api/server/falcon/server.py index b58c1cfa..13e79311 100644 --- a/src/nominatim_api/server/falcon/server.py +++ b/src/nominatim_api/server/falcon/server.py @@ -147,12 +147,36 @@ class FileLoggingMiddleware: f'{resource.name} "{params}"\n') -class APIShutdown: - """ Middleware that closes any open database connections. +class APIMiddleware: + """ Middleware managing the Nominatim database connection. """ - def __init__(self, api: NominatimAPIAsync) -> None: - self.api = api + def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None: + self.api = NominatimAPIAsync(project_dir, environ) + self.app: Optional[App] = None + + @property + def config(self) -> Configuration: + """ Get the configuration for Nominatim. + """ + return self.api.config + + def set_app(self, app: App) -> None: + """ Set the Falcon application this middleware is connected to. + """ + self.app = app + + async def process_startup(self, *_: Any) -> None: + """ Process the ASGI lifespan startup event. + """ + assert self.app is not None + legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS') + formatter = load_format_dispatcher('v1', self.api.config.project_dir) + for name, func in await api_impl.get_routes(self.api): + endpoint = EndpointWrapper(name, func, self.api, formatter) + self.app.add_route(f"/{name}", endpoint) + if legacy_urls: + self.app.add_route(f"/{name}.php", endpoint) async def process_shutdown(self, *_: Any) -> None: """Process the ASGI lifespan shutdown event. @@ -164,28 +188,22 @@ def get_application(project_dir: Path, environ: Optional[Mapping[str, str]] = None) -> App: """ Create a Nominatim Falcon ASGI application. """ - api = NominatimAPIAsync(project_dir, environ) + apimw = APIMiddleware(project_dir, environ) - middleware: List[object] = [APIShutdown(api)] - log_file = api.config.LOG_FILE + middleware: List[object] = [apimw] + log_file = apimw.config.LOG_FILE if log_file: middleware.append(FileLoggingMiddleware(log_file)) - app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'), + app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'), middleware=middleware) + + apimw.set_app(app) app.add_error_handler(HTTPNominatimError, nominatim_error_handler) app.add_error_handler(TimeoutError, timeout_error_handler) # different from TimeoutError in Python <= 3.10 app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type] - legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS') - formatter = load_format_dispatcher('v1', project_dir) - for name, func in api_impl.ROUTES: - endpoint = EndpointWrapper(name, func, api, formatter) - app.add_route(f"/{name}", endpoint) - if legacy_urls: - app.add_route(f"/{name}.php", endpoint) - return app diff --git a/src/nominatim_api/server/starlette/server.py b/src/nominatim_api/server/starlette/server.py index 48f0207a..e6c97693 100644 --- a/src/nominatim_api/server/starlette/server.py +++ b/src/nominatim_api/server/starlette/server.py @@ -7,10 +7,12 @@ """ Server implementation using the starlette webserver framework. """ -from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable +from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \ + Awaitable, AsyncIterator from pathlib import Path import datetime as dt import asyncio +import contextlib from starlette.applications import Starlette from starlette.routing import Route @@ -66,7 +68,7 @@ class ParamWrapper(ASGIAdaptor): return cast(Configuration, self.request.app.state.API.config) def formatting(self) -> FormatDispatcher: - return cast(FormatDispatcher, self.request.app.state.API.formatter) + return cast(FormatDispatcher, self.request.app.state.formatter) def _wrap_endpoint(func: EndpointFunc)\ @@ -132,14 +134,6 @@ def get_application(project_dir: Path, """ config = Configuration(project_dir, environ) - routes = [] - legacy_urls = config.get_bool('SERVE_LEGACY_URLS') - for name, func in api_impl.ROUTES: - endpoint = _wrap_endpoint(func) - routes.append(Route(f"/{name}", endpoint=endpoint)) - if legacy_urls: - routes.append(Route(f"/{name}.php", endpoint=endpoint)) - middleware = [] if config.get_bool('CORS_NOACCESSCONTROL'): middleware.append(Middleware(CORSMiddleware, @@ -156,14 +150,26 @@ def get_application(project_dir: Path, asyncio.TimeoutError: timeout_error } - async def _shutdown() -> None: + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> AsyncIterator[Any]: + app.state.API = NominatimAPIAsync(project_dir, environ) + config = app.state.API.config + + legacy_urls = config.get_bool('SERVE_LEGACY_URLS') + for name, func in await api_impl.get_routes(app.state.API): + endpoint = _wrap_endpoint(func) + app.routes.append(Route(f"/{name}", endpoint=endpoint)) + if legacy_urls: + app.routes.append(Route(f"/{name}.php", endpoint=endpoint)) + + yield + await app.state.API.close() - app = Starlette(debug=debug, routes=routes, middleware=middleware, + app = Starlette(debug=debug, middleware=middleware, exception_handlers=exceptions, - on_shutdown=[_shutdown]) + lifespan=lifespan) - app.state.API = NominatimAPIAsync(project_dir, environ) app.state.formatter = load_format_dispatcher('v1', project_dir) return app diff --git a/src/nominatim_api/v1/__init__.py b/src/nominatim_api/v1/__init__.py index f49de2a2..2304994f 100644 --- a/src/nominatim_api/v1/__init__.py +++ b/src/nominatim_api/v1/__init__.py @@ -8,4 +8,4 @@ Implementation of API version v1 (aka the legacy version). """ -from .server_glue import ROUTES as ROUTES +from .server_glue import get_routes as get_routes diff --git a/src/nominatim_api/v1/server_glue.py b/src/nominatim_api/v1/server_glue.py index ee502d8b..a6450bf2 100644 --- a/src/nominatim_api/v1/server_glue.py +++ b/src/nominatim_api/v1/server_glue.py @@ -8,7 +8,7 @@ Generic part of the server implementation of the v1 API. Combine with the scaffolding provided for the various Python ASGI frameworks. """ -from typing import Optional, Any, Type, Dict, cast +from typing import Optional, Any, Type, Dict, cast, Sequence, Tuple from functools import reduce import dataclasses from urllib.parse import urlencode @@ -25,7 +25,8 @@ from ..results import DetailedResult, ReverseResults, SearchResult, SearchResult from ..localization import Locales from . import helpers from ..server import content_types as ct -from ..server.asgi_adaptor import ASGIAdaptor +from ..server.asgi_adaptor import ASGIAdaptor, EndpointFunc +from ..sql.async_core_library import PGCORE_ERROR def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200, @@ -417,12 +418,25 @@ async def polygons_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any: return build_response(params, params.formatting().format_result(results, fmt, {})) -ROUTES = [ - ('status', status_endpoint), - ('details', details_endpoint), - ('reverse', reverse_endpoint), - ('lookup', lookup_endpoint), - ('search', search_endpoint), - ('deletable', deletable_endpoint), - ('polygons', polygons_endpoint), -] +async def get_routes(api: NominatimAPIAsync) -> Sequence[Tuple[str, EndpointFunc]]: + routes = [ + ('status', status_endpoint), + ('details', details_endpoint), + ('reverse', reverse_endpoint), + ('lookup', lookup_endpoint), + ('deletable', deletable_endpoint), + ('polygons', polygons_endpoint), + ] + + def has_search_name(conn: sa.engine.Connection) -> bool: + insp = sa.inspect(conn) + return insp.has_table('search_name') + + try: + async with api.begin() as conn: + if await conn.connection.run_sync(has_search_name): + routes.append(('search', search_endpoint)) + except (PGCORE_ERROR, sa.exc.OperationalError): + pass # ignored + + return routes