add a transactional reader as DBWrapper2.reader() (#13468)

* add a transactional reader as `DBWrapper2.reader()`

* contextlib.AbstractAsyncContextManager

* try a minimal test

* rework test to allow the read to finish so the write can finish the commit after that

* pylint disable

* WAL!!!

* more tests

* if TYPE_CHECKING for typing variable assignments

* future

* cover some more test cases
This commit is contained in:
Kyle Altendorf 2022-09-30 18:41:33 -04:00 committed by GitHub
parent 28213628ae
commit 61b497a701
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 195 additions and 14 deletions

View File

@ -259,6 +259,20 @@ class DBWrapper2:
finally:
self._current_writer = None
@contextlib.asynccontextmanager
async def reader(self) -> AsyncIterator[aiosqlite.Connection]:
async with self.reader_no_transaction() as connection:
if connection.in_transaction:
yield connection
else:
await connection.execute("BEGIN DEFERRED;")
try:
yield connection
finally:
# close the transaction with a rollback instead of commit just in
# case any modifications were submitted through this reader
await connection.rollback()
@contextlib.asynccontextmanager
async def reader_no_transaction(self) -> AsyncIterator[aiosqlite.Connection]:
# there should have been read connections added

View File

@ -2,14 +2,27 @@ from __future__ import annotations
import asyncio
import contextlib
from typing import List
from typing import TYPE_CHECKING, Callable, List
import aiosqlite
import pytest
# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest
from chia.util.db_wrapper import DBWrapper2
from tests.util.db_connection import DBConnection
if TYPE_CHECKING:
ConnectionContextManager = contextlib.AbstractAsyncContextManager[aiosqlite.core.Connection]
GetReaderMethod = Callable[[DBWrapper2], Callable[[], ConnectionContextManager]]
class UniqueError(Exception):
"""Used to uniquely trigger the exception path out of the context managers."""
pass
async def increment_counter(db_wrapper: DBWrapper2) -> None:
async with db_wrapper.writer_maybe_transaction() as connection:
@ -62,12 +75,37 @@ async def get_value(cursor: aiosqlite.Cursor) -> int:
return int(row[0])
async def query_value(connection: aiosqlite.Connection) -> int:
async with connection.execute("SELECT value FROM counter") as cursor:
return await get_value(cursor=cursor)
def _get_reader_no_transaction_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]:
return db_wrapper.reader_no_transaction
def _get_regular_reader_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]:
return db_wrapper.reader
@pytest.fixture(
name="get_reader_method",
params=[
pytest.param(_get_reader_no_transaction_method, id="reader_no_transaction"),
pytest.param(_get_regular_reader_method, id="reader"),
],
)
def get_reader_method_fixture(request: SubRequest) -> Callable[[], ConnectionContextManager]:
# https://github.com/pytest-dev/pytest/issues/8763
return request.param # type: ignore[no-any-return]
@pytest.mark.asyncio
@pytest.mark.parametrize(
argnames="acquire_outside",
argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")],
)
async def test_concurrent_writers(acquire_outside: bool) -> None:
async def test_concurrent_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
@ -85,7 +123,7 @@ async def test_concurrent_writers(acquire_outside: bool) -> None:
await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
async with db_wrapper.reader_no_transaction() as connection:
async with get_reader_method(db_wrapper)() as connection:
async with connection.execute("SELECT value FROM counter") as cursor:
row = await cursor.fetchone()
@ -160,14 +198,14 @@ async def test_partial_failure() -> None:
@pytest.mark.asyncio
async def test_readers_nests() -> None:
async def test_readers_nests(get_reader_method: GetReaderMethod) -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
async with db_wrapper.reader_no_transaction() as conn1:
async with db_wrapper.reader_no_transaction() as conn2:
async with get_reader_method(db_wrapper)() as conn1:
async with get_reader_method(db_wrapper)() as conn2:
assert conn1 == conn2
async with db_wrapper.reader_no_transaction() as conn3:
async with get_reader_method(db_wrapper)() as conn3:
assert conn1 == conn3
async with conn3.execute("SELECT value FROM counter") as cursor:
value = await get_value(cursor)
@ -176,12 +214,12 @@ async def test_readers_nests() -> None:
@pytest.mark.asyncio
async def test_readers_nests_writer() -> None:
async def test_readers_nests_writer(get_reader_method: GetReaderMethod) -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
async with db_wrapper.writer_maybe_transaction() as conn1:
async with db_wrapper.reader_no_transaction() as conn2:
async with get_reader_method(db_wrapper)() as conn2:
assert conn1 == conn2
async with db_wrapper.writer_maybe_transaction() as conn3:
assert conn1 == conn3
@ -191,12 +229,103 @@ async def test_readers_nests_writer() -> None:
assert value == 0
@pytest.mark.parametrize(
argnames="transactioned",
argvalues=[
pytest.param(True, id="transaction"),
pytest.param(False, id="no transaction"),
],
)
@pytest.mark.asyncio
async def test_only_transactioned_reader_ignores_writer(transactioned: bool) -> None:
writer_committed = asyncio.Event()
reader_read = asyncio.Event()
async def write() -> None:
try:
async with db_wrapper.writer() as writer:
assert reader is not writer
await writer.execute("UPDATE counter SET value = 1")
finally:
writer_committed.set()
await reader_read.wait()
assert await query_value(connection=writer) == 1
async with DBConnection(2) as db_wrapper:
get_reader = db_wrapper.reader if transactioned else db_wrapper.reader_no_transaction
await setup_table(db_wrapper)
async with get_reader() as reader:
assert await query_value(connection=reader) == 0
task = asyncio.create_task(write())
await writer_committed.wait()
assert await query_value(connection=reader) == 0 if transactioned else 1
reader_read.set()
await task
async with get_reader() as reader:
assert await query_value(connection=reader) == 1
@pytest.mark.asyncio
async def test_reader_nests_and_ends_transaction() -> None:
async with DBConnection(2) as db_wrapper:
async with db_wrapper.reader() as reader:
assert reader.in_transaction
async with db_wrapper.reader() as inner_reader:
assert inner_reader is reader
assert reader.in_transaction
assert reader.in_transaction
assert not reader.in_transaction
@pytest.mark.asyncio
async def test_writer_in_reader_works() -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
async with db_wrapper.reader() as reader:
async with db_wrapper.writer() as writer:
assert writer is not reader
await writer.execute("UPDATE counter SET value = 1")
assert await query_value(connection=writer) == 1
assert await query_value(connection=reader) == 0
assert await query_value(connection=reader) == 0
@pytest.mark.asyncio
async def test_reader_transaction_is_deferred() -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
async with db_wrapper.reader() as reader:
async with db_wrapper.writer() as writer:
assert writer is not reader
await writer.execute("UPDATE counter SET value = 1")
assert await query_value(connection=writer) == 1
# The deferred transaction initiation results in the transaction starting
# here and thus reading the written value.
assert await query_value(connection=reader) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize(
argnames="acquire_outside",
argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")],
)
async def test_concurrent_readers(acquire_outside: bool) -> None:
async def test_concurrent_readers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
@ -208,7 +337,7 @@ async def test_concurrent_readers(acquire_outside: bool) -> None:
async with contextlib.AsyncExitStack() as exit_stack:
if acquire_outside:
await exit_stack.enter_async_context(db_wrapper.reader_no_transaction())
await exit_stack.enter_async_context(get_reader_method(db_wrapper)())
tasks = []
values: List[int] = []
@ -226,7 +355,7 @@ async def test_concurrent_readers(acquire_outside: bool) -> None:
argnames="acquire_outside",
argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")],
)
async def test_mixed_readers_writers(acquire_outside: bool) -> None:
async def test_mixed_readers_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
@ -238,7 +367,7 @@ async def test_mixed_readers_writers(acquire_outside: bool) -> None:
async with contextlib.AsyncExitStack() as exit_stack:
if acquire_outside:
await exit_stack.enter_async_context(db_wrapper.reader_no_transaction())
await exit_stack.enter_async_context(get_reader_method(db_wrapper)())
tasks = []
values: List[int] = []
@ -254,7 +383,7 @@ async def test_mixed_readers_writers(acquire_outside: bool) -> None:
# we increment and decrement the counter an equal number of times. It should
# end back at 1.
async with db_wrapper.reader_no_transaction() as connection:
async with get_reader_method(db_wrapper)() as connection:
async with connection.execute("SELECT value FROM counter") as cursor:
row = await cursor.fetchone()
assert row is not None
@ -265,3 +394,41 @@ async def test_mixed_readers_writers(acquire_outside: bool) -> None:
for v in values:
assert v > -99
assert v <= 100
@pytest.mark.parametrize(
argnames=["manager_method", "expected"],
argvalues=[
[DBWrapper2.writer, True],
[DBWrapper2.writer_maybe_transaction, True],
[DBWrapper2.reader, True],
[DBWrapper2.reader_no_transaction, False],
],
)
@pytest.mark.asyncio
async def test_in_transaction_as_expected(
manager_method: Callable[[DBWrapper2], ConnectionContextManager],
expected: bool,
) -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
async with manager_method(db_wrapper) as connection:
assert connection.in_transaction == expected
@pytest.mark.asyncio
async def test_cancelled_reader_does_not_cancel_writer() -> None:
async with DBConnection(2) as db_wrapper:
await setup_table(db_wrapper)
async with db_wrapper.writer() as writer:
await writer.execute("UPDATE counter SET value = 1")
with pytest.raises(UniqueError):
async with db_wrapper.reader() as _:
raise UniqueError()
assert await query_value(connection=writer) == 1
assert await query_value(connection=writer) == 1