plotting|util|tests: Introduce chia.util.generator_tools.list_to_batches (#9304)

* plotting|util: Introduce `chia.util.generator_tools.list_to_batches`

* tests: Test `list_to_batches` in  `test_list_to_batches.py`

* util|tests: Return an empty iterator for empty input lists

* tests: Test list sizes from 1-10 in `test_valid`
This commit is contained in:
dustinface 2021-11-24 20:27:31 +01:00 committed by GitHub
parent 280cc7f694
commit a95dfba70b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 11 deletions

View File

@ -4,7 +4,7 @@ import threading
import time import time
import traceback import traceback
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Iterator from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from blspy import G1Element from blspy import G1Element
@ -21,6 +21,7 @@ from chia.plotting.util import (
stream_plot_info_pk, stream_plot_info_pk,
stream_plot_info_ph, stream_plot_info_ph,
) )
from chia.util.generator_tools import list_to_batches
from chia.util.ints import uint16 from chia.util.ints import uint16
from chia.util.path import mkdir from chia.util.path import mkdir
from chia.util.streamable import Streamable, streamable from chia.util.streamable import Streamable, streamable
@ -258,15 +259,7 @@ class PlotManager:
for filename in filenames_to_remove: for filename in filenames_to_remove:
del self.plot_filename_paths[filename] del self.plot_filename_paths[filename]
def batches() -> Iterator[Tuple[int, List[Path]]]: for remaining, batch in list_to_batches(plot_paths, self.refresh_parameter.batch_size):
if total_size > 0:
for batch_start in range(0, total_size, self.refresh_parameter.batch_size):
batch_end = min(batch_start + self.refresh_parameter.batch_size, total_size)
yield total_size - batch_end, plot_paths[batch_start:batch_end]
else:
yield 0, []
for remaining, batch in batches():
batch_result: PlotRefreshResult = self.refresh_batch(batch, plot_directories) batch_result: PlotRefreshResult = self.refresh_batch(batch, plot_directories)
if not self._refreshing_enabled: if not self._refreshing_enabled:
self.log.debug("refresh_plots: Aborted") self.log.debug("refresh_plots: Aborted")

View File

@ -1,4 +1,4 @@
from typing import List, Tuple from typing import Any, Iterator, List, Tuple
from chiabip158 import PyBIP158 from chiabip158 import PyBIP158
from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.coin import Coin
@ -64,3 +64,14 @@ def tx_removals_and_additions(npc_list: List[NPC]) -> Tuple[List[bytes32], List[
additions.extend(additions_for_npc(npc_list)) additions.extend(additions_for_npc(npc_list))
return removals, additions return removals, additions
def list_to_batches(list_to_split: List[Any], batch_size: int) -> Iterator[Tuple[int, List[Any]]]:
if batch_size <= 0:
raise ValueError("list_to_batches: batch_size must be greater than 0.")
total_size = len(list_to_split)
if total_size == 0:
return iter(())
for batch_start in range(0, total_size, batch_size):
batch_end = min(batch_start + batch_size, total_size)
yield total_size - batch_end, list_to_split[batch_start:batch_end]

View File

@ -0,0 +1,30 @@
import pytest
from chia.util.generator_tools import list_to_batches
def test_empty_lists():
# An empty list should return an empty iterator and skip the loop's body.
for _, _ in list_to_batches([], 1):
assert False
def test_valid():
for k in range(1, 10):
test_list = [x for x in range(0, k)]
for i in range(1, len(test_list) + 1): # Test batch_size 1 to 11 (length + 1)
checked = 0
for remaining, batch in list_to_batches(test_list, i):
assert remaining == max(len(test_list) - checked - i, 0)
assert len(batch) <= i
assert batch == test_list[checked : min(checked + i, len(test_list))]
checked += len(batch)
assert checked == len(test_list)
def test_invalid_batch_sizes():
with pytest.raises(ValueError):
for _ in list_to_batches([], 0):
assert False
with pytest.raises(ValueError):
for _ in list_to_batches([], -1):
assert False