assert_before_height, assert_before_seconds fields in MempoolItem (#14931)

fix issue where assert_before_height and assert_before_seconds fields in MempoolItem would not be populated
This commit is contained in:
Arvid Norberg 2023-03-30 19:06:22 +02:00 committed by GitHub
parent 1577f4aa33
commit 20bfba0fa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 18 deletions

View File

@ -96,12 +96,21 @@ class Mempool:
def _row_to_item(self, row: sqlite3.Row) -> MempoolItem:
name = bytes32(row[0])
fee = int(row[1])
assert_height = row[2]
fee = int(row[2])
assert_height = row[3]
assert_before_height = row[4]
assert_before_seconds = row[5]
item = self._items[name]
return MempoolItem(
item.spend_bundle, uint64(fee), item.npc_result, name, uint32(item.height_added_to_mempool), assert_height
item.spend_bundle,
uint64(fee),
item.npc_result,
name,
uint32(item.height_added_to_mempool),
assert_height,
assert_before_height,
assert_before_seconds,
)
def total_mempool_fees(self) -> int:
@ -118,7 +127,7 @@ class Mempool:
def all_spends(self) -> Iterator[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT name, fee, assert_height FROM tx")
cursor = self._db_conn.execute("SELECT * FROM tx")
for row in cursor:
yield self._row_to_item(row)
@ -131,7 +140,7 @@ class Mempool:
# bit more efficiently
def spends_by_feerate(self) -> Iterator[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT name, fee, assert_height FROM tx ORDER BY fee_per_cost DESC")
cursor = self._db_conn.execute("SELECT * FROM tx ORDER BY fee_per_cost DESC")
for row in cursor:
yield self._row_to_item(row)
@ -143,7 +152,7 @@ class Mempool:
def get_spend_by_id(self, spend_bundle_id: bytes32) -> Optional[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT name, fee, assert_height FROM tx WHERE name=?", (spend_bundle_id,))
cursor = self._db_conn.execute("SELECT * FROM tx WHERE name=?", (spend_bundle_id,))
row = cursor.fetchone()
return None if row is None else self._row_to_item(row)
@ -151,7 +160,7 @@ class Mempool:
def get_spends_by_coin_id(self, spent_coin_id: bytes32) -> List[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute(
"SELECT name, fee, assert_height FROM tx WHERE name in (SELECT tx FROM spends WHERE coin_id=?)",
"SELECT * FROM tx WHERE name in (SELECT tx FROM spends WHERE coin_id=?)",
(spent_coin_id,),
)
return [self._row_to_item(row) for row in cursor]

View File

@ -928,24 +928,26 @@ async def test_create_bundle_from_mempool_on_max_cost() -> None:
@pytest.mark.parametrize(
"opcode,arg,expect_eviction",
"opcode,arg,expect_eviction, expect_limit",
[
# current height: 10 current_time: 10000
# we step the chain forward 1 block and 19 seconds
(co.ASSERT_BEFORE_SECONDS_ABSOLUTE, 10001, True),
(co.ASSERT_BEFORE_SECONDS_ABSOLUTE, 10019, True),
(co.ASSERT_BEFORE_SECONDS_ABSOLUTE, 10020, False),
(co.ASSERT_BEFORE_HEIGHT_ABSOLUTE, 11, True),
(co.ASSERT_BEFORE_HEIGHT_ABSOLUTE, 12, False),
(co.ASSERT_BEFORE_SECONDS_ABSOLUTE, 10001, True, None),
(co.ASSERT_BEFORE_SECONDS_ABSOLUTE, 10019, True, None),
(co.ASSERT_BEFORE_SECONDS_ABSOLUTE, 10020, False, 10020),
(co.ASSERT_BEFORE_HEIGHT_ABSOLUTE, 11, True, None),
(co.ASSERT_BEFORE_HEIGHT_ABSOLUTE, 12, False, 12),
# the coin was created at height: 5 timestamp: 9900
(co.ASSERT_BEFORE_HEIGHT_RELATIVE, 6, True),
(co.ASSERT_BEFORE_HEIGHT_RELATIVE, 7, False),
(co.ASSERT_BEFORE_SECONDS_RELATIVE, 119, True),
(co.ASSERT_BEFORE_SECONDS_RELATIVE, 120, False),
(co.ASSERT_BEFORE_HEIGHT_RELATIVE, 6, True, None),
(co.ASSERT_BEFORE_HEIGHT_RELATIVE, 7, False, 5 + 7),
(co.ASSERT_BEFORE_SECONDS_RELATIVE, 119, True, None),
(co.ASSERT_BEFORE_SECONDS_RELATIVE, 120, False, 9900 + 120),
],
)
@pytest.mark.asyncio
async def test_assert_before_expiration(opcode: ConditionOpcode, arg: int, expect_eviction: bool) -> None:
async def test_assert_before_expiration(
opcode: ConditionOpcode, arg: int, expect_eviction: bool, expect_limit: Optional[int]
) -> None:
async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
return {TEST_COIN.name(): CoinRecord(TEST_COIN, uint32(5), uint32(0), False, uint64(9900))}.get(coin_id)
@ -973,6 +975,16 @@ async def test_assert_before_expiration(opcode: ConditionOpcode, arg: int, expec
still_in_pool = mempool_manager.get_spendbundle(bundle_name) == bundle
assert still_in_pool != expect_eviction
if still_in_pool:
assert expect_limit is not None
item = mempool_manager.get_mempool_item(bundle_name)
assert item is not None
if opcode in [co.ASSERT_BEFORE_SECONDS_ABSOLUTE, co.ASSERT_BEFORE_SECONDS_RELATIVE]:
assert item.assert_before_seconds == expect_limit
elif opcode in [co.ASSERT_BEFORE_HEIGHT_ABSOLUTE, co.ASSERT_BEFORE_HEIGHT_RELATIVE]:
assert item.assert_before_height == expect_limit
else:
assert False
def make_test_spendbundle(coin: Coin, *, fee: int = 0) -> SpendBundle: