chia-blockchain/chia/util/streamable.py

359 lines
13 KiB
Python
Raw Normal View History

# flake8: noqa
# pylint: disable
2019-08-05 08:38:16 +03:00
from __future__ import annotations
2019-11-18 07:49:39 +03:00
import dataclasses
2019-11-18 07:49:39 +03:00
import io
import pprint
import sys
2020-03-20 08:03:51 +03:00
from enum import Enum
from typing import Any, BinaryIO, Dict, List, Tuple, Type, Callable, Optional, Iterator
2019-11-18 07:49:39 +03:00
from blspy import G1Element, G2Element, PrivateKey
2019-11-18 07:49:39 +03:00
from chia.types.blockchain_format.program import Program, SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.hash import std_hash
from chia.util.ints import int64, int512, uint32, uint64, uint128
from chia.util.type_checking import is_type_List, is_type_SpecificOptional, is_type_Tuple, strictdataclass
if sys.version_info < (3, 8):
def get_args(t: Type[Any]) -> Tuple[Any, ...]:
return getattr(t, "__args__", ())
else:
from typing import get_args
pp = pprint.PrettyPrinter(indent=1, width=120, compact=True)
2019-07-22 16:01:23 +03:00
2019-07-25 13:01:56 +03:00
# TODO: Remove hack, this allows streaming these objects from binary
size_hints = {
"PrivateKey": PrivateKey.PRIVATE_KEY_SIZE,
"G1Element": G1Element.SIZE,
"G2Element": G2Element.SIZE,
2021-01-19 09:12:30 +03:00
"ConditionOpcode": 1,
2019-07-25 13:01:56 +03:00
}
2019-11-18 07:50:31 +03:00
unhashable_types = [
PrivateKey,
G1Element,
G2Element,
2020-02-08 07:25:26 +03:00
Program,
SerializedProgram,
2019-11-18 07:50:31 +03:00
]
# JSON does not support big ints, so these types must be serialized differently in JSON
big_ints = [uint64, int64, uint128, int512]
2019-07-25 13:01:56 +03:00
def dataclass_from_dict(klass, d):
"""
Converts a dictionary based on a dataclass, into an instance of that dataclass.
Recursively goes through lists, optionals, and dictionaries.
"""
if is_type_SpecificOptional(klass):
# Type is optional, data is either None, or Any
if not d:
return None
return dataclass_from_dict(get_args(klass)[0], d)
2020-09-11 06:03:06 +03:00
elif is_type_Tuple(klass):
# Type is tuple, can have multiple different types inside
i = 0
klass_properties = []
for item in d:
klass_properties.append(dataclass_from_dict(klass.__args__[i], item))
i = i + 1
return tuple(klass_properties)
2020-09-11 06:03:06 +03:00
elif dataclasses.is_dataclass(klass):
# Type is a dataclass, data is a dictionary
fieldtypes = {f.name: f.type for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
elif is_type_List(klass):
# Type is a list, data is a list
return [dataclass_from_dict(get_args(klass)[0], item) for item in d]
elif issubclass(klass, bytes):
# Type is bytes, data is a hex string
return klass(hexstr_to_bytes(d))
elif klass in unhashable_types:
# Type is unhashable (bls type), so cast from hex string
return klass.from_bytes(hexstr_to_bytes(d))
else:
# Type is a primitive, cast with correct class
return klass(d)
def recurse_jsonify(d):
"""
2021-07-22 19:00:36 +03:00
Makes bytes objects and unhashable types into strings with 0x, and makes large ints into
strings.
"""
if isinstance(d, list) or isinstance(d, tuple):
new_list = []
for item in d:
if type(item) in unhashable_types or issubclass(type(item), bytes):
2021-07-22 19:00:36 +03:00
item = f"0x{bytes(item).hex()}"
if isinstance(item, dict):
item = recurse_jsonify(item)
if isinstance(item, list):
item = recurse_jsonify(item)
if isinstance(item, tuple):
item = recurse_jsonify(item)
if isinstance(item, Enum):
item = item.name
if isinstance(item, int) and type(item) in big_ints:
item = int(item)
new_list.append(item)
d = new_list
else:
for key, value in d.items():
if type(value) in unhashable_types or issubclass(type(value), bytes):
2021-07-22 19:00:36 +03:00
d[key] = f"0x{bytes(value).hex()}"
if isinstance(value, dict):
d[key] = recurse_jsonify(value)
if isinstance(value, list):
d[key] = recurse_jsonify(value)
if isinstance(value, tuple):
d[key] = recurse_jsonify(value)
if isinstance(value, Enum):
d[key] = value.name
if isinstance(value, int) and type(value) in big_ints:
d[key] = int(value)
return d
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
2019-07-23 17:11:07 +03:00
def streamable(cls: Any):
2019-07-22 16:01:23 +03:00
"""
2019-08-05 08:38:16 +03:00
This is a decorator for class definitions. It applies the strictdataclass decorator,
which checks all types at construction. It also defines a simple serialization format,
and adds parse, from bytes, stream, and __bytes__ methods.
2019-08-05 08:38:16 +03:00
Serialization format:
- Each field is serialized in order, by calling from_bytes/__bytes__.
- For Lists, there is a 4 byte prefix for the list length.
- For Optionals, there is a one byte prefix, 1 iff object is present, 0 iff not.
2019-08-05 08:38:16 +03:00
All of the constituents must have parse/from_bytes, and stream/__bytes__ and therefore
2019-08-05 08:38:16 +03:00
be of fixed size. For example, int cannot be a constituent since it is not a fixed size,
whereas uint32 can be.
Furthermore, a get_hash() member is added, which performs a serialization and a sha256.
2019-08-05 08:38:16 +03:00
This class is used for deterministic serialization and hashing, for consensus critical
objects such as the block header.
Make sure to use the Streamable class as a parent class when using the streamable decorator,
as it will allow linters to recognize the methods that are added by the decorator. Also,
use the @dataclass(frozen=True) decorator as well, for linters to recognize constructor
arguments.
2019-07-22 16:01:23 +03:00
"""
cls1 = strictdataclass(cls)
t = type(cls.__name__, (cls1, Streamable), {})
parse_functions = []
try:
fields = cls1.__annotations__ # pylint: disable=no-member
except Exception:
fields = {}
for _, f_type in fields.items():
parse_functions.append(cls.function_to_parse_one_item(f_type))
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = parse_functions
return t
def parse_bool(f: BinaryIO) -> bool:
bool_byte = f.read(1)
assert bool_byte is not None and len(bool_byte) == 1 # Checks for EOF
if bool_byte == bytes([0]):
return False
elif bool_byte == bytes([1]):
return True
else:
raise ValueError("Bool byte must be 0 or 1")
def parse_optional(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> Optional[Any]:
is_present_bytes = f.read(1)
assert is_present_bytes is not None and len(is_present_bytes) == 1 # Checks for EOF
if is_present_bytes == bytes([0]):
return None
elif is_present_bytes == bytes([1]):
return parse_inner_type_f(f)
else:
raise ValueError("Optional must be 0 or 1")
def parse_bytes(f: BinaryIO) -> bytes:
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size: uint32 = uint32(int.from_bytes(list_size_bytes, "big"))
bytes_read = f.read(list_size)
assert bytes_read is not None and len(bytes_read) == list_size
return bytes_read
def parse_list(f: BinaryIO, parse_inner_type_f: Callable[[BinaryIO], Any]) -> List[Any]:
full_list: List = []
# wjb assert inner_type != get_args(List)[0]
list_size_bytes = f.read(4)
assert list_size_bytes is not None and len(list_size_bytes) == 4 # Checks for EOF
list_size = uint32(int.from_bytes(list_size_bytes, "big"))
for list_index in range(list_size):
full_list.append(parse_inner_type_f(f))
return full_list
def parse_tuple(f: BinaryIO, list_parse_inner_type_f: List[Callable[[BinaryIO], Any]]) -> Tuple[Any, ...]:
full_list = []
for parse_f in list_parse_inner_type_f:
full_list.append(parse_f(f))
return tuple(full_list)
def parse_size_hints(f: BinaryIO, f_type: Type, bytes_to_read: int) -> Any:
bytes_read = f.read(bytes_to_read)
assert bytes_read is not None and len(bytes_read) == bytes_to_read
return f_type.from_bytes(bytes_read)
def parse_str(f: BinaryIO) -> str:
str_size_bytes = f.read(4)
assert str_size_bytes is not None and len(str_size_bytes) == 4 # Checks for EOF
str_size: uint32 = uint32(int.from_bytes(str_size_bytes, "big"))
str_read_bytes = f.read(str_size)
assert str_read_bytes is not None and len(str_read_bytes) == str_size # Checks for EOF
return bytes.decode(str_read_bytes, "utf-8")
class Streamable:
@classmethod
def function_to_parse_one_item(cls: Type[cls.__name__], f_type: Type): # type: ignore
"""
This function returns a function taking one argument `f: BinaryIO` that parses
and returns a value of the given type.
"""
inner_type: Type
if f_type is bool:
return parse_bool
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_optional(f, parse_inner_type_f)
if hasattr(f_type, "parse"):
return f_type.parse
2020-02-06 18:12:33 +03:00
if f_type == bytes:
return parse_bytes
if is_type_List(f_type):
inner_type = get_args(f_type)[0]
parse_inner_type_f = cls.function_to_parse_one_item(inner_type)
return lambda f: parse_list(f, parse_inner_type_f)
if is_type_Tuple(f_type):
inner_types = get_args(f_type)
list_parse_inner_type_f = [cls.function_to_parse_one_item(_) for _ in inner_types]
return lambda f: parse_tuple(f, list_parse_inner_type_f)
if hasattr(f_type, "from_bytes") and f_type.__name__ in size_hints:
bytes_to_read = size_hints[f_type.__name__]
return lambda f: parse_size_hints(f, f_type, bytes_to_read)
if f_type is str:
return parse_str
raise NotImplementedError(f"Type {f_type} does not have parse")
@classmethod
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__: # type: ignore
# Create the object without calling __init__() to avoid unnecessary post-init checks in strictdataclass
obj: Streamable = object.__new__(cls)
fields: Iterator[str] = iter(getattr(cls, "__annotations__", {}))
values: Iterator = (parse_f(f) for parse_f in PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls])
for field, value in zip(fields, values):
object.__setattr__(obj, field, value)
# Use -1 as a sentinel value as it's not currently serializable
if next(fields, -1) != -1:
raise ValueError("Failed to parse incomplete Streamable object")
if next(values, -1) != -1:
raise ValueError("Failed to parse unknown data in Streamable object")
return obj
def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
inner_type: Type
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
if item is None:
f.write(bytes([0]))
2019-07-25 13:01:56 +03:00
else:
f.write(bytes([1]))
self.stream_one_item(inner_type, item, f)
2020-02-06 18:12:33 +03:00
elif f_type == bytes:
f.write(uint32(len(item)).to_bytes(4, "big"))
f.write(item)
elif hasattr(f_type, "stream"):
item.stream(f)
2019-11-02 00:53:48 +03:00
elif hasattr(f_type, "__bytes__"):
f.write(bytes(item))
elif is_type_List(f_type):
assert is_type_List(type(item))
f.write(uint32(len(item)).to_bytes(4, "big"))
inner_type = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
for element in item:
self.stream_one_item(inner_type, element, f)
elif is_type_Tuple(f_type):
inner_types = get_args(f_type)
assert len(item) == len(inner_types)
for i in range(len(item)):
self.stream_one_item(inner_types[i], item[i], f)
elif f_type is str:
str_bytes = item.encode("utf-8")
f.write(uint32(len(str_bytes)).to_bytes(4, "big"))
f.write(str_bytes)
elif f_type is bool:
f.write(int(item).to_bytes(1, "big"))
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")
2019-07-31 19:48:30 +03:00
def stream(self, f: BinaryIO) -> None:
try:
fields = self.__annotations__ # pylint: disable=no-member
except Exception:
fields = {}
for f_name, f_type in fields.items():
self.stream_one_item(f_type, getattr(self, f_name), f)
def get_hash(self) -> bytes32:
2020-02-11 03:07:20 +03:00
return bytes32(std_hash(bytes(self)))
@classmethod
def from_bytes(cls: Any, blob: bytes) -> Any:
f = io.BytesIO(blob)
parsed = cls.parse(f)
assert f.read() == b""
return parsed
2019-11-02 00:53:48 +03:00
def __bytes__(self: Any) -> bytes:
f = io.BytesIO()
self.stream(f)
return bytes(f.getvalue())
def __str__(self: Any) -> str:
return pp.pformat(recurse_jsonify(dataclasses.asdict(self)))
def __repr__(self: Any) -> str:
return pp.pformat(recurse_jsonify(dataclasses.asdict(self)))
def to_json_dict(self) -> Dict:
return recurse_jsonify(dataclasses.asdict(self))
@classmethod
def from_json_dict(cls: Any, json_dict: Dict) -> Any:
return dataclass_from_dict(cls, json_dict)