From fdb993cb1049dacfa9db1edefd66a3d4dbba73c4 Mon Sep 17 00:00:00 2001 From: dustinface <35775977+xdustinface@users.noreply.github.com> Date: Fri, 25 Feb 2022 17:07:24 +0100 Subject: [PATCH] streamable: Cache stream functions (#10419) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Apply the same pattern as we have for deserialization to serialization. This avoids all those recursive runtime lookups for "how to stream this object" which brings a nice speedup: ``` compare: benchmark mode | µs/iteration old | µs/iteration new | diff % to_bytes | 447.57 | 193.56 | -56.75 compare: full_block mode | µs/iteration old | µs/iteration new | diff % to_bytes | 110.32 | 61.09 | -44.62 ``` --- chia/util/streamable.py | 85 ++++++++++++++++++++++++++++------------- 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index baedc9c16461..73584bac7ed1 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -124,6 +124,7 @@ def recurse_jsonify(d): return d +STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS = {} PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {} FIELDS_FOR_STREAMABLE_CLASS = {} @@ -179,6 +180,7 @@ def streamable(cls: Any): cls1 = strictdataclass(cls) t = type(cls.__name__, (cls1, Streamable), {}) + stream_functions = [] parse_functions = [] try: hints = get_type_hints(t) @@ -189,8 +191,10 @@ def streamable(cls: Any): FIELDS_FOR_STREAMABLE_CLASS[t] = fields for _, f_type in fields.items(): + stream_functions.append(cls.function_to_stream_one_item(f_type)) parse_functions.append(cls.function_to_parse_one_item(f_type)) + STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = stream_functions PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = parse_functions return t @@ -263,6 +267,37 @@ def parse_str(f: BinaryIO) -> str: return bytes.decode(str_read_bytes, "utf-8") +def stream_optional(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None: + if item is None: + f.write(bytes([0])) + else: + f.write(bytes([1])) + stream_inner_type_func(item, f) + + +def stream_bytes(item: Any, f: BinaryIO) -> None: + write_uint32(f, uint32(len(item))) + f.write(item) + + +def stream_list(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None: + write_uint32(f, uint32(len(item))) + for element in item: + stream_inner_type_func(element, f) + + +def stream_tuple(stream_inner_type_funcs: List[Callable[[Any, BinaryIO], None]], item: Any, f: BinaryIO) -> None: + assert len(stream_inner_type_funcs) == len(item) + for i in range(len(item)): + stream_inner_type_funcs[i](item[i], f) + + +def stream_str(item: Any, f: BinaryIO) -> None: + str_bytes = item.encode("utf-8") + write_uint32(f, uint32(len(str_bytes))) + f.write(str_bytes) + + class Streamable: @classmethod def function_to_parse_one_item(cls, f_type: Type) -> Callable[[BinaryIO], Any]: @@ -312,51 +347,47 @@ class Streamable: raise ValueError("Failed to parse unknown data in Streamable object") return obj - def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None: + @classmethod + def function_to_stream_one_item(cls, f_type: Type) -> Callable[[Any, BinaryIO], Any]: inner_type: Type if is_type_SpecificOptional(f_type): inner_type = get_args(f_type)[0] - if item is None: - f.write(bytes([0])) - else: - f.write(bytes([1])) - self.stream_one_item(inner_type, item, f) + stream_inner_type_func = cls.function_to_stream_one_item(inner_type) + return lambda item, f: stream_optional(stream_inner_type_func, item, f) elif f_type == bytes: - write_uint32(f, uint32(len(item))) - f.write(item) + return stream_bytes elif hasattr(f_type, "stream"): - item.stream(f) + return lambda item, f: item.stream(f) elif hasattr(f_type, "__bytes__"): - f.write(bytes(item)) + return lambda item, f: f.write(bytes(item)) elif is_type_List(f_type): - assert is_type_List(type(item)) - write_uint32(f, uint32(len(item))) inner_type = get_args(f_type)[0] - # wjb assert inner_type != get_args(List)[0] # type: ignore - for element in item: - self.stream_one_item(inner_type, element, f) + stream_inner_type_func = cls.function_to_stream_one_item(inner_type) + return lambda item, f: stream_list(stream_inner_type_func, item, f) elif is_type_Tuple(f_type): inner_types = get_args(f_type) - assert len(item) == len(inner_types) - for i in range(len(item)): - self.stream_one_item(inner_types[i], item[i], f) - + stream_inner_type_funcs = [] + for i in range(len(inner_types)): + stream_inner_type_funcs.append(cls.function_to_stream_one_item(inner_types[i])) + return lambda item, f: stream_tuple(stream_inner_type_funcs, item, f) elif f_type is str: - str_bytes = item.encode("utf-8") - write_uint32(f, uint32(len(str_bytes))) - f.write(str_bytes) + return stream_str elif f_type is bool: - f.write(int(item).to_bytes(1, "big")) + return lambda item, f: f.write(int(item).to_bytes(1, "big")) else: - raise NotImplementedError(f"can't stream {item}, {f_type}") + raise NotImplementedError(f"can't stream {f_type}") def stream(self, f: BinaryIO) -> None: + self_type = type(self) try: - fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)] + fields = FIELDS_FOR_STREAMABLE_CLASS[self_type] + functions = STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[self_type] except Exception: fields = {} - for f_name, f_type in fields.items(): - self.stream_one_item(f_type, getattr(self, f_name), f) + functions = [] + + for field, stream_func in zip(fields, functions): + stream_func(getattr(self, field), f) def get_hash(self) -> bytes32: return bytes32(std_hash(bytes(self)))