move server route creation into async function

This commit is contained in:
Sarah Hoffmann 2024-11-13 21:27:14 +01:00
parent ae8694a6a6
commit 754ff15ebd
2 changed files with 52 additions and 29 deletions

View File

@ -147,12 +147,36 @@ class FileLoggingMiddleware:
f'{resource.name} "{params}"\n')
class APIShutdown:
""" Middleware that closes any open database connections.
class APIMiddleware:
""" Middleware managing the Nominatim database connection.
"""
def __init__(self, api: NominatimAPIAsync) -> None:
self.api = api
def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None:
self.api = NominatimAPIAsync(project_dir, environ)
self.app: Optional[App] = None
@property
def config(self) -> Configuration:
""" Get the configuration for Nominatim.
"""
return self.api.config
def set_app(self, app: App) -> None:
""" Set the Falcon application this middleware is connected to.
"""
self.app = app
async def process_startup(self, *_: Any) -> None:
""" Process the ASGI lifespan startup event.
"""
assert self.app is not None
legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', self.api.config.project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, self.api, formatter)
self.app.add_route(f"/{name}", endpoint)
if legacy_urls:
self.app.add_route(f"/{name}.php", endpoint)
async def process_shutdown(self, *_: Any) -> None:
"""Process the ASGI lifespan shutdown event.
@ -164,28 +188,22 @@ def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> App:
""" Create a Nominatim Falcon ASGI application.
"""
api = NominatimAPIAsync(project_dir, environ)
apimw = APIMiddleware(project_dir, environ)
middleware: List[object] = [APIShutdown(api)]
log_file = api.config.LOG_FILE
middleware: List[object] = [apimw]
log_file = apimw.config.LOG_FILE
if log_file:
middleware.append(FileLoggingMiddleware(log_file))
app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'),
app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware)
apimw.set_app(app)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler)
# different from TimeoutError in Python <= 3.10
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type]
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, api, formatter)
app.add_route(f"/{name}", endpoint)
if legacy_urls:
app.add_route(f"/{name}.php", endpoint)
return app

View File

@ -11,6 +11,7 @@ from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awai
from pathlib import Path
import datetime as dt
import asyncio
import contextlib
from starlette.applications import Starlette
from starlette.routing import Route
@ -66,7 +67,7 @@ class ParamWrapper(ASGIAdaptor):
return cast(Configuration, self.request.app.state.API.config)
def formatting(self) -> FormatDispatcher:
return cast(FormatDispatcher, self.request.app.state.API.formatter)
return cast(FormatDispatcher, self.request.app.state.formatter)
def _wrap_endpoint(func: EndpointFunc)\
@ -132,14 +133,6 @@ def get_application(project_dir: Path,
"""
config = Configuration(project_dir, environ)
routes = []
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
routes.append(Route(f"/{name}.php", endpoint=endpoint))
middleware = []
if config.get_bool('CORS_NOACCESSCONTROL'):
middleware.append(Middleware(CORSMiddleware,
@ -156,14 +149,26 @@ def get_application(project_dir: Path,
asyncio.TimeoutError: timeout_error
}
async def _shutdown() -> None:
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> None:
app.state.API = NominatimAPIAsync(project_dir, environ)
config = app.state.API.config
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
app.routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
yield
await app.state.API.close()
app = Starlette(debug=debug, routes=routes, middleware=middleware,
app = Starlette(debug=debug, middleware=middleware,
exception_handlers=exceptions,
on_shutdown=[_shutdown])
lifespan=lifespan)
app.state.API = NominatimAPIAsync(project_dir, environ)
app.state.formatter = load_format_dispatcher('v1', project_dir)
return app