mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-13 03:12:24 +03:00
add a transactional reader as DBWrapper2.reader()
(#13468)
* add a transactional reader as `DBWrapper2.reader()` * contextlib.AbstractAsyncContextManager * try a minimal test * rework test to allow the read to finish so the write can finish the commit after that * pylint disable * WAL!!! * more tests * if TYPE_CHECKING for typing variable assignments * future * cover some more test cases
This commit is contained in:
parent
28213628ae
commit
61b497a701
@ -259,6 +259,20 @@ class DBWrapper2:
|
||||
finally:
|
||||
self._current_writer = None
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def reader(self) -> AsyncIterator[aiosqlite.Connection]:
|
||||
async with self.reader_no_transaction() as connection:
|
||||
if connection.in_transaction:
|
||||
yield connection
|
||||
else:
|
||||
await connection.execute("BEGIN DEFERRED;")
|
||||
try:
|
||||
yield connection
|
||||
finally:
|
||||
# close the transaction with a rollback instead of commit just in
|
||||
# case any modifications were submitted through this reader
|
||||
await connection.rollback()
|
||||
|
||||
@contextlib.asynccontextmanager
|
||||
async def reader_no_transaction(self) -> AsyncIterator[aiosqlite.Connection]:
|
||||
# there should have been read connections added
|
||||
|
@ -2,14 +2,27 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import List
|
||||
from typing import TYPE_CHECKING, Callable, List
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
|
||||
from _pytest.fixtures import SubRequest
|
||||
|
||||
from chia.util.db_wrapper import DBWrapper2
|
||||
from tests.util.db_connection import DBConnection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ConnectionContextManager = contextlib.AbstractAsyncContextManager[aiosqlite.core.Connection]
|
||||
GetReaderMethod = Callable[[DBWrapper2], Callable[[], ConnectionContextManager]]
|
||||
|
||||
|
||||
class UniqueError(Exception):
|
||||
"""Used to uniquely trigger the exception path out of the context managers."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def increment_counter(db_wrapper: DBWrapper2) -> None:
|
||||
async with db_wrapper.writer_maybe_transaction() as connection:
|
||||
@ -62,12 +75,37 @@ async def get_value(cursor: aiosqlite.Cursor) -> int:
|
||||
return int(row[0])
|
||||
|
||||
|
||||
async def query_value(connection: aiosqlite.Connection) -> int:
|
||||
async with connection.execute("SELECT value FROM counter") as cursor:
|
||||
return await get_value(cursor=cursor)
|
||||
|
||||
|
||||
def _get_reader_no_transaction_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]:
|
||||
return db_wrapper.reader_no_transaction
|
||||
|
||||
|
||||
def _get_regular_reader_method(db_wrapper: DBWrapper2) -> Callable[[], ConnectionContextManager]:
|
||||
return db_wrapper.reader
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
name="get_reader_method",
|
||||
params=[
|
||||
pytest.param(_get_reader_no_transaction_method, id="reader_no_transaction"),
|
||||
pytest.param(_get_regular_reader_method, id="reader"),
|
||||
],
|
||||
)
|
||||
def get_reader_method_fixture(request: SubRequest) -> Callable[[], ConnectionContextManager]:
|
||||
# https://github.com/pytest-dev/pytest/issues/8763
|
||||
return request.param # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
argnames="acquire_outside",
|
||||
argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")],
|
||||
)
|
||||
async def test_concurrent_writers(acquire_outside: bool) -> None:
|
||||
async def test_concurrent_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None:
|
||||
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
@ -85,7 +123,7 @@ async def test_concurrent_writers(acquire_outside: bool) -> None:
|
||||
|
||||
await asyncio.wait_for(asyncio.gather(*tasks), timeout=None)
|
||||
|
||||
async with db_wrapper.reader_no_transaction() as connection:
|
||||
async with get_reader_method(db_wrapper)() as connection:
|
||||
async with connection.execute("SELECT value FROM counter") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
|
||||
@ -160,14 +198,14 @@ async def test_partial_failure() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readers_nests() -> None:
|
||||
async def test_readers_nests(get_reader_method: GetReaderMethod) -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with db_wrapper.reader_no_transaction() as conn1:
|
||||
async with db_wrapper.reader_no_transaction() as conn2:
|
||||
async with get_reader_method(db_wrapper)() as conn1:
|
||||
async with get_reader_method(db_wrapper)() as conn2:
|
||||
assert conn1 == conn2
|
||||
async with db_wrapper.reader_no_transaction() as conn3:
|
||||
async with get_reader_method(db_wrapper)() as conn3:
|
||||
assert conn1 == conn3
|
||||
async with conn3.execute("SELECT value FROM counter") as cursor:
|
||||
value = await get_value(cursor)
|
||||
@ -176,12 +214,12 @@ async def test_readers_nests() -> None:
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readers_nests_writer() -> None:
|
||||
async def test_readers_nests_writer(get_reader_method: GetReaderMethod) -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with db_wrapper.writer_maybe_transaction() as conn1:
|
||||
async with db_wrapper.reader_no_transaction() as conn2:
|
||||
async with get_reader_method(db_wrapper)() as conn2:
|
||||
assert conn1 == conn2
|
||||
async with db_wrapper.writer_maybe_transaction() as conn3:
|
||||
assert conn1 == conn3
|
||||
@ -191,12 +229,103 @@ async def test_readers_nests_writer() -> None:
|
||||
assert value == 0
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
argnames="transactioned",
|
||||
argvalues=[
|
||||
pytest.param(True, id="transaction"),
|
||||
pytest.param(False, id="no transaction"),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_transactioned_reader_ignores_writer(transactioned: bool) -> None:
|
||||
writer_committed = asyncio.Event()
|
||||
reader_read = asyncio.Event()
|
||||
|
||||
async def write() -> None:
|
||||
try:
|
||||
async with db_wrapper.writer() as writer:
|
||||
assert reader is not writer
|
||||
|
||||
await writer.execute("UPDATE counter SET value = 1")
|
||||
finally:
|
||||
writer_committed.set()
|
||||
|
||||
await reader_read.wait()
|
||||
|
||||
assert await query_value(connection=writer) == 1
|
||||
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
get_reader = db_wrapper.reader if transactioned else db_wrapper.reader_no_transaction
|
||||
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with get_reader() as reader:
|
||||
assert await query_value(connection=reader) == 0
|
||||
|
||||
task = asyncio.create_task(write())
|
||||
await writer_committed.wait()
|
||||
|
||||
assert await query_value(connection=reader) == 0 if transactioned else 1
|
||||
reader_read.set()
|
||||
|
||||
await task
|
||||
|
||||
async with get_reader() as reader:
|
||||
assert await query_value(connection=reader) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reader_nests_and_ends_transaction() -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
async with db_wrapper.reader() as reader:
|
||||
assert reader.in_transaction
|
||||
|
||||
async with db_wrapper.reader() as inner_reader:
|
||||
assert inner_reader is reader
|
||||
assert reader.in_transaction
|
||||
|
||||
assert reader.in_transaction
|
||||
|
||||
assert not reader.in_transaction
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_writer_in_reader_works() -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with db_wrapper.reader() as reader:
|
||||
async with db_wrapper.writer() as writer:
|
||||
assert writer is not reader
|
||||
await writer.execute("UPDATE counter SET value = 1")
|
||||
assert await query_value(connection=writer) == 1
|
||||
assert await query_value(connection=reader) == 0
|
||||
|
||||
assert await query_value(connection=reader) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reader_transaction_is_deferred() -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with db_wrapper.reader() as reader:
|
||||
async with db_wrapper.writer() as writer:
|
||||
assert writer is not reader
|
||||
await writer.execute("UPDATE counter SET value = 1")
|
||||
assert await query_value(connection=writer) == 1
|
||||
|
||||
# The deferred transaction initiation results in the transaction starting
|
||||
# here and thus reading the written value.
|
||||
assert await query_value(connection=reader) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
argnames="acquire_outside",
|
||||
argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")],
|
||||
)
|
||||
async def test_concurrent_readers(acquire_outside: bool) -> None:
|
||||
async def test_concurrent_readers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None:
|
||||
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
@ -208,7 +337,7 @@ async def test_concurrent_readers(acquire_outside: bool) -> None:
|
||||
|
||||
async with contextlib.AsyncExitStack() as exit_stack:
|
||||
if acquire_outside:
|
||||
await exit_stack.enter_async_context(db_wrapper.reader_no_transaction())
|
||||
await exit_stack.enter_async_context(get_reader_method(db_wrapper)())
|
||||
|
||||
tasks = []
|
||||
values: List[int] = []
|
||||
@ -226,7 +355,7 @@ async def test_concurrent_readers(acquire_outside: bool) -> None:
|
||||
argnames="acquire_outside",
|
||||
argvalues=[pytest.param(False, id="not acquired outside"), pytest.param(True, id="acquired outside")],
|
||||
)
|
||||
async def test_mixed_readers_writers(acquire_outside: bool) -> None:
|
||||
async def test_mixed_readers_writers(acquire_outside: bool, get_reader_method: GetReaderMethod) -> None:
|
||||
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
@ -238,7 +367,7 @@ async def test_mixed_readers_writers(acquire_outside: bool) -> None:
|
||||
|
||||
async with contextlib.AsyncExitStack() as exit_stack:
|
||||
if acquire_outside:
|
||||
await exit_stack.enter_async_context(db_wrapper.reader_no_transaction())
|
||||
await exit_stack.enter_async_context(get_reader_method(db_wrapper)())
|
||||
|
||||
tasks = []
|
||||
values: List[int] = []
|
||||
@ -254,7 +383,7 @@ async def test_mixed_readers_writers(acquire_outside: bool) -> None:
|
||||
|
||||
# we increment and decrement the counter an equal number of times. It should
|
||||
# end back at 1.
|
||||
async with db_wrapper.reader_no_transaction() as connection:
|
||||
async with get_reader_method(db_wrapper)() as connection:
|
||||
async with connection.execute("SELECT value FROM counter") as cursor:
|
||||
row = await cursor.fetchone()
|
||||
assert row is not None
|
||||
@ -265,3 +394,41 @@ async def test_mixed_readers_writers(acquire_outside: bool) -> None:
|
||||
for v in values:
|
||||
assert v > -99
|
||||
assert v <= 100
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
argnames=["manager_method", "expected"],
|
||||
argvalues=[
|
||||
[DBWrapper2.writer, True],
|
||||
[DBWrapper2.writer_maybe_transaction, True],
|
||||
[DBWrapper2.reader, True],
|
||||
[DBWrapper2.reader_no_transaction, False],
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_in_transaction_as_expected(
|
||||
manager_method: Callable[[DBWrapper2], ConnectionContextManager],
|
||||
expected: bool,
|
||||
) -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with manager_method(db_wrapper) as connection:
|
||||
assert connection.in_transaction == expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancelled_reader_does_not_cancel_writer() -> None:
|
||||
async with DBConnection(2) as db_wrapper:
|
||||
await setup_table(db_wrapper)
|
||||
|
||||
async with db_wrapper.writer() as writer:
|
||||
await writer.execute("UPDATE counter SET value = 1")
|
||||
|
||||
with pytest.raises(UniqueError):
|
||||
async with db_wrapper.reader() as _:
|
||||
raise UniqueError()
|
||||
|
||||
assert await query_value(connection=writer) == 1
|
||||
|
||||
assert await query_value(connection=writer) == 1
|
||||
|
Loading…
Reference in New Issue
Block a user