diff --git a/src/nominatim_api/server/falcon/server.py b/src/nominatim_api/server/falcon/server.py index b58c1cfa..e252e3c8 100644 --- a/src/nominatim_api/server/falcon/server.py +++ b/src/nominatim_api/server/falcon/server.py @@ -147,12 +147,36 @@ class FileLoggingMiddleware: f'{resource.name} "{params}"\n') -class APIShutdown: - """ Middleware that closes any open database connections. +class APIMiddleware: + """ Middleware managing the Nominatim database connection. """ - def __init__(self, api: NominatimAPIAsync) -> None: - self.api = api + def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None: + self.api = NominatimAPIAsync(project_dir, environ) + self.app: Optional[App] = None + + @property + def config(self) -> Configuration: + """ Get the configuration for Nominatim. + """ + return self.api.config + + def set_app(self, app: App) -> None: + """ Set the Falcon application this middleware is connected to. + """ + self.app = app + + async def process_startup(self, *_: Any) -> None: + """ Process the ASGI lifespan startup event. + """ + assert self.app is not None + legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS') + formatter = load_format_dispatcher('v1', self.api.config.project_dir) + for name, func in api_impl.ROUTES: + endpoint = EndpointWrapper(name, func, self.api, formatter) + self.app.add_route(f"/{name}", endpoint) + if legacy_urls: + self.app.add_route(f"/{name}.php", endpoint) async def process_shutdown(self, *_: Any) -> None: """Process the ASGI lifespan shutdown event. @@ -164,28 +188,22 @@ def get_application(project_dir: Path, environ: Optional[Mapping[str, str]] = None) -> App: """ Create a Nominatim Falcon ASGI application. """ - api = NominatimAPIAsync(project_dir, environ) + apimw = APIMiddleware(project_dir, environ) - middleware: List[object] = [APIShutdown(api)] - log_file = api.config.LOG_FILE + middleware: List[object] = [apimw] + log_file = apimw.config.LOG_FILE if log_file: middleware.append(FileLoggingMiddleware(log_file)) - app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'), + app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'), middleware=middleware) + + apimw.set_app(app) app.add_error_handler(HTTPNominatimError, nominatim_error_handler) app.add_error_handler(TimeoutError, timeout_error_handler) # different from TimeoutError in Python <= 3.10 app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type] - legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS') - formatter = load_format_dispatcher('v1', project_dir) - for name, func in api_impl.ROUTES: - endpoint = EndpointWrapper(name, func, api, formatter) - app.add_route(f"/{name}", endpoint) - if legacy_urls: - app.add_route(f"/{name}.php", endpoint) - return app diff --git a/src/nominatim_api/server/starlette/server.py b/src/nominatim_api/server/starlette/server.py index 48f0207a..9d014920 100644 --- a/src/nominatim_api/server/starlette/server.py +++ b/src/nominatim_api/server/starlette/server.py @@ -11,6 +11,7 @@ from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awai from pathlib import Path import datetime as dt import asyncio +import contextlib from starlette.applications import Starlette from starlette.routing import Route @@ -66,7 +67,7 @@ class ParamWrapper(ASGIAdaptor): return cast(Configuration, self.request.app.state.API.config) def formatting(self) -> FormatDispatcher: - return cast(FormatDispatcher, self.request.app.state.API.formatter) + return cast(FormatDispatcher, self.request.app.state.formatter) def _wrap_endpoint(func: EndpointFunc)\ @@ -132,14 +133,6 @@ def get_application(project_dir: Path, """ config = Configuration(project_dir, environ) - routes = [] - legacy_urls = config.get_bool('SERVE_LEGACY_URLS') - for name, func in api_impl.ROUTES: - endpoint = _wrap_endpoint(func) - routes.append(Route(f"/{name}", endpoint=endpoint)) - if legacy_urls: - routes.append(Route(f"/{name}.php", endpoint=endpoint)) - middleware = [] if config.get_bool('CORS_NOACCESSCONTROL'): middleware.append(Middleware(CORSMiddleware, @@ -156,14 +149,26 @@ def get_application(project_dir: Path, asyncio.TimeoutError: timeout_error } - async def _shutdown() -> None: + @contextlib.asynccontextmanager + async def lifespan(app: Starlette) -> None: + app.state.API = NominatimAPIAsync(project_dir, environ) + config = app.state.API.config + + legacy_urls = config.get_bool('SERVE_LEGACY_URLS') + for name, func in api_impl.ROUTES: + endpoint = _wrap_endpoint(func) + app.routes.append(Route(f"/{name}", endpoint=endpoint)) + if legacy_urls: + app.routes.append(Route(f"/{name}.php", endpoint=endpoint)) + + yield + await app.state.API.close() - app = Starlette(debug=debug, routes=routes, middleware=middleware, + app = Starlette(debug=debug, middleware=middleware, exception_handlers=exceptions, - on_shutdown=[_shutdown]) + lifespan=lifespan) - app.state.API = NominatimAPIAsync(project_dir, environ) app.state.formatter = load_format_dispatcher('v1', project_dir) return app