make NominatimAPI[Async] a context manager

If close() isn't properly called, it can lead to odd error messages
about uncaught exceptions.
This commit is contained in:
Sarah Hoffmann 2024-08-19 11:31:38 +02:00
parent 8b41b80bff
commit c2594aca40
8 changed files with 53 additions and 65 deletions

View File

@ -38,6 +38,8 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
This class shares most of the functions with its synchronous
version. There are some additional functions or parameters,
which are documented below.
This class should usually be used as a context manager in 'with' context.
"""
def __init__(self, project_dir: Path,
environ: Optional[Mapping[str, str]] = None,
@ -166,6 +168,14 @@ class NominatimAPIAsync: #pylint: disable=too-many-instance-attributes
await self._engine.dispose()
async def __aenter__(self) -> 'NominatimAPIAsync':
return self
async def __aexit__(self, *_: Any) -> None:
await self.close()
@contextlib.asynccontextmanager
async def begin(self) -> AsyncIterator[SearchConnection]:
""" Create a new connection with automatic transaction handling.
@ -351,6 +361,8 @@ class NominatimAPI:
""" This class provides a thin synchronous wrapper around the asynchronous
Nominatim functions. It creates its own event loop and runs each
synchronous function call to completion using that loop.
This class should usually be used as a context manager in 'with' context.
"""
def __init__(self, project_dir: Path,
@ -376,8 +388,17 @@ class NominatimAPI:
This function also closes the asynchronous worker loop making
the NominatimAPI object unusable.
"""
self._loop.run_until_complete(self._async_api.close())
self._loop.close()
if not self._loop.is_closed():
self._loop.run_until_complete(self._async_api.close())
self._loop.close()
def __enter__(self) -> 'NominatimAPI':
return self
def __exit__(self, *_: Any) -> None:
self.close()
@property

View File

@ -9,6 +9,7 @@ Helper fixtures for API call tests.
"""
from pathlib import Path
import pytest
import pytest_asyncio
import time
import datetime as dt
@ -244,3 +245,9 @@ def frontend(request, event_loop, tmp_path):
for api in testapis:
api.close()
@pytest_asyncio.fixture
async def api(temp_db):
async with napi.NominatimAPIAsync(Path('/invalid')) as api:
yield api

View File

@ -40,10 +40,9 @@ async def conn(table_factory):
table_factory('word',
definition='word_id INT, word_token TEXT, type TEXT, word TEXT, info JSONB')
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn
@pytest.mark.asyncio

View File

@ -74,10 +74,9 @@ async def conn(table_factory, temp_db_cursor):
temp_db_cursor.execute("""CREATE OR REPLACE FUNCTION make_standard_name(name TEXT)
RETURNS TEXT AS $$ SELECT lower(name); $$ LANGUAGE SQL;""")
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
yield conn
await api.close()
async with NominatimAPIAsync(Path('/invalid'), {}) as api:
async with api.begin() as conn:
yield conn
@pytest.mark.asyncio

View File

@ -11,41 +11,35 @@ from pathlib import Path
import pytest
from nominatim_api import NominatimAPIAsync
from nominatim_api.search.query_analyzer_factory import make_query_analyzer
from nominatim_api.search.icu_tokenizer import ICUQueryAnalyzer
@pytest.mark.asyncio
async def test_import_icu_tokenizer(table_factory):
async def test_import_icu_tokenizer(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'icu'),
('tokenizer_import_normalisation', ':: lower();'),
('tokenizer_import_transliteration', "'1' > '/1/'; 'ä' > 'ä '")))
api = NominatimAPIAsync(Path('/invalid'), {})
async with api.begin() as conn:
ana = await make_query_analyzer(conn)
assert isinstance(ana, ICUQueryAnalyzer)
await api.close()
@pytest.mark.asyncio
async def test_import_missing_property(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_property(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')
async with api.begin() as conn:
with pytest.raises(ValueError, match='Property.*not found'):
await make_query_analyzer(conn)
await api.close()
@pytest.mark.asyncio
async def test_import_missing_module(table_factory):
api = NominatimAPIAsync(Path('/invalid'), {})
async def test_import_missing_module(table_factory, api):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('tokenizer', 'missing'),))
@ -53,5 +47,3 @@ async def test_import_missing_module(table_factory):
async with api.begin() as conn:
with pytest.raises(RuntimeError, match='Tokenizer not found'):
await make_query_analyzer(conn)
await api.close()

View File

@ -9,45 +9,34 @@ Tests for enhanced connection class for API functions.
"""
from pathlib import Path
import pytest
import pytest_asyncio
import sqlalchemy as sa
from nominatim_api import NominatimAPIAsync
@pytest_asyncio.fixture
async def apiobj(temp_db):
""" Create an asynchronous SQLAlchemy engine for the test DB.
"""
api = NominatimAPIAsync(Path('/invalid'), {})
yield api
await api.close()
@pytest.mark.asyncio
async def test_run_scalar(apiobj, table_factory):
async def test_run_scalar(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.scalar(sa.text('SELECT * FROM foo')) == 'a'
@pytest.mark.asyncio
async def test_run_execute(apiobj, table_factory):
async def test_run_execute(api, table_factory):
table_factory('foo', definition='that TEXT', content=(('a', ),))
async with apiobj.begin() as conn:
async with api.begin() as conn:
result = await conn.execute(sa.text('SELECT * FROM foo'))
assert result.fetchone()[0] == 'a'
@pytest.mark.asyncio
async def test_get_property_existing_cached(apiobj, table_factory):
async def test_get_property_existing_cached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))
async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text('TRUNCATE nominatim_properties'))
@ -56,12 +45,12 @@ async def test_get_property_existing_cached(apiobj, table_factory):
@pytest.mark.asyncio
async def test_get_property_existing_uncached(apiobj, table_factory):
async def test_get_property_existing_uncached(api, table_factory):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT',
content=(('dbv', '96723'), ))
async with apiobj.begin() as conn:
async with api.begin() as conn:
assert await conn.get_property('dbv') == '96723'
await conn.execute(sa.text("UPDATE nominatim_properties SET value = '1'"))
@ -71,23 +60,23 @@ async def test_get_property_existing_uncached(apiobj, table_factory):
@pytest.mark.asyncio
@pytest.mark.parametrize('param', ['foo', 'DB:server_version'])
async def test_get_property_missing(apiobj, table_factory, param):
async def test_get_property_missing(api, table_factory, param):
table_factory('nominatim_properties',
definition='property TEXT, value TEXT')
async with apiobj.begin() as conn:
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_property(param)
@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
assert await conn.get_db_property('server_version') > 0
@pytest.mark.asyncio
async def test_get_db_property_existing(apiobj):
async with apiobj.begin() as conn:
async def test_get_db_property_existing(api):
async with api.begin() as conn:
with pytest.raises(ValueError):
await conn.get_db_property('dfkgjd.rijg')

View File

@ -11,19 +11,10 @@ import json
from pathlib import Path
import pytest
import pytest_asyncio
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue
import nominatim_api as napi
@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()
class TestDeletableEndPoint:
@ -61,4 +52,3 @@ class TestDeletableEndPoint:
{'place_id': 3, 'country_code': 'cd', 'name': None,
'osm_id': 781, 'osm_type': 'R',
'class': 'landcover', 'type': 'grass'}]

View File

@ -12,19 +12,10 @@ import datetime as dt
from pathlib import Path
import pytest
import pytest_asyncio
from fake_adaptor import FakeAdaptor, FakeError, FakeResponse
import nominatim_api.v1.server_glue as glue
import nominatim_api as napi
@pytest_asyncio.fixture
async def api():
api = napi.NominatimAPIAsync(Path('/invalid'))
yield api
await api.close()
class TestPolygonsEndPoint: