improve CoinStore test (#8495)

* improve CoinStore test by using pytest parameters (instead of for loops) and use a context manager for the DB connection

* use temporary filename for sqlite db
This commit is contained in:
Arvid Norberg 2021-09-17 23:09:02 +02:00 committed by GitHub
parent 1abfd4d3f4
commit f5dca63048
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Set, Tuple
import aiosqlite
import pytest
import tempfile
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.consensus.blockchain import Blockchain, ReceiveBlockResult
@ -21,6 +22,7 @@ from chia.util.ints import uint64, uint32
from tests.wallet_tools import WalletTool
from chia.util.db_wrapper import DBWrapper
from tests.setup_nodes import bt, test_constants
from chia.types.blockchain_format.sized_bytes import bytes32
@pytest.fixture(scope="module")
@ -54,39 +56,48 @@ def get_future_reward_coins(block: FullBlock) -> Tuple[Coin, Coin]:
return pool_coin, farmer_coin
class DBConnection:
async def __aenter__(self) -> DBWrapper:
self.db_path = Path(tempfile.NamedTemporaryFile().name)
if self.db_path.exists():
self.db_path.unlink()
self.connection = await aiosqlite.connect(self.db_path)
return DBWrapper(self.connection)
async def __aexit__(self, exc_t, exc_v, exc_tb):
await self.connection.close()
self.db_path.unlink()
class TestCoinStoreWithBlocks:
@pytest.mark.asyncio
@pytest.mark.parametrize("rust_checker", [True, False])
async def test_basic_coin_store(self, rust_checker: bool):
@pytest.mark.parametrize("cache_size", [0])
async def test_basic_coin_store(self, rust_checker: bool, cache_size: uint32):
wallet_a = WALLET_A
reward_ph = wallet_a.get_new_puzzlehash()
for cache_size in [0]:
# Generate some coins
blocks = bt.get_consecutive_blocks(
10,
[],
farmer_reward_puzzle_hash=reward_ph,
pool_reward_puzzle_hash=reward_ph,
)
# Generate some coins
blocks = bt.get_consecutive_blocks(
10,
[],
farmer_reward_puzzle_hash=reward_ph,
pool_reward_puzzle_hash=reward_ph,
)
coins_to_spend: List[Coin] = []
for block in blocks:
if block.is_transaction_block():
for coin in block.get_included_reward_coins():
if coin.puzzle_hash == reward_ph:
coins_to_spend.append(coin)
coins_to_spend: List[Coin] = []
for block in blocks:
if block.is_transaction_block():
for coin in block.get_included_reward_coins():
if coin.puzzle_hash == reward_ph:
coins_to_spend.append(coin)
spend_bundle = wallet_a.generate_signed_transaction(
uint64(1000), wallet_a.get_new_puzzlehash(), coins_to_spend[0]
)
spend_bundle = wallet_a.generate_signed_transaction(
uint64(1000), wallet_a.get_new_puzzlehash(), coins_to_spend[0]
)
db_path = Path("fndb_test.db")
if db_path.exists():
db_path.unlink()
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
coin_store = await CoinStore.create(db_wrapper, cache_size=uint32(cache_size))
async with DBConnection() as db_wrapper:
coin_store = await CoinStore.create(db_wrapper, cache_size=cache_size)
blocks = bt.get_consecutive_blocks(
10,
@ -158,25 +169,19 @@ class TestCoinStoreWithBlocks:
should_be_included_prev = should_be_included.copy()
should_be_included = set()
await connection.close()
Path("fndb_test.db").unlink()
@pytest.mark.asyncio
async def test_set_spent(self):
@pytest.mark.parametrize("cache_size", [0, 10, 100000])
async def test_set_spent(self, cache_size: uint32):
blocks = bt.get_consecutive_blocks(9, [])
for cache_size in [0, 10, 100000]:
db_path = Path("fndb_test.db")
if db_path.exists():
db_path.unlink()
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
coin_store = await CoinStore.create(db_wrapper, cache_size=uint32(cache_size))
async with DBConnection() as db_wrapper:
coin_store = await CoinStore.create(db_wrapper, cache_size=cache_size)
# Save/get block
for block in blocks:
if block.is_transaction_block():
removals, additions = [], []
removals: List[bytes32] = []
additions: List[Coin] = []
if block.is_transaction_block():
assert block.foliage_transaction_block is not None
@ -201,24 +206,20 @@ class TestCoinStoreWithBlocks:
assert record.spent
assert record.spent_block_index == block.height
await connection.close()
Path("fndb_test.db").unlink()
@pytest.mark.asyncio
async def test_rollback(self):
@pytest.mark.parametrize("cache_size", [0, 10, 100000])
async def test_rollback(self, cache_size: uint32):
blocks = bt.get_consecutive_blocks(20)
for cache_size in [0, 10, 100000]:
db_path = Path("fndb_test.db")
if db_path.exists():
db_path.unlink()
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
async with DBConnection() as db_wrapper:
coin_store = await CoinStore.create(db_wrapper, cache_size=uint32(cache_size))
records: List[Optional[CoinRecord]] = []
for block in blocks:
if block.is_transaction_block():
removals, additions = [], []
removals: List[bytes32] = []
additions: List[Coin] = []
if block.is_transaction_block():
assert block.foliage_transaction_block is not None
@ -231,17 +232,15 @@ class TestCoinStoreWithBlocks:
)
coins = block.get_included_reward_coins()
records: List[Optional[CoinRecord]] = [
await coin_store.get_coin_record(coin.name()) for coin in coins
]
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]
for record in records:
assert record is not None
await coin_store._set_spent(record.coin.name(), block.height)
records: List[Optional[CoinRecord]] = [
await coin_store.get_coin_record(coin.name()) for coin in coins
]
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]
for record in records:
assert record is not None
assert record.spent
assert record.spent_block_index == block.height
@ -251,9 +250,7 @@ class TestCoinStoreWithBlocks:
for block in blocks:
if block.is_transaction_block():
coins = block.get_included_reward_coins()
records: List[Optional[CoinRecord]] = [
await coin_store.get_coin_record(coin.name()) for coin in coins
]
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]
if block.height <= reorg_index:
for record in records:
@ -263,36 +260,33 @@ class TestCoinStoreWithBlocks:
for record in records:
assert record is None
await connection.close()
Path("fndb_test.db").unlink()
@pytest.mark.asyncio
async def test_basic_reorg(self):
for cache_size in [0, 10, 100000]:
@pytest.mark.parametrize("cache_size", [0, 10, 100000])
async def test_basic_reorg(self, cache_size: uint32):
async with DBConnection() as db_wrapper:
initial_block_count = 30
reorg_length = 15
blocks = bt.get_consecutive_blocks(initial_block_count)
db_path = Path("blockchain_test.db")
if db_path.exists():
db_path.unlink()
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
coin_store = await CoinStore.create(db_wrapper, cache_size=uint32(cache_size))
store = await BlockStore.create(db_wrapper)
b: Blockchain = await Blockchain.create(coin_store, store, test_constants)
try:
records: List[Optional[CoinRecord]] = []
for block in blocks:
await b.receive_block(block)
assert b.get_peak().height == initial_block_count - 1
peak = b.get_peak()
assert peak is not None
assert peak.height == initial_block_count - 1
for c, block in enumerate(blocks):
if block.is_transaction_block():
coins = block.get_included_reward_coins()
records: List[Optional[CoinRecord]] = [
await coin_store.get_coin_record(coin.name()) for coin in coins
]
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]
for record in records:
assert record is not None
assert not record.spent
assert record.confirmed_block_index == block.height
assert record.spent_block_index == 0
@ -312,28 +306,23 @@ class TestCoinStoreWithBlocks:
assert result == ReceiveBlockResult.NEW_PEAK
if reorg_block.is_transaction_block():
coins = reorg_block.get_included_reward_coins()
records: List[Optional[CoinRecord]] = [
await coin_store.get_coin_record(coin.name()) for coin in coins
]
records = [await coin_store.get_coin_record(coin.name()) for coin in coins]
for record in records:
assert record is not None
assert not record.spent
assert record.confirmed_block_index == reorg_block.height
assert record.spent_block_index == 0
assert error_code is None
assert b.get_peak().height == initial_block_count - 10 + reorg_length - 1
except Exception as e:
await connection.close()
Path("blockchain_test.db").unlink()
peak = b.get_peak()
assert peak is not None
assert peak.height == initial_block_count - 10 + reorg_length - 1
finally:
b.shut_down()
raise e
await connection.close()
Path("blockchain_test.db").unlink()
b.shut_down()
@pytest.mark.asyncio
async def test_get_puzzle_hash(self):
for cache_size in [0, 10, 100000]:
@pytest.mark.parametrize("cache_size", [0, 10, 100000])
async def test_get_puzzle_hash(self, cache_size: uint32):
async with DBConnection() as db_wrapper:
num_blocks = 20
farmer_ph = 32 * b"0"
pool_ph = 32 * b"1"
@ -343,11 +332,6 @@ class TestCoinStoreWithBlocks:
pool_reward_puzzle_hash=pool_ph,
guarantee_transaction_block=True,
)
db_path = Path("blockchain_test.db")
if db_path.exists():
db_path.unlink()
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
coin_store = await CoinStore.create(db_wrapper, cache_size=uint32(cache_size))
store = await BlockStore.create(db_wrapper)
b: Blockchain = await Blockchain.create(coin_store, store, test_constants)
@ -355,7 +339,9 @@ class TestCoinStoreWithBlocks:
res, err, _ = await b.receive_block(block)
assert err is None
assert res == ReceiveBlockResult.NEW_PEAK
assert b.get_peak().height == num_blocks - 1
peak = b.get_peak()
assert peak is not None
assert peak.height == num_blocks - 1
coins_farmer = await coin_store.get_coin_records_by_puzzle_hash(True, pool_ph)
coins_pool = await coin_store.get_coin_records_by_puzzle_hash(True, farmer_ph)
@ -363,6 +349,4 @@ class TestCoinStoreWithBlocks:
assert len(coins_farmer) == num_blocks - 2
assert len(coins_pool) == num_blocks - 2
await connection.close()
Path("blockchain_test.db").unlink()
b.shut_down()