wallet: Implement WalletCoinStore.delete_wallet (#15127)

This commit is contained in:
dustinface 2023-05-12 07:04:01 +07:00 committed by GitHub
parent e4c339e072
commit 0c4222f83e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 2 deletions

View File

@ -376,3 +376,9 @@ class WalletCoinStore:
)
).close()
self.total_count_cache.cache.clear()
async def delete_wallet(self, wallet_id: uint32) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
cursor = await conn.execute("DELETE FROM coin_record WHERE wallet_id=?", (wallet_id,))
await cursor.close()
self.total_count_cache.cache.clear()

View File

@ -1,8 +1,8 @@
from __future__ import annotations
from dataclasses import replace
from dataclasses import dataclass, field, replace
from secrets import token_bytes
from typing import List, Optional, Tuple
from typing import Dict, List, Optional, Tuple
import pytest
@ -102,6 +102,28 @@ record_9 = WalletCoinRecord(
)
def get_dummy_record(wallet_id: int) -> WalletCoinRecord:
return WalletCoinRecord(
Coin(token_bytes(32), token_bytes(32), uint64(12312)),
uint32(0),
uint32(0),
False,
False,
WalletType.STANDARD_WALLET,
wallet_id,
)
@dataclass
class DummyWalletCoinRecords:
records_per_wallet: Dict[int, List[WalletCoinRecord]] = field(default_factory=dict)
def generate(self, wallet_id: int, count: int) -> None:
records = self.records_per_wallet.setdefault(wallet_id, [])
for _ in range(count):
records.append(get_dummy_record(wallet_id))
@pytest.mark.parametrize(
"invalid_record, error",
[
@ -808,6 +830,7 @@ async def test_get_coin_records_total_count_cache_reset() -> None:
store.set_spent(coin_4.name(), 10),
store.delete_coin_record(record_4.name()),
store.rollback_to_block(1000),
store.delete_wallet(uint32(record_1.wallet_id)),
]:
await test_cache()
await trigger
@ -955,3 +978,24 @@ async def test_get_coin_records_between() -> None:
records = await store.get_coin_records_between(1, 0, 4, coin_type=CoinType.CLAWBACK)
assert len(records) == 1
assert records[0] == record_8
@pytest.mark.asyncio
async def test_delete_wallet() -> None:
dummy_records = DummyWalletCoinRecords()
for i in range(5):
dummy_records.generate(i, i * 5)
async with DBConnection(1) as wrapper:
store = await WalletCoinStore.create(wrapper)
# Add the records per wallet and verify them
for wallet_id, records in dummy_records.records_per_wallet.items():
for coin_record in records:
await store.add_coin_record(coin_record)
assert set((await store.get_coin_records(wallet_id=wallet_id)).records) == set(records)
# Remove one wallet after the other and verify before and after each
for wallet_id, records in dummy_records.records_per_wallet.items():
# Assert the existence again here to make sure the previous removals did not affect other wallet_ids
assert set((await store.get_coin_records(wallet_id=wallet_id)).records) == set(records)
# Remove the wallet_id and make sure its removed fully
await store.delete_wallet(wallet_id)
assert (await store.get_coin_records(wallet_id=wallet_id)).records == []