util: Drop chia.util.chunks and use chia.util.to_batches instead (#15418)

This commit is contained in:
dustinface 2023-06-09 23:35:31 +02:00 committed by GitHub
parent 2441035afc
commit 5e0b8bb1ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 41 additions and 70 deletions

View File

@ -20,8 +20,8 @@ from chia.types.coin_record import CoinRecord
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.types.spend_bundle_conditions import Spend, SpendBundleConditions
from chia.util.chunks import chunks
from chia.util.ints import uint32, uint64
from chia.util.misc import to_batches
NUM_ITERS = 200
NUM_PEERS = 5
@ -144,12 +144,12 @@ async def run_mempool_benchmark() -> None:
bundles = []
print(" large spend bundles")
for coins in chunks(unspent, 200):
print(f"{len(coins)} coins")
for batch in to_batches(unspent, 200):
print(f"{len(batch.entries)} coins")
tx = SpendBundle.aggregate(
[
wt.generate_signed_transaction(uint64(c.amount // 2), wt.get_new_puzzlehash(), c, fee=peer + idx)
for c in coins
for c in batch.entries
]
)
bundles.append(tx)

View File

@ -13,7 +13,6 @@ from chia.protocols.wallet_protocol import CoinState
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.util.chunks import chunks
from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER, DBWrapper2
from chia.util.ints import uint32, uint64
from chia.util.lru_cache import LRUCache
@ -186,12 +185,12 @@ class CoinStore:
async with self.db_wrapper.reader_no_transaction() as conn:
cursors: List[Cursor] = []
for names_chunk in chunks(names, SQLITE_MAX_VARIABLE_NUMBER):
for batch in to_batches(names, SQLITE_MAX_VARIABLE_NUMBER):
names_db: Tuple[Any, ...]
if self.db_wrapper.db_version == 2:
names_db = tuple(names_chunk)
names_db = tuple(batch.entries)
else:
names_db = tuple(n.hex() for n in names_chunk)
names_db = tuple(n.hex() for n in batch.entries)
cursors.append(
await conn.execute(
f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
@ -408,15 +407,15 @@ class CoinStore:
coins = set()
async with self.db_wrapper.reader_no_transaction() as conn:
for ids in chunks(parent_ids, SQLITE_MAX_VARIABLE_NUMBER):
for batch in to_batches(parent_ids, SQLITE_MAX_VARIABLE_NUMBER):
parent_ids_db: Tuple[Any, ...]
if self.db_wrapper.db_version == 2:
parent_ids_db = tuple(ids)
parent_ids_db = tuple(batch.entries)
else:
parent_ids_db = tuple([pid.hex() for pid in ids])
parent_ids_db = tuple([pid.hex() for pid in batch.entries])
async with conn.execute(
f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
f'coin_parent, amount, timestamp FROM coin_record WHERE coin_parent in ({"?," * (len(ids) - 1)}?) '
f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, coin_parent, amount, timestamp "
f'FROM coin_record WHERE coin_parent in ({"?," * (len(batch.entries) - 1)}?) '
f"AND confirmed_index>=? AND confirmed_index<? "
f"{'' if include_spent_coins else 'AND spent_index=0'}",
parent_ids_db + (start_height, end_height),
@ -559,15 +558,15 @@ class CoinStore:
async with self.db_wrapper.writer_maybe_transaction() as conn:
rows_updated: int = 0
for coin_names_chunk in chunks(coin_names, SQLITE_MAX_VARIABLE_NUMBER):
name_params = ",".join(["?"] * len(coin_names_chunk))
for batch in to_batches(coin_names, SQLITE_MAX_VARIABLE_NUMBER):
name_params = ",".join(["?"] * len(batch.entries))
if self.db_wrapper.db_version == 2:
ret: Cursor = await conn.execute(
f"UPDATE coin_record INDEXED BY sqlite_autoindex_coin_record_1 "
f"SET spent_index={index} "
f"WHERE spent_index=0 "
f"AND coin_name IN ({name_params})",
coin_names_chunk,
batch.entries,
)
else:
ret = await conn.execute(
@ -575,7 +574,7 @@ class CoinStore:
f"SET spent=1, spent_index={index} "
f"WHERE spent_index=0 "
f"AND coin_name IN ({name_params})",
[name.hex() for name in coin_names_chunk],
[name.hex() for name in batch.entries],
)
rows_updated += ret.rowcount
if rows_updated != len(coin_names):

View File

@ -19,10 +19,10 @@ from chia.types.eligible_coin_spends import EligibleCoinSpends
from chia.types.internal_mempool_item import InternalMempoolItem
from chia.types.mempool_item import MempoolItem
from chia.types.spend_bundle import SpendBundle
from chia.util.chunks import chunks
from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER
from chia.util.errors import Err
from chia.util.ints import uint32, uint64
from chia.util.misc import to_batches
log = logging.getLogger(__name__)
@ -175,11 +175,12 @@ class Mempool:
def get_items_by_coin_ids(self, spent_coin_ids: List[bytes32]) -> List[MempoolItem]:
items: List[MempoolItem] = []
for coin_ids in chunks(spent_coin_ids, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(coin_ids))
for batch in to_batches(spent_coin_ids, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(batch.entries))
with self._db_conn:
cursor = self._db_conn.execute(
f"SELECT * FROM tx WHERE name IN (SELECT tx FROM spends WHERE coin_id IN ({args}))", tuple(coin_ids)
f"SELECT * FROM tx WHERE name IN (SELECT tx FROM spends WHERE coin_id IN ({args}))",
tuple(batch.entries),
)
items.extend(self._row_to_item(row) for row in cursor)
return items
@ -237,11 +238,11 @@ class Mempool:
removed_items: List[MempoolItemInfo] = []
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
for spend_bundle_ids in chunks(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(spend_bundle_ids))
for batch in to_batches(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(batch.entries))
with self._db_conn:
cursor = self._db_conn.execute(
f"SELECT name, cost, fee FROM tx WHERE name in ({args})", spend_bundle_ids
f"SELECT name, cost, fee FROM tx WHERE name in ({args})", batch.entries
)
for row in cursor:
name = bytes32(row[0])
@ -252,11 +253,11 @@ class Mempool:
for name in items:
self._items.pop(name)
for spend_bundle_ids in chunks(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(spend_bundle_ids))
for batch in to_batches(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(batch.entries))
with self._db_conn:
self._db_conn.execute(f"DELETE FROM tx WHERE name in ({args})", spend_bundle_ids)
self._db_conn.execute(f"DELETE FROM spends WHERE tx in ({args})", spend_bundle_ids)
self._db_conn.execute(f"DELETE FROM tx WHERE name in ({args})", batch.entries)
self._db_conn.execute(f"DELETE FROM spends WHERE tx in ({args})", batch.entries)
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
info = FeeMempoolInfo(

View File

@ -41,9 +41,9 @@ from chia.types.weight_proof import (
WeightProof,
)
from chia.util.block_cache import BlockCache
from chia.util.chunks import chunks
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.misc import to_batches
from chia.util.setproctitle import getproctitle, setproctitle
log = logging.getLogger(__name__)
@ -1670,11 +1670,10 @@ async def validate_weight_proof_inner(
if vdfs_to_validate is None:
return False, []
vdf_chunks = chunks(vdfs_to_validate, num_processes)
vdf_tasks = []
for chunk in vdf_chunks:
for batch in to_batches(vdfs_to_validate, num_processes):
byte_chunks = []
for vdf_proof, classgroup, vdf_info in chunk:
for vdf_proof, classgroup, vdf_info in batch.entries:
byte_chunks.append((bytes(vdf_proof), bytes(classgroup), bytes(vdf_info)))
vdf_task = asyncio.get_running_loop().run_in_executor(
executor,

View File

@ -1,11 +0,0 @@
from __future__ import annotations
from typing import Iterator, List, TypeVar
T = TypeVar("T")
def chunks(in_list: List[T], size: int) -> Iterator[List[T]]:
size = max(1, size)
for i in range(0, len(in_list), size):
yield in_list[i : i + size]

View File

@ -45,7 +45,6 @@ from chia.types.header_block import HeaderBlock
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.types.weight_proof import WeightProof
from chia.util.chunks import chunks
from chia.util.config import (
WALLET_PEERS_PATH_KEY_DEPRECATED,
lock_and_load_config,
@ -56,6 +55,7 @@ from chia.util.db_wrapper import manage_connection
from chia.util.errors import KeychainIsEmpty, KeychainIsLocked, KeychainKeyNotFound, KeychainProxyConnectionFailure
from chia.util.ints import uint16, uint32, uint64, uint128
from chia.util.keychain import Keychain
from chia.util.misc import to_batches
from chia.util.path import path_from_root
from chia.util.profiler import mem_profile_task, profile_task
from chia.util.streamable import Streamable, streamable
@ -780,8 +780,8 @@ class WalletNode:
not_checked_puzzle_hashes = set(all_puzzle_hashes) - already_checked_ph
if not_checked_puzzle_hashes == set():
break
for chunk in chunks(list(not_checked_puzzle_hashes), 1000):
ph_update_res: List[CoinState] = await subscribe_to_phs(chunk, full_node, 0)
for batch in to_batches(not_checked_puzzle_hashes, 1000):
ph_update_res: List[CoinState] = await subscribe_to_phs(batch.entries, full_node, 0)
ph_update_res = list(filter(is_new_state_update, ph_update_res))
if not await self.add_states_from_peer(ph_update_res, full_node):
# If something goes wrong, abort sync
@ -798,8 +798,8 @@ class WalletNode:
not_checked_coin_ids = set(all_coin_ids) - already_checked_coin_ids
if not_checked_coin_ids == set():
break
for chunk in chunks(list(not_checked_coin_ids), 1000):
c_update_res: List[CoinState] = await subscribe_to_coin_updates(chunk, full_node, 0)
for batch in to_batches(not_checked_coin_ids, 1000):
c_update_res: List[CoinState] = await subscribe_to_coin_updates(batch.entries, full_node, 0)
if not await self.add_states_from_peer(c_update_res, full_node):
# If something goes wrong, abort sync
@ -901,7 +901,7 @@ class WalletNode:
# Keep chunk size below 1000 just in case, windows has sqlite limits of 999 per query
# Untrusted has a smaller batch size since validation has to happen which takes a while
chunk_size: int = 900 if trusted else 10
for states in chunks(items, chunk_size):
for batch in to_batches(items, chunk_size):
if self._server is None:
self.log.error("No server")
await asyncio.gather(*all_tasks)
@ -912,8 +912,8 @@ class WalletNode:
return False
if trusted:
async with self.wallet_state_manager.db_wrapper.writer():
self.log.info(f"new coin state received ({idx}-{idx + len(states) - 1}/ {len(items)})")
if not await self.wallet_state_manager.add_coin_states(states, peer, fork_height):
self.log.info(f"new coin state received ({idx}-{idx + len(batch.entries) - 1}/ {len(items)})")
if not await self.wallet_state_manager.add_coin_states(batch.entries, peer, fork_height):
return False
else:
while len(all_tasks) >= target_concurrent_tasks:
@ -923,8 +923,8 @@ class WalletNode:
self.log.info("Terminating receipt and validation due to shut down request")
await asyncio.gather(*all_tasks)
return False
all_tasks.append(asyncio.create_task(validate_and_add(states, idx)))
idx += len(states)
all_tasks.append(asyncio.create_task(validate_and_add(batch.entries, idx)))
idx += len(batch.entries)
still_connected = self._server is not None and peer.peer_node_id in self.server.all_connections
await asyncio.gather(*all_tasks)

View File

@ -1,17 +0,0 @@
from __future__ import annotations
from chia.util.chunks import chunks
def test_chunks() -> None:
assert list(chunks([], 0)) == []
assert list(chunks(["a"], 0)) == [["a"]]
assert list(chunks(["a", "b"], 0)) == [["a"], ["b"]]
assert list(chunks(["a", "b", "c", "d"], -1)) == [["a"], ["b"], ["c"], ["d"]]
assert list(chunks(["a", "b", "c", "d"], 0)) == [["a"], ["b"], ["c"], ["d"]]
assert list(chunks(["a", "b", "c", "d"], 1)) == [["a"], ["b"], ["c"], ["d"]]
assert list(chunks(["a", "b", "c", "d"], 2)) == [["a", "b"], ["c", "d"]]
assert list(chunks(["a", "b", "c", "d"], 3)) == [["a", "b", "c"], ["d"]]
assert list(chunks(["a", "b", "c", "d"], 4)) == [["a", "b", "c", "d"]]
assert list(chunks(["a", "b", "c", "d"], 200)) == [["a", "b", "c", "d"]]