Merge pull request #3588 from lonvia/optional-reverse-api

Add support for adding endpoints to server conditionally
This commit is contained in:
Sarah Hoffmann 2024-11-14 19:33:57 +01:00 committed by GitHub
commit 3acd7df5c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 80 additions and 42 deletions

View File

@ -147,12 +147,36 @@ class FileLoggingMiddleware:
f'{resource.name} "{params}"\n') f'{resource.name} "{params}"\n')
class APIShutdown: class APIMiddleware:
""" Middleware that closes any open database connections. """ Middleware managing the Nominatim database connection.
""" """
def __init__(self, api: NominatimAPIAsync) -> None: def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]]) -> None:
self.api = api self.api = NominatimAPIAsync(project_dir, environ)
self.app: Optional[App] = None
@property
def config(self) -> Configuration:
""" Get the configuration for Nominatim.
"""
return self.api.config
def set_app(self, app: App) -> None:
""" Set the Falcon application this middleware is connected to.
"""
self.app = app
async def process_startup(self, *_: Any) -> None:
""" Process the ASGI lifespan startup event.
"""
assert self.app is not None
legacy_urls = self.api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', self.api.config.project_dir)
for name, func in await api_impl.get_routes(self.api):
endpoint = EndpointWrapper(name, func, self.api, formatter)
self.app.add_route(f"/{name}", endpoint)
if legacy_urls:
self.app.add_route(f"/{name}.php", endpoint)
async def process_shutdown(self, *_: Any) -> None: async def process_shutdown(self, *_: Any) -> None:
"""Process the ASGI lifespan shutdown event. """Process the ASGI lifespan shutdown event.
@ -164,28 +188,22 @@ def get_application(project_dir: Path,
environ: Optional[Mapping[str, str]] = None) -> App: environ: Optional[Mapping[str, str]] = None) -> App:
""" Create a Nominatim Falcon ASGI application. """ Create a Nominatim Falcon ASGI application.
""" """
api = NominatimAPIAsync(project_dir, environ) apimw = APIMiddleware(project_dir, environ)
middleware: List[object] = [APIShutdown(api)] middleware: List[object] = [apimw]
log_file = api.config.LOG_FILE log_file = apimw.config.LOG_FILE
if log_file: if log_file:
middleware.append(FileLoggingMiddleware(log_file)) middleware.append(FileLoggingMiddleware(log_file))
app = App(cors_enable=api.config.get_bool('CORS_NOACCESSCONTROL'), app = App(cors_enable=apimw.config.get_bool('CORS_NOACCESSCONTROL'),
middleware=middleware) middleware=middleware)
apimw.set_app(app)
app.add_error_handler(HTTPNominatimError, nominatim_error_handler) app.add_error_handler(HTTPNominatimError, nominatim_error_handler)
app.add_error_handler(TimeoutError, timeout_error_handler) app.add_error_handler(TimeoutError, timeout_error_handler)
# different from TimeoutError in Python <= 3.10 # different from TimeoutError in Python <= 3.10
app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type] app.add_error_handler(asyncio.TimeoutError, timeout_error_handler) # type: ignore[arg-type]
legacy_urls = api.config.get_bool('SERVE_LEGACY_URLS')
formatter = load_format_dispatcher('v1', project_dir)
for name, func in api_impl.ROUTES:
endpoint = EndpointWrapper(name, func, api, formatter)
app.add_route(f"/{name}", endpoint)
if legacy_urls:
app.add_route(f"/{name}.php", endpoint)
return app return app

View File

@ -7,10 +7,12 @@
""" """
Server implementation using the starlette webserver framework. Server implementation using the starlette webserver framework.
""" """
from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, Awaitable from typing import Any, Optional, Mapping, Callable, cast, Coroutine, Dict, \
Awaitable, AsyncIterator
from pathlib import Path from pathlib import Path
import datetime as dt import datetime as dt
import asyncio import asyncio
import contextlib
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Route from starlette.routing import Route
@ -66,7 +68,7 @@ class ParamWrapper(ASGIAdaptor):
return cast(Configuration, self.request.app.state.API.config) return cast(Configuration, self.request.app.state.API.config)
def formatting(self) -> FormatDispatcher: def formatting(self) -> FormatDispatcher:
return cast(FormatDispatcher, self.request.app.state.API.formatter) return cast(FormatDispatcher, self.request.app.state.formatter)
def _wrap_endpoint(func: EndpointFunc)\ def _wrap_endpoint(func: EndpointFunc)\
@ -132,14 +134,6 @@ def get_application(project_dir: Path,
""" """
config = Configuration(project_dir, environ) config = Configuration(project_dir, environ)
routes = []
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in api_impl.ROUTES:
endpoint = _wrap_endpoint(func)
routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
routes.append(Route(f"/{name}.php", endpoint=endpoint))
middleware = [] middleware = []
if config.get_bool('CORS_NOACCESSCONTROL'): if config.get_bool('CORS_NOACCESSCONTROL'):
middleware.append(Middleware(CORSMiddleware, middleware.append(Middleware(CORSMiddleware,
@ -156,14 +150,26 @@ def get_application(project_dir: Path,
asyncio.TimeoutError: timeout_error asyncio.TimeoutError: timeout_error
} }
async def _shutdown() -> None: @contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[Any]:
app.state.API = NominatimAPIAsync(project_dir, environ)
config = app.state.API.config
legacy_urls = config.get_bool('SERVE_LEGACY_URLS')
for name, func in await api_impl.get_routes(app.state.API):
endpoint = _wrap_endpoint(func)
app.routes.append(Route(f"/{name}", endpoint=endpoint))
if legacy_urls:
app.routes.append(Route(f"/{name}.php", endpoint=endpoint))
yield
await app.state.API.close() await app.state.API.close()
app = Starlette(debug=debug, routes=routes, middleware=middleware, app = Starlette(debug=debug, middleware=middleware,
exception_handlers=exceptions, exception_handlers=exceptions,
on_shutdown=[_shutdown]) lifespan=lifespan)
app.state.API = NominatimAPIAsync(project_dir, environ)
app.state.formatter = load_format_dispatcher('v1', project_dir) app.state.formatter = load_format_dispatcher('v1', project_dir)
return app return app

View File

@ -8,4 +8,4 @@
Implementation of API version v1 (aka the legacy version). Implementation of API version v1 (aka the legacy version).
""" """
from .server_glue import ROUTES as ROUTES from .server_glue import get_routes as get_routes

View File

@ -8,7 +8,7 @@
Generic part of the server implementation of the v1 API. Generic part of the server implementation of the v1 API.
Combine with the scaffolding provided for the various Python ASGI frameworks. Combine with the scaffolding provided for the various Python ASGI frameworks.
""" """
from typing import Optional, Any, Type, Dict, cast from typing import Optional, Any, Type, Dict, cast, Sequence, Tuple
from functools import reduce from functools import reduce
import dataclasses import dataclasses
from urllib.parse import urlencode from urllib.parse import urlencode
@ -25,7 +25,8 @@ from ..results import DetailedResult, ReverseResults, SearchResult, SearchResult
from ..localization import Locales from ..localization import Locales
from . import helpers from . import helpers
from ..server import content_types as ct from ..server import content_types as ct
from ..server.asgi_adaptor import ASGIAdaptor from ..server.asgi_adaptor import ASGIAdaptor, EndpointFunc
from ..sql.async_core_library import PGCORE_ERROR
def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200, def build_response(adaptor: ASGIAdaptor, output: str, status: int = 200,
@ -417,12 +418,25 @@ async def polygons_endpoint(api: NominatimAPIAsync, params: ASGIAdaptor) -> Any:
return build_response(params, params.formatting().format_result(results, fmt, {})) return build_response(params, params.formatting().format_result(results, fmt, {}))
ROUTES = [ async def get_routes(api: NominatimAPIAsync) -> Sequence[Tuple[str, EndpointFunc]]:
('status', status_endpoint), routes = [
('details', details_endpoint), ('status', status_endpoint),
('reverse', reverse_endpoint), ('details', details_endpoint),
('lookup', lookup_endpoint), ('reverse', reverse_endpoint),
('search', search_endpoint), ('lookup', lookup_endpoint),
('deletable', deletable_endpoint), ('deletable', deletable_endpoint),
('polygons', polygons_endpoint), ('polygons', polygons_endpoint),
] ]
def has_search_name(conn: sa.engine.Connection) -> bool:
insp = sa.inspect(conn)
return insp.has_table('search_name')
try:
async with api.begin() as conn:
if await conn.connection.run_sync(has_search_name):
routes.append(('search', search_endpoint))
except (PGCORE_ERROR, sa.exc.OperationalError):
pass # ignored
return routes