Implement get_items_by_coin_ids() for the mempool (#15069)

Implement get_items_by_coin_ids() for the mempool.
This commit is contained in:
Amine Khaldi 2023-04-27 15:55:01 +01:00 committed by GitHub
parent 1748e35526
commit de6fb526b5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 0 deletions

View File

@ -180,6 +180,17 @@ class Mempool:
)
return [self._row_to_item(row) for row in cursor]
def get_items_by_coin_ids(self, spent_coin_ids: List[bytes32]) -> List[MempoolItem]:
items: List[MempoolItem] = []
for coin_ids in chunks(spent_coin_ids, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(coin_ids))
with self._db_conn:
cursor = self._db_conn.execute(
f"SELECT * FROM tx WHERE name IN (SELECT tx FROM spends WHERE coin_id IN ({args}))", tuple(coin_ids)
)
items.extend(self._row_to_item(row) for row in cursor)
return items
def get_min_fee_rate(self, cost: int) -> float:
"""
Gets the minimum fpc rate that a transaction with specified cost will need in order to get included.

View File

@ -2833,3 +2833,40 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L
print(f"- cost: {item.cost} TTL: {ttl}")
assert mempool.total_mempool_cost() > 90
@pytest.mark.parametrize(
"items,coin_ids,expected",
[
# None of these spend those coins
(
[mk_item(coins[0:1]), mk_item(coins[1:2]), mk_item(coins[2:3])],
[coins[3].name(), coins[4].name()],
[],
),
# One of these spends one of the coins
(
[mk_item(coins[0:1]), mk_item(coins[1:2]), mk_item(coins[2:3])],
[coins[1].name(), coins[3].name()],
[mk_item(coins[1:2])],
),
# One of these spends one another spends two
(
[mk_item(coins[0:1]), mk_item(coins[1:3]), mk_item(coins[2:4]), mk_item(coins[3:4])],
[coins[2].name(), coins[3].name()],
[mk_item(coins[1:3]), mk_item(coins[2:4]), mk_item(coins[3:4])],
),
],
)
def test_get_items_by_coin_ids(items: List[MempoolItem], coin_ids: List[bytes32], expected: List[MempoolItem]) -> None:
fee_estimator = create_bitcoin_fee_estimator(uint64(11000000000))
mempool_info = MempoolInfo(
CLVMCost(uint64(11000000000 * 3)),
FeeRate(uint64(1000000)),
CLVMCost(uint64(11000000000)),
)
mempool = Mempool(mempool_info, fee_estimator)
for i in items:
mempool.add_to_pool(i)
result = mempool.get_items_by_coin_ids(coin_ids)
assert set(result) == set(expected)