From 84169a881dc34014512e5284717ac08a1f624f67 Mon Sep 17 00:00:00 2001 From: Arvid Norberg Date: Sat, 9 Dec 2023 01:29:53 +0100 Subject: [PATCH] add SerializedProgram.to() to simplify some code (#17022) --- .../blockchain_format/serialized_program.py | 7 +++ tests/blockchain/test_blockchain.py | 2 +- tests/core/mempool/test_mempool.py | 6 ++- tests/core/mempool/test_mempool_manager.py | 16 +++--- tests/core/test_program.py | 49 +++++++++++++++++++ 5 files changed, 69 insertions(+), 11 deletions(-) create mode 100644 tests/core/test_program.py diff --git a/chia/types/blockchain_format/serialized_program.py b/chia/types/blockchain_format/serialized_program.py index 3df972e33a54..977423d12a00 100644 --- a/chia/types/blockchain_format/serialized_program.py +++ b/chia/types/blockchain_format/serialized_program.py @@ -5,6 +5,7 @@ from typing import Tuple, Type from chia_rs import MEMPOOL_MODE, run_chia_program, serialized_length, tree_hash from clvm import SExp +from clvm.SExp import CastableType from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 @@ -60,6 +61,12 @@ class SerializedProgram: ret._buf = bytes(p) return ret + @staticmethod + def to(o: CastableType) -> SerializedProgram: + ret = SerializedProgram() + ret._buf = Program.to(o).as_bin() + return ret + def to_program(self) -> Program: return Program.from_bytes(self._buf) diff --git a/tests/blockchain/test_blockchain.py b/tests/blockchain/test_blockchain.py index 374310c647e7..633593638500 100644 --- a/tests/blockchain/test_blockchain.py +++ b/tests/blockchain/test_blockchain.py @@ -2194,7 +2194,7 @@ class TestBodyValidation: blocks = bt.get_consecutive_blocks(1, block_list_input=blocks) original_block: FullBlock = blocks[-1] - block = recursive_replace(original_block, "transactions_generator", SerializedProgram()) + block = recursive_replace(original_block, "transactions_generator", SerializedProgram.to(None)) await _validate_and_add_block( empty_blockchain, block, expected_error=Err.NOT_BLOCK_BUT_HAS_DATA, skip_prevalidation=True ) diff --git a/tests/core/mempool/test_mempool.py b/tests/core/mempool/test_mempool.py index f306e58aec83..a82cdf6138fe 100644 --- a/tests/core/mempool/test_mempool.py +++ b/tests/core/mempool/test_mempool.py @@ -2958,6 +2958,8 @@ def test_aggregating_on_a_solution_then_a_more_cost_saving_one_appears() -> None def test_get_puzzle_and_solution_for_coin_failure(): with pytest.raises( - ValueError, match=f"Failed to get puzzle and solution for coin {TEST_COIN}, error: failed to fill whole buffer" + ValueError, match=f"Failed to get puzzle and solution for coin {TEST_COIN}, error: \\('coin not found', '80'\\)" ): - get_puzzle_and_solution_for_coin(BlockGenerator(SerializedProgram(), [], []), TEST_COIN, 0, test_constants) + get_puzzle_and_solution_for_coin( + BlockGenerator(SerializedProgram.to(None), [], []), TEST_COIN, 0, test_constants + ) diff --git a/tests/core/mempool/test_mempool_manager.py b/tests/core/mempool/test_mempool_manager.py index 348e0d208131..9d044b73f5d2 100644 --- a/tests/core/mempool/test_mempool_manager.py +++ b/tests/core/mempool/test_mempool_manager.py @@ -48,7 +48,7 @@ from chia.wallet.wallet_coin_record import WalletCoinRecord from chia.wallet.wallet_node import WalletNode from tests.util.setup_nodes import SimulatorsAndWallets -IDENTITY_PUZZLE = SerializedProgram.from_program(Program.to(1)) +IDENTITY_PUZZLE = SerializedProgram.to(1) IDENTITY_PUZZLE_HASH = IDENTITY_PUZZLE.get_tree_hash() TEST_TIMESTAMP = uint64(10040) @@ -665,7 +665,7 @@ def mk_item( ) -> MempoolItem: # we don't actually care about the puzzle and solutions for the purpose of # can_replace() - spends = [CoinSpend(c, SerializedProgram(), SerializedProgram()) for c in coins] + spends = [CoinSpend(c, SerializedProgram.to(None), SerializedProgram.to(None)) for c in coins] spend_bundle = SpendBundle(spends, G2Element()) npc_result = NPCResult(None, make_test_conds(cost=cost, spend_ids=[c.name() for c in coins]), uint64(cost)) return MempoolItem( @@ -1234,7 +1234,7 @@ def test_dedup_info_eligible_1st_time() -> None: sb = spend_bundle_from_conditions(conditions, TEST_COIN) mempool_item = mempool_item_from_spendbundle(sb) eligible_coin_spends = EligibleCoinSpends() - solution = SerializedProgram.from_program(Program.to(conditions)) + solution = SerializedProgram.to(conditions) unique_coin_spends, cost_saving, unique_additions = eligible_coin_spends.get_deduplication_info( bundle_coin_spends=mempool_item.bundle_coin_spends, max_cost=mempool_item.npc_result.cost ) @@ -1253,7 +1253,7 @@ def test_dedup_info_eligible_but_different_solution() -> None: [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2], ] - initial_solution = SerializedProgram.from_program(Program.to(initial_conditions)) + initial_solution = SerializedProgram.to(initial_conditions) eligible_coin_spends = EligibleCoinSpends({TEST_COIN_ID: DedupCoinSpend(solution=initial_solution, cost=None)}) conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2]] sb = spend_bundle_from_conditions(conditions, TEST_COIN) @@ -1270,11 +1270,11 @@ def test_dedup_info_eligible_2nd_time_and_another_1st_time() -> None: [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2], ] - initial_solution = SerializedProgram.from_program(Program.to(initial_conditions)) + initial_solution = SerializedProgram.to(initial_conditions) eligible_coin_spends = EligibleCoinSpends({TEST_COIN_ID: DedupCoinSpend(solution=initial_solution, cost=None)}) sb1 = spend_bundle_from_conditions(initial_conditions, TEST_COIN) second_conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 3]] - second_solution = SerializedProgram.from_program(Program.to(second_conditions)) + second_solution = SerializedProgram.to(second_conditions) sb2 = spend_bundle_from_conditions(second_conditions, TEST_COIN2) sb = SpendBundle.aggregate([sb1, sb2]) mempool_item = mempool_item_from_spendbundle(sb) @@ -1303,9 +1303,9 @@ def test_dedup_info_eligible_3rd_time_another_2nd_time_and_one_non_eligible() -> [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 1], [ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 2], ] - initial_solution = SerializedProgram.from_program(Program.to(initial_conditions)) + initial_solution = SerializedProgram.to(initial_conditions) second_conditions = [[ConditionOpcode.CREATE_COIN, IDENTITY_PUZZLE_HASH, 3]] - second_solution = SerializedProgram.from_program(Program.to(second_conditions)) + second_solution = SerializedProgram.to(second_conditions) saved_cost = uint64(3600044) eligible_coin_spends = EligibleCoinSpends( { diff --git a/tests/core/test_program.py b/tests/core/test_program.py new file mode 100644 index 000000000000..39fb2a6d38ca --- /dev/null +++ b/tests/core/test_program.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import List + +from clvm.SExp import CastableType +from clvm_tools import binutils + +from chia.types.blockchain_format.program import Program +from chia.types.blockchain_format.serialized_program import SerializedProgram +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.ints import uint32, uint64 + + +def program_roundtrip(o: CastableType) -> None: + prg1 = Program.to(o) + prg2 = SerializedProgram.to(o) + prg3 = SerializedProgram.from_program(prg1) + prg4 = SerializedProgram.from_bytes(prg1.as_bin()) + prg5 = prg2.to_program() + + assert bytes(prg1) == bytes(prg2) + assert bytes(prg1) == bytes(prg3) + assert bytes(prg1) == bytes(prg4) + assert bytes(prg1) == bytes(prg5) + + +def test_serialized_program_to() -> None: + prg = "(q ((0x0101010101010101010101010101010101010101010101010101010101010101 80 123 (() (q . ())))))" # noqa + tests: List[CastableType] = [ + 0, + 1, + (1, 2), + [0, 1, 2], + Program.to([1, 2, 3]), + SerializedProgram.to([1, 2, 3]), + b"123", + binutils.assemble(prg), # type: ignore[no-untyped-call] + [b"1", b"2", b"3"], + (b"1", (b"2", b"3")), + None, + -24, + bytes32.fromhex("0" * 64), + bytes.fromhex("0" * 6), + uint32(123), + uint64(123123), + ] + + for t in tests: + program_roundtrip(t)