move server route creation into async function

This commit is contained in:
Sarah Hoffmann 2024-11-13 21:27:14 +01:00
parent ae8694a6a6
commit 754ff15ebd
2 changed files with 52 additions and 29 deletions

View File

@ -147,12 +147,36 @@ class FileLoggingMiddleware:
f'{resource.name} "{params}"\n') f'{resource.name} "{params}"\n')
class APIShutdown: class APIMiddleware:
""" Middleware that closes any open database connections. """ Middleware managing the Nominatim database connection.
""" """
def __init__(self, api: NominatimAPIAsync) -> None: def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None:
self.api = api self.api = NominatimAPIAsync(project_dir, environ)
self.app: Optional[App] = None
@property
def config(self) -> Configuration:
""" Get the configuration for Nominatim.
"""
return self.api.config
def set_app(self, app: App) -> None:
""" Set the Falcon application this middleware is connected to.
"""
self.app = app
async def process_startup(self, *_: Any) -> None:
""" Process the ASGI lifespan startup event.
"""
assert self.app is not None
legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', self.api.config.project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, self.api, formatter)
self.app.add_route(f"/{name}", endpoint)
if legacy_urls:
self.app.add_route(f"/{name}.php", endpoint)
async def process_shutdown(self, *_: Any) -> None: async def process_shutdown(self, *_: Any) -> None:
"""Process the ASGI lifespan shutdown event. """Process the ASGI lifespan shutdown event.
@ -164,28 +188,22 @@ def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> App: environ: Optional[Mapping[str, str]] = None) -> App:
""" Create a Nominatim Falcon ASGI application. """ Create a Nominatim Falcon ASGI application.
""" """
api = NominatimAPIAsync(project_dir, environ) apimw = APIMiddleware(project_dir, environ)
middleware: List[object] = [APIShutdown(api)] middleware: List[object] = [apimw]
log_file = api.config.LOG_FILE log_file = apimw.config.LOG_FILE
if log_file: if log_file:
middleware.append(FileLoggingMiddleware(log_file)) middleware.append(FileLoggingMiddleware(log_file))
app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'), app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware) middleware=middleware)
apimw.set_app(app)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler) app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler) app.add_error_handler(TimeoutError, timeout_error_handler)
# different from TimeoutError in Python <= 3.10 # different from TimeoutError in Python <= 3.10
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type] app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type]
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, api, formatter)
app.add_route(f"/{name}", endpoint)
if legacy_urls:
app.add_route(f"/{name}.php", endpoint)
return app return app

View File

@ -11,6 +11,7 @@ from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awai
from pathlib import Path from pathlib import Path
import datetime as dt import datetime as dt
import asyncio import asyncio
import contextlib
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Route from starlette.routing import Route
@ -66,7 +67,7 @@ class ParamWrapper(ASGIAdaptor):
return cast(Configuration, self.request.app.state.API.config) return cast(Configuration, self.request.app.state.API.config)
def formatting(self) -> FormatDispatcher: def formatting(self) -> FormatDispatcher:
return cast(FormatDispatcher, self.request.app.state.API.formatter) return cast(FormatDispatcher, self.request.app.state.formatter)
def _wrap_endpoint(func: EndpointFunc)\ def _wrap_endpoint(func: EndpointFunc)\
@ -132,14 +133,6 @@ def get_application(project_dir: Path,
""" """
config = Configuration(project_dir, environ) config = Configuration(project_dir, environ)
routes = []
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
routes.append(Route(f"/{name}.php", endpoint=endpoint))
middleware = [] middleware = []
if config.get_bool('CORS_NOACCESSCONTROL'): if config.get_bool('CORS_NOACCESSCONTROL'):
middleware.append(Middleware(CORSMiddleware, middleware.append(Middleware(CORSMiddleware,
@ -156,14 +149,26 @@ def get_application(project_dir: Path,
asyncio.TimeoutError: timeout_error asyncio.TimeoutError: timeout_error
} }
async def _shutdown() -> None: @contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> None:
app.state.API = NominatimAPIAsync(project_dir, environ)
config = app.state.API.config
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
app.routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
yield
await app.state.API.close() await app.state.API.close()
app = Starlette(debug=debug, routes=routes, middleware=middleware, app = Starlette(debug=debug, middleware=middleware,
exception_handlers=exceptions, exception_handlers=exceptions,
on_shutdown=[_shutdown]) lifespan=lifespan)
app.state.API = NominatimAPIAsync(project_dir, environ)
app.state.formatter = load_format_dispatcher('v1', project_dir) app.state.formatter = load_format_dispatcher('v1', project_dir)
return app return app