protocols: Give rust wallet types an alias, don't use __all__ (#15186)

* protocols: Give rust wallet types an alias, don't use `__all__`

* Refactor `types_in_module` in `test_network_protocol_test.py`
This commit is contained in:
dustinface 2023-05-05 01:16:33 +07:00 committed by GitHub
parent 5b93bd320c
commit 9a81cb36b6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 15 deletions

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from chia_rs import CoinState, RespondToPhUpdates import chia_rs
from chia.full_node.fee_estimate import FeeEstimateGroup from chia.full_node.fee_estimate import FeeEstimateGroup
from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.coin import Coin
@ -20,7 +20,8 @@ Note: When changing this file, also change protocol_message_types.py, and the pr
""" """
__all__ = ["CoinState", "RespondToPhUpdates"] CoinState = chia_rs.CoinState
RespondToPhUpdates = chia_rs.RespondToPhUpdates
@streamable @streamable

View File

@ -1,7 +1,9 @@
# flake8: noqa # flake8: noqa
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Set import ast
import inspect
from typing import Any, Set, cast
from chia.protocols import ( from chia.protocols import (
farmer_protocol, farmer_protocol,
@ -19,16 +21,16 @@ from chia.protocols import (
def types_in_module(mod: Any) -> Set[str]: def types_in_module(mod: Any) -> Set[str]:
ret: List[str] = [] parsed = ast.parse(inspect.getsource(mod))
mod_name = mod.__name__ types = set()
for sym in dir(mod): for line in parsed.body:
obj = getattr(mod, sym) if isinstance(line, ast.Assign):
if hasattr(obj, "__module__") and obj.__module__ == mod_name: name = cast(ast.Name, line.targets[0])
ret.append(sym) if inspect.isclass(getattr(mod, name.id)):
types.add(name.id)
if hasattr(mod, "__all__"): elif isinstance(line, ast.ClassDef):
ret += getattr(mod, "__all__") types.add(line.name)
return set(ret) return types
def test_missing_messages_state_machine() -> None: def test_missing_messages_state_machine() -> None:
@ -155,8 +157,6 @@ def test_missing_messages() -> None:
"PutFarmerPayload", "PutFarmerPayload",
"PutFarmerRequest", "PutFarmerRequest",
"PutFarmerResponse", "PutFarmerResponse",
"get_current_authentication_token",
"validate_authentication_token",
} }
timelord_msgs = { timelord_msgs = {