chia-blockchain/src/util/streamable.py

176 lines
6.2 KiB
Python
Raw Normal View History

# flake8: noqa
2019-08-05 08:38:16 +03:00
from __future__ import annotations
2019-11-18 07:49:39 +03:00
import dataclasses
2019-11-18 07:49:39 +03:00
import io
import pprint
from hashlib import sha256
2019-11-18 07:49:39 +03:00
from typing import Any, BinaryIO, List, Type, get_type_hints
2019-11-18 07:50:31 +03:00
from blspy import (
ChainCode,
ExtendedPrivateKey,
ExtendedPublicKey,
InsecureSignature,
PrependSignature,
PrivateKey,
PublicKey,
Signature,
)
2019-11-18 07:49:39 +03:00
from src.types.sized_bytes import bytes32
2019-08-05 08:38:16 +03:00
from src.util.ints import uint32
2019-11-18 07:50:31 +03:00
from src.util.type_checking import (
is_type_List,
is_type_SpecificOptional,
strictdataclass,
)
2019-07-22 16:01:23 +03:00
pp = pprint.PrettyPrinter(indent=1, width=120, compact=True)
2019-07-22 16:01:23 +03:00
2019-07-25 13:01:56 +03:00
# TODO: Remove hack, this allows streaming these objects from binary
size_hints = {
"PrivateKey": PrivateKey.PRIVATE_KEY_SIZE,
2019-07-25 13:01:56 +03:00
"PublicKey": PublicKey.PUBLIC_KEY_SIZE,
"Signature": Signature.SIGNATURE_SIZE,
"InsecureSignature": InsecureSignature.SIGNATURE_SIZE,
"PrependSignature": PrependSignature.SIGNATURE_SIZE,
"ExtendedPublicKey": ExtendedPublicKey.EXTENDED_PUBLIC_KEY_SIZE,
"ExtendedPrivateKey": ExtendedPrivateKey.EXTENDED_PRIVATE_KEY_SIZE,
2019-11-18 07:50:31 +03:00
"ChainCode": ChainCode.CHAIN_CODE_KEY_SIZE,
2019-07-25 13:01:56 +03:00
}
2019-11-18 07:50:31 +03:00
unhashable_types = [
PrivateKey,
PublicKey,
Signature,
PrependSignature,
InsecureSignature,
ExtendedPublicKey,
ExtendedPrivateKey,
ChainCode,
]
2019-07-25 13:01:56 +03:00
2019-07-23 17:11:07 +03:00
def streamable(cls: Any):
2019-07-22 16:01:23 +03:00
"""
2019-08-05 08:38:16 +03:00
This is a decorator for class definitions. It applies the strictdataclass decorator,
which checks all types at construction. It also defines a simple serialization format,
and adds parse, from bytes, stream, and __bytes__ methods.
2019-08-05 08:38:16 +03:00
Serialization format:
- Each field is serialized in order, by calling from_bytes/__bytes__.
- For Lists, there is a 4 byte prefix for the list length.
- For Optionals, there is a one byte prefix, 1 iff object is present, 0 iff not.
2019-08-05 08:38:16 +03:00
All of the constituents must have parse/from_bytes, and stream/__bytes__ and therefore
2019-08-05 08:38:16 +03:00
be of fixed size. For example, int cannot be a constituent since it is not a fixed size,
whereas uint32 can be.
Furthermore, a get_hash() member is added, which performs a serialization and a sha256.
2019-08-05 08:38:16 +03:00
This class is used for deterministic serialization and hashing, for consensus critical
objects such as the block header.
Make sure to use the Streamable class as a parent class when using the streamable decorator,
as it will allow linters to recognize the methods that are added by the decorator. Also,
use the @dataclass(frozen=True) decorator as well, for linters to recognize constructor
arguments.
2019-07-22 16:01:23 +03:00
"""
cls1 = strictdataclass(cls)
return type(cls.__name__, (cls1, Streamable), {})
class Streamable:
@classmethod
def parse_one_item(cls: Type[cls.__name__], f_type: Type, f: BinaryIO): # type: ignore
inner_type: Type
if is_type_List(f_type):
inner_type = f_type.__args__[0]
full_list: List[inner_type] = [] # type: ignore
assert inner_type != List.__args__[0] # type: ignore
list_size: uint32 = uint32(int.from_bytes(f.read(4), "big"))
for list_index in range(list_size):
full_list.append(cls.parse_one_item(inner_type, f)) # type: ignore
return full_list
if is_type_SpecificOptional(f_type):
inner_type = f_type.__args__[0]
is_present: bool = f.read(1) == bytes([1])
if is_present:
return cls.parse_one_item(inner_type, f) # type: ignore
2019-08-05 08:38:16 +03:00
else:
return None
if hasattr(f_type, "parse"):
return f_type.parse(f)
if hasattr(f_type, "from_bytes") and size_hints[f_type.__name__]:
return f_type.from_bytes(f.read(size_hints[f_type.__name__]))
if f_type is str:
str_size: uint32 = uint32(int.from_bytes(f.read(4), "big"))
2019-11-18 07:50:31 +03:00
return bytes.decode(f.read(str_size), "utf-8")
else:
raise RuntimeError(f"Type {f_type} does not have parse")
@classmethod
def parse(cls: Type[cls.__name__], f: BinaryIO) -> cls.__name__: # type: ignore
values = []
for _, f_type in get_type_hints(cls).items():
values.append(cls.parse_one_item(f_type, f)) # type: ignore
return cls(*values)
def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
inner_type: Type
if is_type_List(f_type):
assert is_type_List(type(item))
f.write(uint32(len(item)).to_bytes(4, "big"))
inner_type = f_type.__args__[0]
assert inner_type != List.__args__[0] # type: ignore
for element in item:
self.stream_one_item(inner_type, element, f)
elif is_type_SpecificOptional(f_type):
inner_type = f_type.__args__[0]
if item is None:
f.write(bytes([0]))
2019-07-25 13:01:56 +03:00
else:
f.write(bytes([1]))
self.stream_one_item(inner_type, item, f)
elif hasattr(f_type, "stream"):
item.stream(f)
2019-11-02 00:53:48 +03:00
elif hasattr(f_type, "__bytes__"):
f.write(bytes(item))
elif f_type is str:
f.write(uint32(len(item)).to_bytes(4, "big"))
2019-11-18 07:50:31 +03:00
f.write(item.encode("utf-8"))
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")
2019-07-31 19:48:30 +03:00
def stream(self, f: BinaryIO) -> None:
for f_name, f_type in get_type_hints(self).items(): # type: ignore
self.stream_one_item(f_type, getattr(self, f_name), f)
def get_hash(self) -> bytes32:
2019-11-02 00:53:48 +03:00
return bytes32(sha256(bytes(self)).digest())
@classmethod
def from_bytes(cls: Any, blob: bytes) -> Any:
f = io.BytesIO(blob)
return cls.parse(f)
2019-11-02 00:53:48 +03:00
def __bytes__(self: Any) -> bytes:
f = io.BytesIO()
self.stream(f)
return bytes(f.getvalue())
def __str__(self: Any) -> str:
return pp.pformat(self.recurse_str(dataclasses.asdict(self)))
def __repr__(self: Any) -> str:
return pp.pformat(self.recurse_str(dataclasses.asdict(self)))
def recurse_str(self, d):
for key, value in d.items():
if type(value) in unhashable_types:
d[key] = str(value)
if isinstance(value, dict):
self.recurse_str(value)
return d