streamable: Cache stream functions (#10419)

Apply the same pattern as we have for deserialization to serialization.

This avoids all those recursive runtime lookups for "how to stream this
object" which brings a nice speedup:

```
compare: benchmark
mode         | µs/iteration old | µs/iteration new | diff %
to_bytes     | 447.57           | 193.56           | -56.75

compare: full_block
mode         | µs/iteration old | µs/iteration new | diff %
to_bytes     | 110.32           | 61.09            | -44.62
```
This commit is contained in:
dustinface 2022-02-25 17:07:24 +01:00 committed by GitHub
parent 874cc23c71
commit fdb993cb10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -124,6 +124,7 @@ def recurse_jsonify(d):
return d
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS = {}
FIELDS_FOR_STREAMABLE_CLASS = {}
@ -179,6 +180,7 @@ def streamable(cls: Any):
cls1 = strictdataclass(cls)
t = type(cls.__name__, (cls1, Streamable), {})
stream_functions = []
parse_functions = []
try:
hints = get_type_hints(t)
@ -189,8 +191,10 @@ def streamable(cls: Any):
FIELDS_FOR_STREAMABLE_CLASS[t] = fields
for _, f_type in fields.items():
stream_functions.append(cls.function_to_stream_one_item(f_type))
parse_functions.append(cls.function_to_parse_one_item(f_type))
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = stream_functions
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[t] = parse_functions
return t
@ -263,6 +267,37 @@ def parse_str(f: BinaryIO) -> str:
return bytes.decode(str_read_bytes, "utf-8")
def stream_optional(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None:
if item is None:
f.write(bytes([0]))
else:
f.write(bytes([1]))
stream_inner_type_func(item, f)
def stream_bytes(item: Any, f: BinaryIO) -> None:
write_uint32(f, uint32(len(item)))
f.write(item)
def stream_list(stream_inner_type_func: Callable[[Any, BinaryIO], None], item: Any, f: BinaryIO) -> None:
write_uint32(f, uint32(len(item)))
for element in item:
stream_inner_type_func(element, f)
def stream_tuple(stream_inner_type_funcs: List[Callable[[Any, BinaryIO], None]], item: Any, f: BinaryIO) -> None:
assert len(stream_inner_type_funcs) == len(item)
for i in range(len(item)):
stream_inner_type_funcs[i](item[i], f)
def stream_str(item: Any, f: BinaryIO) -> None:
str_bytes = item.encode("utf-8")
write_uint32(f, uint32(len(str_bytes)))
f.write(str_bytes)
class Streamable:
@classmethod
def function_to_parse_one_item(cls, f_type: Type) -> Callable[[BinaryIO], Any]:
@ -312,51 +347,47 @@ class Streamable:
raise ValueError("Failed to parse unknown data in Streamable object")
return obj
def stream_one_item(self, f_type: Type, item, f: BinaryIO) -> None:
@classmethod
def function_to_stream_one_item(cls, f_type: Type) -> Callable[[Any, BinaryIO], Any]:
inner_type: Type
if is_type_SpecificOptional(f_type):
inner_type = get_args(f_type)[0]
if item is None:
f.write(bytes([0]))
else:
f.write(bytes([1]))
self.stream_one_item(inner_type, item, f)
stream_inner_type_func = cls.function_to_stream_one_item(inner_type)
return lambda item, f: stream_optional(stream_inner_type_func, item, f)
elif f_type == bytes:
write_uint32(f, uint32(len(item)))
f.write(item)
return stream_bytes
elif hasattr(f_type, "stream"):
item.stream(f)
return lambda item, f: item.stream(f)
elif hasattr(f_type, "__bytes__"):
f.write(bytes(item))
return lambda item, f: f.write(bytes(item))
elif is_type_List(f_type):
assert is_type_List(type(item))
write_uint32(f, uint32(len(item)))
inner_type = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
for element in item:
self.stream_one_item(inner_type, element, f)
stream_inner_type_func = cls.function_to_stream_one_item(inner_type)
return lambda item, f: stream_list(stream_inner_type_func, item, f)
elif is_type_Tuple(f_type):
inner_types = get_args(f_type)
assert len(item) == len(inner_types)
for i in range(len(item)):
self.stream_one_item(inner_types[i], item[i], f)
stream_inner_type_funcs = []
for i in range(len(inner_types)):
stream_inner_type_funcs.append(cls.function_to_stream_one_item(inner_types[i]))
return lambda item, f: stream_tuple(stream_inner_type_funcs, item, f)
elif f_type is str:
str_bytes = item.encode("utf-8")
write_uint32(f, uint32(len(str_bytes)))
f.write(str_bytes)
return stream_str
elif f_type is bool:
f.write(int(item).to_bytes(1, "big"))
return lambda item, f: f.write(int(item).to_bytes(1, "big"))
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")
raise NotImplementedError(f"can't stream {f_type}")
def stream(self, f: BinaryIO) -> None:
self_type = type(self)
try:
fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
fields = FIELDS_FOR_STREAMABLE_CLASS[self_type]
functions = STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[self_type]
except Exception:
fields = {}
for f_name, f_type in fields.items():
self.stream_one_item(f_type, getattr(self, f_name), f)
functions = []
for field, stream_func in zip(fields, functions):
stream_func(getattr(self, field), f)
def get_hash(self) -> bytes32:
return bytes32(std_hash(bytes(self)))