diff --git a/chia/types/blockchain_format/program.py b/chia/types/blockchain_format/program.py index 2f94f65d00549..904654eb38f4d 100644 --- a/chia/types/blockchain_format/program.py +++ b/chia/types/blockchain_format/program.py @@ -1,7 +1,7 @@ from __future__ import annotations import io -from typing import Any, Callable, Dict, Set, Tuple +from typing import Any, Callable, Dict, Optional, Set, Tuple from chia_rs import ALLOW_BACKREFS, run_chia_program, tree_hash from clvm import SExp @@ -177,6 +177,12 @@ class Program(SExp): def as_int(self) -> int: return int_from_bytes(self.as_atom()) + def as_atom(self) -> bytes: + ret: Optional[bytes] = self.atom + if ret is None: + raise ValueError("expected atom") + return ret + def __deepcopy__(self, memo): return type(self).from_bytes(bytes(self)) diff --git a/chia/wallet/conditions.py b/chia/wallet/conditions.py index d9b61b635ecaf..3b11d76ff8c40 100644 --- a/chia/wallet/conditions.py +++ b/chia/wallet/conditions.py @@ -43,8 +43,8 @@ class AggSigParent(Condition): @classmethod def from_program(cls, program: Program, parent_id: Optional[bytes32] = None) -> AggSigParent: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), parent_id, ) @@ -64,8 +64,8 @@ class AggSigPuzzle(Condition): @classmethod def from_program(cls, program: Program, puzzle_hash: Optional[bytes32] = None) -> AggSigPuzzle: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), puzzle_hash, ) @@ -85,8 +85,8 @@ class AggSigAmount(Condition): @classmethod def from_program(cls, program: Program, amount: Optional[uint64] = None) -> AggSigAmount: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), amount, ) @@ -112,8 +112,8 @@ class AggSigPuzzleAmount(Condition): amount: Optional[uint64] = None, ) -> AggSigPuzzleAmount: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), puzzle_hash, amount, ) @@ -140,8 +140,8 @@ class AggSigParentAmount(Condition): amount: Optional[uint64] = None, ) -> AggSigParentAmount: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), parent_id, amount, ) @@ -168,8 +168,8 @@ class AggSigParentPuzzle(Condition): puzzle_hash: Optional[bytes32] = None, ) -> AggSigParentPuzzle: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), parent_id, puzzle_hash, ) @@ -189,8 +189,8 @@ class AggSigUnsafe(Condition): @classmethod def from_program(cls, program: Program) -> AggSigUnsafe: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), ) @@ -215,8 +215,8 @@ class AggSigMe(Condition): additional_data: Optional[bytes32] = None, ) -> AggSigMe: return cls( - G1Element.from_bytes(program.at("rf").atom), - program.at("rrf").atom, + G1Element.from_bytes(program.at("rf").as_atom()), + program.at("rrf").as_atom(), coin_id, additional_data, ) @@ -241,9 +241,11 @@ class CreateCoin(Condition): def from_program(cls, program: Program) -> CreateCoin: potential_memos: Program = program.at("rrr") return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), uint64(program.at("rrf").as_int()), - None if potential_memos == Program.to(None) else [memo.atom for memo in potential_memos.at("f").as_iter()], + None + if potential_memos == Program.to(None) + else [memo.as_atom() for memo in potential_memos.at("f").as_iter()], ) @@ -302,7 +304,7 @@ class AssertCoinAnnouncement(Condition): asserted_msg: Optional[bytes] = None, ) -> AssertCoinAnnouncement: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), asserted_id, asserted_msg, ) @@ -328,7 +330,7 @@ class CreateCoinAnnouncement(Condition): @classmethod def from_program(cls, program: Program, coin_id: Optional[bytes32] = None) -> CreateCoinAnnouncement: return cls( - program.at("rf").atom, + program.at("rf").as_atom(), coin_id, ) @@ -371,7 +373,7 @@ class AssertPuzzleAnnouncement(Condition): asserted_msg: Optional[bytes] = None, ) -> AssertPuzzleAnnouncement: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), asserted_ph, asserted_msg, ) @@ -397,7 +399,7 @@ class CreatePuzzleAnnouncement(Condition): @classmethod def from_program(cls, program: Program, puzzle_hash: Optional[bytes32] = None) -> CreatePuzzleAnnouncement: return cls( - program.at("rf").atom, + program.at("rf").as_atom(), puzzle_hash, ) @@ -415,7 +417,7 @@ class AssertConcurrentSpend(Condition): @classmethod def from_program(cls, program: Program) -> AssertConcurrentSpend: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), ) @@ -432,7 +434,7 @@ class AssertConcurrentPuzzle(Condition): @classmethod def from_program(cls, program: Program) -> AssertConcurrentPuzzle: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), ) @@ -449,7 +451,7 @@ class AssertMyCoinID(Condition): @classmethod def from_program(cls, program: Program) -> AssertMyCoinID: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), ) @@ -466,7 +468,7 @@ class AssertMyParentID(Condition): @classmethod def from_program(cls, program: Program) -> AssertMyParentID: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), ) @@ -483,7 +485,7 @@ class AssertMyPuzzleHash(Condition): @classmethod def from_program(cls, program: Program) -> AssertMyPuzzleHash: return cls( - bytes32(program.at("rf").atom), + bytes32(program.at("rf").as_atom()), ) @@ -761,7 +763,7 @@ class AggSig(Condition): @classmethod def from_program(cls, program: Program, **kwargs: Optional[Union[uint64, bytes32]]) -> AggSig: - opcode: bytes = program.at("f").atom + opcode: bytes = program.at("f").as_atom() condition_driver: Condition = CONDITION_DRIVERS[opcode].from_program(program, **kwargs) return cls( # We are either parsing an agg sig condition, all of which have these, or we want to error @@ -794,7 +796,7 @@ class CreateAnnouncement(Condition): @classmethod def from_program(cls, program: Program, **kwargs: Optional[bytes32]) -> CreateAnnouncement: - if program.at("f").atom == ConditionOpcode.CREATE_COIN_ANNOUNCEMENT: + if program.at("f").as_atom() == ConditionOpcode.CREATE_COIN_ANNOUNCEMENT: coin_not_puzzle: bool = True condition: Union[CreateCoinAnnouncement, CreatePuzzleAnnouncement] = CreateCoinAnnouncement.from_program( program, **kwargs @@ -848,7 +850,7 @@ class AssertAnnouncement(Condition): @classmethod def from_program(cls, program: Program, **kwargs: Optional[bytes32]) -> AssertAnnouncement: - if program.at("f").atom == ConditionOpcode.ASSERT_COIN_ANNOUNCEMENT: + if program.at("f").as_atom() == ConditionOpcode.ASSERT_COIN_ANNOUNCEMENT: coin_not_puzzle: bool = True condition: Union[AssertCoinAnnouncement, AssertPuzzleAnnouncement] = AssertCoinAnnouncement.from_program( program, **kwargs @@ -1022,7 +1024,7 @@ class Timelock(Condition): @classmethod def from_program(cls, program: Program) -> Timelock: - opcode: bytes = program.at("f").atom + opcode: bytes = program.at("f").as_atom() if opcode in AFTER_TIMELOCK_OPCODES: after_not_before = True else: @@ -1131,7 +1133,7 @@ def parse_conditions_non_consensus( final_condition_list: List[Condition] = [] for condition in conditions: try: - final_condition_list.append(driver_dictionary[condition.at("f").atom].from_program(condition)) + final_condition_list.append(driver_dictionary[condition.at("f").as_atom()].from_program(condition)) except Exception: final_condition_list.append(UnknownCondition.from_program(condition)) diff --git a/tests/wallet/test_conditions.py b/tests/wallet/test_conditions.py index 9357ae82125ea..8292a5b31b4c3 100644 --- a/tests/wallet/test_conditions.py +++ b/tests/wallet/test_conditions.py @@ -5,6 +5,7 @@ from typing import Any, List, Optional, Tuple, Type, Union import pytest from clvm.casts import int_from_bytes +from clvm.EvalError import EvalError from chia.types.blockchain_format.program import Program from chia.types.blockchain_format.sized_bytes import bytes32 @@ -13,22 +14,43 @@ from chia.util.ints import uint32, uint64 from chia.wallet.conditions import ( CONDITION_DRIVERS, CONDITION_DRIVERS_W_ABSTRACTIONS, + AggSig, + AggSigAmount, + AggSigMe, + AggSigParent, + AggSigParentAmount, + AggSigParentPuzzle, + AggSigPuzzle, + AggSigPuzzleAmount, + AggSigUnsafe, AssertAnnouncement, AssertBeforeHeightAbsolute, AssertBeforeHeightRelative, AssertBeforeSecondsAbsolute, AssertBeforeSecondsRelative, AssertCoinAnnouncement, + AssertConcurrentPuzzle, + AssertConcurrentSpend, AssertHeightAbsolute, AssertHeightRelative, + AssertMyAmount, + AssertMyBirthHeight, + AssertMyBirthSeconds, + AssertMyCoinID, + AssertMyParentID, + AssertMyPuzzleHash, AssertPuzzleAnnouncement, AssertSecondsAbsolute, AssertSecondsRelative, Condition, ConditionValidTimes, CreateAnnouncement, + CreateCoin, CreateCoinAnnouncement, CreatePuzzleAnnouncement, + Remark, + ReserveFee, + Softfork, Timelock, UnknownCondition, conditions_from_json_dicts, @@ -243,3 +265,106 @@ def test_timelock_parsing(timelock_info: TimelockInfo) -> None: assert timelock_info.parsed_info.to_conditions() == ( timelock_info.conditions_after if timelock_info.conditions_after is not None else timelock_info.drivers ) + + +@pytest.mark.parametrize( + "cond", + [ + AggSigParent, + AggSigPuzzle, + AggSigAmount, + AggSigPuzzleAmount, + AggSigParentAmount, + AggSigParentPuzzle, + AggSigUnsafe, + AggSigMe, + CreateCoin, + ReserveFee, + AssertCoinAnnouncement, + CreateCoinAnnouncement, + AssertPuzzleAnnouncement, + CreatePuzzleAnnouncement, + AssertConcurrentSpend, + AssertConcurrentPuzzle, + AssertMyCoinID, + AssertMyParentID, + AssertMyPuzzleHash, + AssertMyAmount, + AssertMyBirthSeconds, + AssertMyBirthHeight, + AssertSecondsRelative, + AssertSecondsAbsolute, + AssertHeightRelative, + AssertHeightAbsolute, + AssertBeforeSecondsRelative, + AssertBeforeSecondsAbsolute, + AssertBeforeHeightRelative, + AssertBeforeHeightAbsolute, + Softfork, + Remark, + UnknownCondition, + AggSig, + CreateAnnouncement, + AssertAnnouncement, + Timelock, + ], +) +@pytest.mark.parametrize( + "prg", + [ + bytes([0x80]), + bytes([0xFF, 0x80, 0xFF, 0xFF, 0xFF, 0x80, 0x80, 0x80, 0x80]), + bytes([0xFF, 0x80, 0xFF, 0xFF, 0x80, 0x80, 0xFF, 0x80, 0x80]), + bytes([0xFF, 0x80, 0xFF, 0xFF, 0x80, 0x80, 0xFF, 0x80, 0xFF, 0x80, 0x80]), + bytes([0xFF, 0x80, 0xFF, 0xFF, 0x80, 0x80, 0xFF, 0x80, 0xFF, 0x80, 0xFF, 0x80, 0x80]), + ], +) +def test_invalid_condition( + cond: Type[ + Union[ + AggSigParent, + AggSigPuzzle, + AggSigAmount, + AggSigPuzzleAmount, + AggSigParentAmount, + AggSigParentPuzzle, + AggSigUnsafe, + AggSigMe, + CreateCoin, + ReserveFee, + AssertCoinAnnouncement, + CreateCoinAnnouncement, + AssertPuzzleAnnouncement, + CreatePuzzleAnnouncement, + AssertConcurrentSpend, + AssertConcurrentPuzzle, + AssertMyCoinID, + AssertMyParentID, + AssertMyPuzzleHash, + AssertMyAmount, + AssertMyBirthSeconds, + AssertMyBirthHeight, + AssertSecondsRelative, + AssertSecondsAbsolute, + AssertHeightRelative, + AssertHeightAbsolute, + AssertBeforeSecondsRelative, + AssertBeforeSecondsAbsolute, + AssertBeforeHeightRelative, + AssertBeforeHeightAbsolute, + Softfork, + Remark, + UnknownCondition, + AggSig, + CreateAnnouncement, + AssertAnnouncement, + Timelock, + ] + ], + prg: bytes, +) -> None: + if (cond == Remark or cond == UnknownCondition) and prg != b"\x80": + pytest.skip("condition takes arbitrary arguments") + + with pytest.raises((ValueError, EvalError, KeyError)): + cond.from_program(Program.from_bytes(prg))