diff --git a/chia/daemon/server.py b/chia/daemon/server.py index 7d309927b6304..9772b9da797dc 100644 --- a/chia/daemon/server.py +++ b/chia/daemon/server.py @@ -875,6 +875,8 @@ class WebSocketServer: device_index = request.get("device", None) t1 = request.get("t", None) # Temp directory t2 = request.get("t2", None) # Temp2 directory + disk_128 = request.get("disk_128", False) + disk_16 = request.get("disk_16", False) if device_index is not None and str(device_index).isdigit(): command_args.append("--device") @@ -885,6 +887,10 @@ class WebSocketServer: if t2 is not None: command_args.append("-2") command_args.append(t2) + if disk_128: + command_args.append("--disk-128") + if disk_16: + command_args.append("--disk-16") return command_args # if plot_type == "diskplot" diff --git a/chia/plotters/bladebit.py b/chia/plotters/bladebit.py index 752d91a257ab3..cd23d78e700fb 100644 --- a/chia/plotters/bladebit.py +++ b/chia/plotters/bladebit.py @@ -365,6 +365,10 @@ def plot_bladebit(args, chia_root_path, root_path): if "device" in args and str(args.device).isdigit(): call_args.append("--device") call_args.append(str(args.device)) + if "disk_128" in args and args.disk_128: + call_args.append("--disk-128") + if "disk_16" in args and args.disk_16: + call_args.append("--disk-16") call_args.append(args.finaldir) diff --git a/chia/plotters/plotters.py b/chia/plotters/plotters.py index 5325ee87ba759..2d3337b37a5fe 100644 --- a/chia/plotters/plotters.py +++ b/chia/plotters/plotters.py @@ -52,6 +52,8 @@ class Options(Enum): COMPRESSION = 37 BLADEBIT_DEVICE_INDEX = 38 CUDA_TMP_DIR = 40 + BLADEBIT_HYBRID_128_MODE = 41 + BLADEBIT_HYBRID_16_MODE = 42 chia_plotter_options = [ @@ -112,6 +114,8 @@ bladebit_cuda_plotter_options = [ Options.FINAL_DIR, Options.COMPRESSION, Options.BLADEBIT_DEVICE_INDEX, + Options.BLADEBIT_HYBRID_128_MODE, + Options.BLADEBIT_HYBRID_16_MODE, ] bladebit_ram_plotter_options = [ @@ -464,6 +468,20 @@ def build_parser(subparsers, root_path, option_list, name, plotter_desc): help="The CUDA device index", default=0, ) + if option is Options.BLADEBIT_HYBRID_128_MODE: + parser.add_argument( + "--disk-128", + action="store_true", + help="Enable hybrid disk plotting for 128G system RAM", + default=False, + ) + if option is Options.BLADEBIT_HYBRID_16_MODE: + parser.add_argument( + "--disk-16", + action="store_true", + help="Enable hybrid disk plotting for 16G system RAM", + default=False, + ) def call_plotters(root_path: Path, args): diff --git a/setup.py b/setup.py index a0b7a9c54fd48..3482aa9454023 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,7 @@ dev_dependencies = [ "pytest==7.4.0", "pytest-asyncio==0.21.1", "pytest-cov==4.1.0", + "pytest-mock==3.11.1", "pytest-monitor==1.6.6; sys_platform == 'linux'", "pytest-xdist==3.3.1", "twine==4.0.2", diff --git a/tests/core/daemon/test_daemon.py b/tests/core/daemon/test_daemon.py index 5cc22a04177ba..d58899bb3d3ef 100644 --- a/tests/core/daemon/test_daemon.py +++ b/tests/core/daemon/test_daemon.py @@ -4,12 +4,14 @@ import asyncio import json import logging from dataclasses import dataclass, field, replace +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast import aiohttp import pkg_resources import pytest from aiohttp.web_ws import WebSocketResponse +from pytest_mock import MockerFixture from chia.daemon.client import connect_to_daemon from chia.daemon.keychain_server import ( @@ -22,6 +24,7 @@ from chia.daemon.keychain_server import ( SetLabelRequest, ) from chia.daemon.server import WebSocketServer, plotter_log_path, service_plotter +from chia.plotters.plotters import call_plotters from chia.server.outbound_message import NodeType from chia.simulator.block_tools import BlockTools from chia.simulator.keyring import TempKeyring @@ -70,6 +73,63 @@ class KeysForPlotCase: marks: Marks = () +@dataclass +class ChiaPlottersBladebitArgsCase: + case_id: str + plot_type: str + count: int = 1 + threads: int = 0 + pool_contract: str = "txch1xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + compress: int = 1 + device: int = 0 + hybrid_disk_mode: Optional[int] = None + farmer_pk: str = "" + final_dir: str = "" + marks: Marks = () + + @property + def id(self) -> str: + return self.case_id + + def to_command_array(self) -> List[str]: + command: List[str] = ["bladebit", self.plot_type] + command += ["-r", str(self.threads)] + command += ["-n", str(self.count)] + command += ["-c", self.pool_contract] + command += ["-f", self.farmer_pk] + command += ["--compress", str(self.compress)] + if self.plot_type == "cudaplot": + command += ["--device", str(self.device)] + if self.hybrid_disk_mode is not None: + command += [f"--disk-{self.hybrid_disk_mode}"] + command += ["-d", str(self.final_dir)] + + return command + + def expected_raw_command_args(self): + raw_args = [] + raw_args += [ + "--threads", + str(self.threads), + "--count", + str(self.count), + "--farmer-key", + str(self.farmer_pk), + "--pool-contract", + str(self.pool_contract), + ] + # --compress is "1" by default + raw_args += ["--compress", str(self.compress) if self.compress is not None else "1"] + raw_args += [self.plot_type] + if self.plot_type == "cudaplot": + # --device is "0" by default + raw_args += ["--device", str(self.device) if self.device is not None else "0"] + if self.hybrid_disk_mode is not None: + raw_args += [f"--disk-{self.hybrid_disk_mode}"] + raw_args += [str(self.final_dir)] + return raw_args + + # Simple class that responds to a poll() call used by WebSocketServer.is_running() @dataclass class Service: @@ -1632,6 +1692,42 @@ async def test_plotter_errors( "success": True, }, ), + RouteCase( + route="start_plotting", + description="bladebit - cudaplot - hybrid 128 mode", + request={ + **plotter_request_ref, + "plotter": "bladebit", + "plot_type": "cudaplot", + "w": True, + "m": True, + "no_cpu_affinity": True, + "e": False, + "compress": 1, + "disk_128": True, + }, + response={ + "success": True, + }, + ), + RouteCase( + route="start_plotting", + description="bladebit - cudaplot - hybrid 16 mode", + request={ + **plotter_request_ref, + "plotter": "bladebit", + "plot_type": "cudaplot", + "w": True, + "m": True, + "no_cpu_affinity": True, + "e": False, + "compress": 1, + "disk_16": True, + }, + response={ + "success": True, + }, + ), RouteCase( route="start_plotting", description="madmax", @@ -1857,3 +1953,37 @@ async def test_plotter_stop_plotting( # 5) Finally, get the "ack" for the stop_plotting payload response = await ws.receive() assert_response(response, {"success": True}, stop_plotting_request_id) + + +@datacases( + ChiaPlottersBladebitArgsCase(case_id="1", plot_type="cudaplot"), + ChiaPlottersBladebitArgsCase(case_id="2", plot_type="cudaplot", hybrid_disk_mode=16), + ChiaPlottersBladebitArgsCase(case_id="3", plot_type="cudaplot", hybrid_disk_mode=128), +) +def test_run_plotter_bladebit( + mocker: MockerFixture, + mock_daemon_with_config_and_keys, + bt: BlockTools, + case: ChiaPlottersBladebitArgsCase, +) -> None: + root_path = bt.root_path + + case.farmer_pk = bytes(bt.farmer_pk).hex() + case.final_dir = str(bt.plot_dir) + + def bladebit_exists(x: Path) -> bool: + return True if isinstance(x, Path) and x.parent == root_path / "plotters" else mocker.DEFAULT + + def get_bladebit_version(_: Path) -> Tuple[bool, List[str]]: + return True, ["3", "0", "0"] + + mocker.patch("os.path.exists", side_effect=bladebit_exists) + mocker.patch("chia.plotters.bladebit.get_bladebit_version", side_effect=get_bladebit_version) + mock_run_plotter = mocker.patch("chia.plotters.bladebit.run_plotter") + + call_plotters(root_path, case.to_command_array()) + + assert mock_run_plotter.call_args.args[0] == root_path + assert mock_run_plotter.call_args.args[1] == "bladebit" + assert mock_run_plotter.call_args.args[2][1:] == case.expected_raw_command_args() + mock_run_plotter.assert_called_once()