Merge commit '618f93b4c42b176659cc74c02a4dd711adc62052' into checkpoint/main_from_release_1.6.2_618f93b4c42b176659cc74c02a4dd711adc62052

This commit is contained in:
Amine Khaldi 2022-12-14 19:24:34 +01:00
commit 7defb1ebfc
No known key found for this signature in database
GPG Key ID: B1C074FFC904E2D9
4 changed files with 140 additions and 5 deletions

View File

@ -0,0 +1,26 @@
from __future__ import annotations
from typing import Iterable, List, Set, Tuple
from chia.protocols.shared_protocol import Capability
from chia.util.ints import uint16
_capability_values = {int(capability) for capability in Capability}
def known_active_capabilities(values: Iterable[Tuple[uint16, str]]) -> List[Capability]:
# NOTE: order is not guaranteed
# TODO: what if there's a claim for both supporting and not?
# presently it considers it supported
filtered: Set[uint16] = set()
for value, state in values:
if state != "1":
continue
if value not in _capability_values:
continue
filtered.add(value)
# TODO: consider changing all uses to sets instead of lists
return [Capability(value) for value in filtered]

View File

@ -19,6 +19,7 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.protocol_state_machine import message_requires_reply, message_response_ok
from chia.protocols.protocol_timing import API_EXCEPTION_BAN_SECONDS, INTERNAL_PROTOCOL_ERROR_BAN_SECONDS
from chia.protocols.shared_protocol import Capability, Handshake
from chia.server.capabilities import known_active_capabilities
from chia.server.outbound_message import Message, NodeType, make_msg
from chia.server.rate_limits import RateLimiter
from chia.types.blockchain_format.sized_bytes import bytes32
@ -146,7 +147,7 @@ class WSChiaConnection:
local_type=local_type,
local_port=server_port,
local_capabilities_for_handshake=local_capabilities_for_handshake,
local_capabilities=[Capability(x[0]) for x in local_capabilities_for_handshake if x[1] == "1"],
local_capabilities=known_active_capabilities(local_capabilities_for_handshake),
peer_host=peer_host,
peer_port=peername[1],
peer_node_id=peer_id,
@ -206,7 +207,7 @@ class WSChiaConnection:
self.peer_server_port = inbound_handshake.server_port
self.connection_type = NodeType(inbound_handshake.node_type)
# "1" means capability is enabled
self.peer_capabilities = [Capability(x[0]) for x in inbound_handshake.capabilities if x[1] == "1"]
self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities)
else:
try:
message = await self._read_one_message()
@ -232,7 +233,7 @@ class WSChiaConnection:
self.peer_server_port = inbound_handshake.server_port
self.connection_type = NodeType(inbound_handshake.node_type)
# "1" means capability is enabled
self.peer_capabilities = [Capability(x[0]) for x in inbound_handshake.capabilities if x[1] == "1"]
self.peer_capabilities = known_active_capabilities(inbound_handshake.capabilities)
self.outbound_task = asyncio.create_task(self.outbound_handler())
self.inbound_task = asyncio.create_task(self.inbound_handler())

View File

@ -5,7 +5,7 @@ import dataclasses
import random
import time
from secrets import token_bytes
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
import pytest
from blspy import AugSchemeMPL, G2Element, PrivateKey
@ -20,10 +20,12 @@ from chia.protocols import full_node_protocol as fnp
from chia.protocols import timelord_protocol, wallet_protocol
from chia.protocols.full_node_protocol import RespondTransaction
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import Capability, capabilities
from chia.protocols.wallet_protocol import SendTransaction, TransactionAck
from chia.server.address_manager import AddressManager
from chia.server.outbound_message import Message, NodeType
from chia.simulator.block_tools import get_signage_point, test_constants
from chia.server.server import ChiaServer
from chia.simulator.block_tools import BlockTools, get_signage_point, test_constants
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.simulator.time_out_assert import time_out_assert, time_out_assert_custom_interval, time_out_messages
from chia.types.blockchain_format.classgroup import ClassgroupElement
@ -1951,3 +1953,41 @@ class TestFullNodeProtocol:
if block.challenge_chain_sp_proof is not None:
assert not block.challenge_chain_sp_proof.normalized_to_identity
assert not block.challenge_chain_ip_proof.normalized_to_identity
@pytest.mark.parametrize(
argnames=["custom_capabilities", "expect_success"],
argvalues=[
# standard
[capabilities, True],
# an additional enabled but unknown capability
[[*capabilities, (uint16(max(Capability) + 1), "1")], True],
# no capability, not even Chia mainnet
# TODO: shouldn't we fail without Capability.BASE?
[[], True],
# only an unknown capability
# TODO: shouldn't we fail without Capability.BASE?
[[(uint16(max(Capability) + 1), "1")], True],
],
)
@pytest.mark.asyncio
async def test_invalid_capability_can_connect(
self,
two_nodes: Tuple[FullNodeAPI, FullNodeAPI, ChiaServer, ChiaServer, BlockTools],
self_hostname: str,
custom_capabilities: List[Tuple[uint16, str]],
expect_success: bool,
) -> None:
# TODO: consider not testing this against both DB v1 and v2?
[
initiating_full_node_api,
listening_full_node_api,
initiating_server,
listening_server,
bt,
] = two_nodes
initiating_server._local_capabilities_for_handshake = custom_capabilities
connected = await initiating_server.start_client(PeerInfo(self_hostname, uint16(listening_server._port)), None)
assert connected == expect_success, custom_capabilities

View File

@ -0,0 +1,68 @@
from __future__ import annotations
from typing import List, Tuple
import pytest
from chia.protocols.shared_protocol import Capability
from chia.server.capabilities import known_active_capabilities
from chia.util.ints import uint16
@pytest.mark.parametrize(
argnames=["values", "expected"],
argvalues=[
# nothing, not even Chia mainnet...
[[], []],
# single valid
[[(uint16(Capability.BASE), "1")], [Capability.BASE]],
# all capabilities
[[(uint16(capability), "1") for capability in Capability], list(Capability)],
# all capabilities plus some invalid
[
[
*[(uint16(capability), "1") for capability in Capability],
*[(uint16(max(Capability) + n), "1") for n in range(1, 10)],
],
list(Capability),
],
# all possible values
[
[(uint16(n), "1") for n in range(2**16)],
list(Capability),
],
# all possible invalid values
[
[(uint16(n), "1") for n in set(range(2**16)) - set(Capability)],
[],
],
# single invalid
[[(uint16(max(Capability) + 1), "1")], []],
# a few invalid
[[(uint16(max(Capability) + n), "1") for n in range(1, 10)], []],
],
)
@pytest.mark.parametrize(
argnames="duplicated",
argvalues=[False, True],
ids=lambda value: "duplicated" if value else "as-is",
)
@pytest.mark.parametrize(
argnames="disabled",
argvalues=[False, True],
ids=lambda value: "disabled" if value else "enabled",
)
def test_known_active_capabilities_filter(
values: List[Tuple[uint16, str]],
expected: List[Capability],
duplicated: bool,
disabled: bool,
) -> None:
if duplicated:
values = values * 2
if disabled:
values = [(value, "0") for value, state in values]
expected = []
assert known_active_capabilities(values=values) == expected