diff --git a/src/nominatim_api/core.py b/src/nominatim_api/core.py index 6c4c37d7..ac579862 100644 --- a/src/nominatim_api/core.py +++ b/src/nominatim_api/core.py @@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes This class shares most of the functions with its synchronous version. There are some additional functions or parameters, which are documented below. + + This class should usually be used as a context manager in 'with' context. """ def __init__(self, project_dir: Path, environ: Optional[Mapping[str, str]] = None, @@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes await self._engine.dispose() + async def __aenter__(self) -> 'NominatimAPIAsync': + return self + + + async def __aexit__(self, *_: Any) -> None: + await self.close() + + @contextlib.asynccontextmanager async def begin(self) -> AsyncIterator[SearchConnection]: """ Create a new connection with automatic transaction handling. @@ -351,6 +361,8 @@ class NominatimAPI: """ This class provides a thin synchronous wrapper around the asynchronous Nominatim functions. It creates its own event loop and runs each synchronous function call to completion using that loop. + + This class should usually be used as a context manager in 'with' context. """ def __init__(self, project_dir: Path, @@ -376,8 +388,17 @@ class NominatimAPI: This function also closes the asynchronous worker loop making the NominatimAPI object unusable. """ - self._loop.run_until_complete(self._async_api.close()) - self._loop.close() + if not self._loop.is_closed(): + self._loop.run_until_complete(self._async_api.close()) + self._loop.close() + + + def __enter__(self) -> 'NominatimAPI': + return self + + + def __exit__(self, *_: Any) -> None: + self.close() @property diff --git a/src/nominatim_db/clicmd/api.py b/src/nominatim_db/clicmd/api.py index ddc10ff9..dcbbb24b 100644 --- a/src/nominatim_db/clicmd/api.py +++ b/src/nominatim_db/clicmd/api.py @@ -180,29 +180,32 @@ class APISearch: raise UsageError(f"Unsupported format '{args.format}'. " 'Use --list-formats to see supported formats.') - api = napi.NominatimAPI(args.project_dir) - params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10), - 'address_details': True, # needed for display name - 'geometry_output': _get_geometry_output(args), - 'geometry_simplification': args.polygon_threshold, - 'countries': args.countrycodes, - 'excluded': args.exclude_place_ids, - 'viewbox': args.viewbox, - 'bounded_viewbox': args.bounded, - 'locales': _get_locales(args, api.config.DEFAULT_LANGUAGE) - } + try: + with napi.NominatimAPI(args.project_dir) as api: + params: Dict[str, Any] = {'max_results': args.limit + min(args.limit, 10), + 'address_details': True, # needed for display name + 'geometry_output': _get_geometry_output(args), + 'geometry_simplification': args.polygon_threshold, + 'countries': args.countrycodes, + 'excluded': args.exclude_place_ids, + 'viewbox': args.viewbox, + 'bounded_viewbox': args.bounded, + 'locales': _get_locales(args, api.config.DEFAULT_LANGUAGE) + } - if args.query: - results = api.search(args.query, **params) - else: - results = api.search_address(amenity=args.amenity, - street=args.street, - city=args.city, - county=args.county, - state=args.state, - postalcode=args.postalcode, - country=args.country, - **params) + if args.query: + results = api.search(args.query, **params) + else: + results = api.search_address(amenity=args.amenity, + street=args.street, + city=args.city, + county=args.county, + state=args.state, + postalcode=args.postalcode, + country=args.country, + **params) + except napi.UsageError as ex: + raise UsageError(ex) from ex if args.dedupe and len(results) > 1: results = deduplicate_results(results, args.limit) @@ -260,14 +263,19 @@ class APIReverse: if args.lat is None or args.lon is None: raise UsageError("lat' and 'lon' parameters are required.") - api = napi.NominatimAPI(args.project_dir) - result = api.reverse(napi.Point(args.lon, args.lat), - max_rank=zoom_to_rank(args.zoom or 18), - layers=_get_layers(args, napi.DataLayer.ADDRESS | napi.DataLayer.POI), - address_details=True, # needed for display name - geometry_output=_get_geometry_output(args), - geometry_simplification=args.polygon_threshold, - locales=_get_locales(args, api.config.DEFAULT_LANGUAGE)) + layers = _get_layers(args, napi.DataLayer.ADDRESS | napi.DataLayer.POI) + + try: + with napi.NominatimAPI(args.project_dir) as api: + result = api.reverse(napi.Point(args.lon, args.lat), + max_rank=zoom_to_rank(args.zoom or 18), + layers=layers, + address_details=True, # needed for display name + geometry_output=_get_geometry_output(args), + geometry_simplification=args.polygon_threshold, + locales=_get_locales(args, api.config.DEFAULT_LANGUAGE)) + except napi.UsageError as ex: + raise UsageError(ex) from ex if args.format == 'debug': print(loglib.get_and_disable()) @@ -323,12 +331,15 @@ class APILookup: places = [napi.OsmID(o[0], int(o[1:])) for o in args.ids] - api = napi.NominatimAPI(args.project_dir) - results = api.lookup(places, - address_details=True, # needed for display name - geometry_output=_get_geometry_output(args), - geometry_simplification=args.polygon_threshold or 0.0, - locales=_get_locales(args, api.config.DEFAULT_LANGUAGE)) + try: + with napi.NominatimAPI(args.project_dir) as api: + results = api.lookup(places, + address_details=True, # needed for display name + geometry_output=_get_geometry_output(args), + geometry_simplification=args.polygon_threshold or 0.0, + locales=_get_locales(args, api.config.DEFAULT_LANGUAGE)) + except napi.UsageError as ex: + raise UsageError(ex) from ex if args.format == 'debug': print(loglib.get_and_disable()) @@ -410,17 +421,20 @@ class APIDetails: raise UsageError('One of the arguments --node/-n --way/-w ' '--relation/-r --place_id/-p is required/') - api = napi.NominatimAPI(args.project_dir) - locales = _get_locales(args, api.config.DEFAULT_LANGUAGE) - result = api.details(place, - address_details=args.addressdetails, - linked_places=args.linkedplaces, - parented_places=args.hierarchy, - keywords=args.keywords, - geometry_output=napi.GeometryFormat.GEOJSON - if args.polygon_geojson - else napi.GeometryFormat.NONE, - locales=locales) + try: + with napi.NominatimAPI(args.project_dir) as api: + locales = _get_locales(args, api.config.DEFAULT_LANGUAGE) + result = api.details(place, + address_details=args.addressdetails, + linked_places=args.linkedplaces, + parented_places=args.hierarchy, + keywords=args.keywords, + geometry_output=napi.GeometryFormat.GEOJSON + if args.polygon_geojson + else napi.GeometryFormat.NONE, + locales=locales) + except napi.UsageError as ex: + raise UsageError(ex) from ex if args.format == 'debug': print(loglib.get_and_disable()) @@ -465,7 +479,11 @@ class APIStatus: raise UsageError(f"Unsupported format '{args.format}'. " 'Use --list-formats to see supported formats.') - status = napi.NominatimAPI(args.project_dir).status() + try: + with napi.NominatimAPI(args.project_dir) as api: + status = api.status() + except napi.UsageError as ex: + raise UsageError(ex) from ex if args.format == 'debug': print(loglib.get_and_disable()) diff --git a/test/python/api/conftest.py b/test/python/api/conftest.py index a902e264..0c770980 100644 --- a/test/python/api/conftest.py +++ b/test/python/api/conftest.py @@ -9,6 +9,7 @@ Helper fixtures for API call tests. """ from pathlib import Path import pytest +import pytest_asyncio import time import datetime as dt @@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path): for api in testapis: api.close() + + +@pytest_asyncio.fixture +async def api(temp_db): + async with napi.NominatimAPIAsync(Path('/invalid')) as api: + yield api diff --git a/test/python/api/search/test_icu_query_analyzer.py b/test/python/api/search/test_icu_query_analyzer.py index 8e5480fc..7f88879c 100644 --- a/test/python/api/search/test_icu_query_analyzer.py +++ b/test/python/api/search/test_icu_query_analyzer.py @@ -40,10 +40,9 @@ async def conn(table_factory): table_factory('word', definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB') - api = NominatimAPIAsync(Path('/invalid'), {}) - async with api.begin() as conn: - yield conn - await api.close() + async with NominatimAPIAsync(Path('/invalid'), {}) as api: + async with api.begin() as conn: + yield conn @pytest.mark.asyncio diff --git a/test/python/api/search/test_legacy_query_analyzer.py b/test/python/api/search/test_legacy_query_analyzer.py index 92de8706..0e967c10 100644 --- a/test/python/api/search/test_legacy_query_analyzer.py +++ b/test/python/api/search/test_legacy_query_analyzer.py @@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor): temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT) RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""") - api = NominatimAPIAsync(Path('/invalid'), {}) - async with api.begin() as conn: - yield conn - await api.close() + async with NominatimAPIAsync(Path('/invalid'), {}) as api: + async with api.begin() as conn: + yield conn @pytest.mark.asyncio diff --git a/test/python/api/search/test_query_analyzer_factory.py b/test/python/api/search/test_query_analyzer_factory.py index 9545a88f..42220b55 100644 --- a/test/python/api/search/test_query_analyzer_factory.py +++ b/test/python/api/search/test_query_analyzer_factory.py @@ -11,41 +11,35 @@ from pathlib import Path import pytest -from nominatim_api import NominatimAPIAsync from nominatim_api.search.query_analyzer_factory import make_query_analyzer from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer @pytest.mark.asyncio -async def test_import_icu_tokenizer(table_factory): +async def test_import_icu_tokenizer(table_factory, api): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('tokenizer', 'icu'), ('tokenizer_import_normalisation', ':: lower();'), ('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '"))) - api = NominatimAPIAsync(Path('/invalid'), {}) async with api.begin() as conn: ana = await make_query_analyzer(conn) assert isinstance(ana, ICUQueryAnalyzer) - await api.close() @pytest.mark.asyncio -async def test_import_missing_property(table_factory): - api = NominatimAPIAsync(Path('/invalid'), {}) +async def test_import_missing_property(table_factory, api): table_factory('nominatim_properties', definition='property TEXT, value TEXT') async with api.begin() as conn: with pytest.raises(ValueError, match='Property.*not found'): await make_query_analyzer(conn) - await api.close() @pytest.mark.asyncio -async def test_import_missing_module(table_factory): - api = NominatimAPIAsync(Path('/invalid'), {}) +async def test_import_missing_module(table_factory, api): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('tokenizer', 'missing'),)) @@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory): async with api.begin() as conn: with pytest.raises(RuntimeError, match='Tokenizer not found'): await make_query_analyzer(conn) - await api.close() - diff --git a/test/python/api/test_api_connection.py b/test/python/api/test_api_connection.py index 3c4fc61b..f62b0d9e 100644 --- a/test/python/api/test_api_connection.py +++ b/test/python/api/test_api_connection.py @@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions. """ from pathlib import Path import pytest -import pytest_asyncio import sqlalchemy as sa -from nominatim_api import NominatimAPIAsync - -@pytest_asyncio.fixture -async def apiobj(temp_db): - """ Create an asynchronous SQLAlchemy engine for the test DB. - """ - api = NominatimAPIAsync(Path('/invalid'), {}) - yield api - await api.close() - @pytest.mark.asyncio -async def test_run_scalar(apiobj, table_factory): +async def test_run_scalar(api, table_factory): table_factory('foo', definition='that TEXT', content=(('a', ),)) - async with apiobj.begin() as conn: + async with api.begin() as conn: assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a' @pytest.mark.asyncio -async def test_run_execute(apiobj, table_factory): +async def test_run_execute(api, table_factory): table_factory('foo', definition='that TEXT', content=(('a', ),)) - async with apiobj.begin() as conn: + async with api.begin() as conn: result = await conn.execute(sa.text('SELECT * FROM foo')) assert result.fetchone()[0] == 'a' @pytest.mark.asyncio -async def test_get_property_existing_cached(apiobj, table_factory): +async def test_get_property_existing_cached(api, table_factory): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('dbv', '96723'), )) - async with apiobj.begin() as conn: + async with api.begin() as conn: assert await conn.get_property('dbv') == '96723' await conn.execute(sa.text('TRUNCATE nominatim_properties')) @@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory): @pytest.mark.asyncio -async def test_get_property_existing_uncached(apiobj, table_factory): +async def test_get_property_existing_uncached(api, table_factory): table_factory('nominatim_properties', definition='property TEXT, value TEXT', content=(('dbv', '96723'), )) - async with apiobj.begin() as conn: + async with api.begin() as conn: assert await conn.get_property('dbv') == '96723' await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'")) @@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory): @pytest.mark.asyncio @pytest.mark.parametrize('param', ['foo', 'DB:server_version']) -async def test_get_property_missing(apiobj, table_factory, param): +async def test_get_property_missing(api, table_factory, param): table_factory('nominatim_properties', definition='property TEXT, value TEXT') - async with apiobj.begin() as conn: + async with api.begin() as conn: with pytest.raises(ValueError): await conn.get_property(param) @pytest.mark.asyncio -async def test_get_db_property_existing(apiobj): - async with apiobj.begin() as conn: +async def test_get_db_property_existing(api): + async with api.begin() as conn: assert await conn.get_db_property('server_version') > 0 @pytest.mark.asyncio -async def test_get_db_property_existing(apiobj): - async with apiobj.begin() as conn: +async def test_get_db_property_existing(api): + async with api.begin() as conn: with pytest.raises(ValueError): await conn.get_db_property('dfkgjd.rijg') diff --git a/test/python/api/test_api_deletable_v1.py b/test/python/api/test_api_deletable_v1.py index 649dd8fc..9e113886 100644 --- a/test/python/api/test_api_deletable_v1.py +++ b/test/python/api/test_api_deletable_v1.py @@ -11,19 +11,10 @@ import json from pathlib import Path import pytest -import pytest_asyncio from fake_adaptor import FakeAdaptor, FakeError, FakeResponse import nominatim_api.v1.server_glue as glue -import nominatim_api as napi - -@pytest_asyncio.fixture -async def api(): - api = napi.NominatimAPIAsync(Path('/invalid')) - yield api - await api.close() - class TestDeletableEndPoint: @@ -61,4 +52,3 @@ class TestDeletableEndPoint: {'place_id': 3, 'country_code': 'cd', 'name': None, 'osm_id': 781, 'osm_type': 'R', 'class': 'landcover', 'type': 'grass'}] - diff --git a/test/python/api/test_api_polygons_v1.py b/test/python/api/test_api_polygons_v1.py index 558be813..ac2b4cb9 100644 --- a/test/python/api/test_api_polygons_v1.py +++ b/test/python/api/test_api_polygons_v1.py @@ -12,19 +12,10 @@ import datetime as dt from pathlib import Path import pytest -import pytest_asyncio from fake_adaptor import FakeAdaptor, FakeError, FakeResponse import nominatim_api.v1.server_glue as glue -import nominatim_api as napi - -@pytest_asyncio.fixture -async def api(): - api = napi.NominatimAPIAsync(Path('/invalid')) - yield api - await api.close() - class TestPolygonsEndPoint: