fix type mismatch with Optional[bytes] and bytes in wallet/conditions.py (#17030)

This commit is contained in:
Arvid Norberg 2023-12-11 21:13:01 +01:00 committed by GitHub
parent 3425b8bbfa
commit d26303cb0e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 166 additions and 33 deletions

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import io
from typing import Any, Callable, Dict, Set, Tuple
from typing import Any, Callable, Dict, Optional, Set, Tuple
from chia_rs import ALLOW_BACKREFS, run_chia_program, tree_hash
from clvm import SExp
@ -177,6 +177,12 @@ class Program(SExp):
def as_int(self) -> int:
return int_from_bytes(self.as_atom())
def as_atom(self) -> bytes:
ret: Optional[bytes] = self.atom
if ret is None:
raise ValueError("expected atom")
return ret
def __deepcopy__(self, memo):
return type(self).from_bytes(bytes(self))

View File

@ -43,8 +43,8 @@ class AggSigParent(Condition):
@classmethod
def from_program(cls, program: Program, parent_id: Optional[bytes32] = None) -> AggSigParent:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
parent_id,
)
@ -64,8 +64,8 @@ class AggSigPuzzle(Condition):
@classmethod
def from_program(cls, program: Program, puzzle_hash: Optional[bytes32] = None) -> AggSigPuzzle:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
puzzle_hash,
)
@ -85,8 +85,8 @@ class AggSigAmount(Condition):
@classmethod
def from_program(cls, program: Program, amount: Optional[uint64] = None) -> AggSigAmount:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
amount,
)
@ -112,8 +112,8 @@ class AggSigPuzzleAmount(Condition):
amount: Optional[uint64] = None,
) -> AggSigPuzzleAmount:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
puzzle_hash,
amount,
)
@ -140,8 +140,8 @@ class AggSigParentAmount(Condition):
amount: Optional[uint64] = None,
) -> AggSigParentAmount:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
parent_id,
amount,
)
@ -168,8 +168,8 @@ class AggSigParentPuzzle(Condition):
puzzle_hash: Optional[bytes32] = None,
) -> AggSigParentPuzzle:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
parent_id,
puzzle_hash,
)
@ -189,8 +189,8 @@ class AggSigUnsafe(Condition):
@classmethod
def from_program(cls, program: Program) -> AggSigUnsafe:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
)
@ -215,8 +215,8 @@ class AggSigMe(Condition):
additional_data: Optional[bytes32] = None,
) -> AggSigMe:
return cls(
G1Element.from_bytes(program.at("rf").atom),
program.at("rrf").atom,
G1Element.from_bytes(program.at("rf").as_atom()),
program.at("rrf").as_atom(),
coin_id,
additional_data,
)
@ -241,9 +241,11 @@ class CreateCoin(Condition):
def from_program(cls, program: Program) -> CreateCoin:
potential_memos: Program = program.at("rrr")
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
uint64(program.at("rrf").as_int()),
None if potential_memos == Program.to(None) else [memo.atom for memo in potential_memos.at("f").as_iter()],
None
if potential_memos == Program.to(None)
else [memo.as_atom() for memo in potential_memos.at("f").as_iter()],
)
@ -302,7 +304,7 @@ class AssertCoinAnnouncement(Condition):
asserted_msg: Optional[bytes] = None,
) -> AssertCoinAnnouncement:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
asserted_id,
asserted_msg,
)
@ -328,7 +330,7 @@ class CreateCoinAnnouncement(Condition):
@classmethod
def from_program(cls, program: Program, coin_id: Optional[bytes32] = None) -> CreateCoinAnnouncement:
return cls(
program.at("rf").atom,
program.at("rf").as_atom(),
coin_id,
)
@ -371,7 +373,7 @@ class AssertPuzzleAnnouncement(Condition):
asserted_msg: Optional[bytes] = None,
) -> AssertPuzzleAnnouncement:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
asserted_ph,
asserted_msg,
)
@ -397,7 +399,7 @@ class CreatePuzzleAnnouncement(Condition):
@classmethod
def from_program(cls, program: Program, puzzle_hash: Optional[bytes32] = None) -> CreatePuzzleAnnouncement:
return cls(
program.at("rf").atom,
program.at("rf").as_atom(),
puzzle_hash,
)
@ -415,7 +417,7 @@ class AssertConcurrentSpend(Condition):
@classmethod
def from_program(cls, program: Program) -> AssertConcurrentSpend:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
)
@ -432,7 +434,7 @@ class AssertConcurrentPuzzle(Condition):
@classmethod
def from_program(cls, program: Program) -> AssertConcurrentPuzzle:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
)
@ -449,7 +451,7 @@ class AssertMyCoinID(Condition):
@classmethod
def from_program(cls, program: Program) -> AssertMyCoinID:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
)
@ -466,7 +468,7 @@ class AssertMyParentID(Condition):
@classmethod
def from_program(cls, program: Program) -> AssertMyParentID:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
)
@ -483,7 +485,7 @@ class AssertMyPuzzleHash(Condition):
@classmethod
def from_program(cls, program: Program) -> AssertMyPuzzleHash:
return cls(
bytes32(program.at("rf").atom),
bytes32(program.at("rf").as_atom()),
)
@ -761,7 +763,7 @@ class AggSig(Condition):
@classmethod
def from_program(cls, program: Program, **kwargs: Optional[Union[uint64, bytes32]]) -> AggSig:
opcode: bytes = program.at("f").atom
opcode: bytes = program.at("f").as_atom()
condition_driver: Condition = CONDITION_DRIVERS[opcode].from_program(program, **kwargs)
return cls(
# We are either parsing an agg sig condition, all of which have these, or we want to error
@ -794,7 +796,7 @@ class CreateAnnouncement(Condition):
@classmethod
def from_program(cls, program: Program, **kwargs: Optional[bytes32]) -> CreateAnnouncement:
if program.at("f").atom == ConditionOpcode.CREATE_COIN_ANNOUNCEMENT:
if program.at("f").as_atom() == ConditionOpcode.CREATE_COIN_ANNOUNCEMENT:
coin_not_puzzle: bool = True
condition: Union[CreateCoinAnnouncement, CreatePuzzleAnnouncement] = CreateCoinAnnouncement.from_program(
program, **kwargs
@ -848,7 +850,7 @@ class AssertAnnouncement(Condition):
@classmethod
def from_program(cls, program: Program, **kwargs: Optional[bytes32]) -> AssertAnnouncement:
if program.at("f").atom == ConditionOpcode.ASSERT_COIN_ANNOUNCEMENT:
if program.at("f").as_atom() == ConditionOpcode.ASSERT_COIN_ANNOUNCEMENT:
coin_not_puzzle: bool = True
condition: Union[AssertCoinAnnouncement, AssertPuzzleAnnouncement] = AssertCoinAnnouncement.from_program(
program, **kwargs
@ -1022,7 +1024,7 @@ class Timelock(Condition):
@classmethod
def from_program(cls, program: Program) -> Timelock:
opcode: bytes = program.at("f").atom
opcode: bytes = program.at("f").as_atom()
if opcode in AFTER_TIMELOCK_OPCODES:
after_not_before = True
else:
@ -1131,7 +1133,7 @@ def parse_conditions_non_consensus(
final_condition_list: List[Condition] = []
for condition in conditions:
try:
final_condition_list.append(driver_dictionary[condition.at("f").atom].from_program(condition))
final_condition_list.append(driver_dictionary[condition.at("f").as_atom()].from_program(condition))
except Exception:
final_condition_list.append(UnknownCondition.from_program(condition))

View File

@ -5,6 +5,7 @@ from typing import Any, List, Optional, Tuple, Type, Union
import pytest
from clvm.casts import int_from_bytes
from clvm.EvalError import EvalError
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
@ -13,22 +14,43 @@ from chia.util.ints import uint32, uint64
from chia.wallet.conditions import (
CONDITION_DRIVERS,
CONDITION_DRIVERS_W_ABSTRACTIONS,
AggSig,
AggSigAmount,
AggSigMe,
AggSigParent,
AggSigParentAmount,
AggSigParentPuzzle,
AggSigPuzzle,
AggSigPuzzleAmount,
AggSigUnsafe,
AssertAnnouncement,
AssertBeforeHeightAbsolute,
AssertBeforeHeightRelative,
AssertBeforeSecondsAbsolute,
AssertBeforeSecondsRelative,
AssertCoinAnnouncement,
AssertConcurrentPuzzle,
AssertConcurrentSpend,
AssertHeightAbsolute,
AssertHeightRelative,
AssertMyAmount,
AssertMyBirthHeight,
AssertMyBirthSeconds,
AssertMyCoinID,
AssertMyParentID,
AssertMyPuzzleHash,
AssertPuzzleAnnouncement,
AssertSecondsAbsolute,
AssertSecondsRelative,
Condition,
ConditionValidTimes,
CreateAnnouncement,
CreateCoin,
CreateCoinAnnouncement,
CreatePuzzleAnnouncement,
Remark,
ReserveFee,
Softfork,
Timelock,
UnknownCondition,
conditions_from_json_dicts,
@ -243,3 +265,106 @@ def test_timelock_parsing(timelock_info: TimelockInfo) -> None:
assert timelock_info.parsed_info.to_conditions() == (
timelock_info.conditions_after if timelock_info.conditions_after is not None else timelock_info.drivers
)
@pytest.mark.parametrize(
"cond",
[
AggSigParent,
AggSigPuzzle,
AggSigAmount,
AggSigPuzzleAmount,
AggSigParentAmount,
AggSigParentPuzzle,
AggSigUnsafe,
AggSigMe,
CreateCoin,
ReserveFee,
AssertCoinAnnouncement,
CreateCoinAnnouncement,
AssertPuzzleAnnouncement,
CreatePuzzleAnnouncement,
AssertConcurrentSpend,
AssertConcurrentPuzzle,
AssertMyCoinID,
AssertMyParentID,
AssertMyPuzzleHash,
AssertMyAmount,
AssertMyBirthSeconds,
AssertMyBirthHeight,
AssertSecondsRelative,
AssertSecondsAbsolute,
AssertHeightRelative,
AssertHeightAbsolute,
AssertBeforeSecondsRelative,
AssertBeforeSecondsAbsolute,
AssertBeforeHeightRelative,
AssertBeforeHeightAbsolute,
Softfork,
Remark,
UnknownCondition,
AggSig,
CreateAnnouncement,
AssertAnnouncement,
Timelock,
],
)
@pytest.mark.parametrize(
"prg",
[
bytes([0x80]),
bytes([0xFF, 0x80, 0xFF, 0xFF, 0xFF, 0x80, 0x80, 0x80, 0x80]),
bytes([0xFF, 0x80, 0xFF, 0xFF, 0x80, 0x80, 0xFF, 0x80, 0x80]),
bytes([0xFF, 0x80, 0xFF, 0xFF, 0x80, 0x80, 0xFF, 0x80, 0xFF, 0x80, 0x80]),
bytes([0xFF, 0x80, 0xFF, 0xFF, 0x80, 0x80, 0xFF, 0x80, 0xFF, 0x80, 0xFF, 0x80, 0x80]),
],
)
def test_invalid_condition(
cond: Type[
Union[
AggSigParent,
AggSigPuzzle,
AggSigAmount,
AggSigPuzzleAmount,
AggSigParentAmount,
AggSigParentPuzzle,
AggSigUnsafe,
AggSigMe,
CreateCoin,
ReserveFee,
AssertCoinAnnouncement,
CreateCoinAnnouncement,
AssertPuzzleAnnouncement,
CreatePuzzleAnnouncement,
AssertConcurrentSpend,
AssertConcurrentPuzzle,
AssertMyCoinID,
AssertMyParentID,
AssertMyPuzzleHash,
AssertMyAmount,
AssertMyBirthSeconds,
AssertMyBirthHeight,
AssertSecondsRelative,
AssertSecondsAbsolute,
AssertHeightRelative,
AssertHeightAbsolute,
AssertBeforeSecondsRelative,
AssertBeforeSecondsAbsolute,
AssertBeforeHeightRelative,
AssertBeforeHeightAbsolute,
Softfork,
Remark,
UnknownCondition,
AggSig,
CreateAnnouncement,
AssertAnnouncement,
Timelock,
]
],
prg: bytes,
) -> None:
if (cond == Remark or cond == UnknownCondition) and prg != b"\x80":
pytest.skip("condition takes arbitrary arguments")
with pytest.raises((ValueError, EvalError, KeyError)):
cond.from_program(Program.from_bytes(prg))