Add optional outbound parameter to get_connections (#13879)

This commit is contained in:
dustinface 2022-11-16 00:40:16 +01:00 committed by GitHub
parent f30f0b3512
commit 8e2e51a8c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 15 additions and 16 deletions

View File

@ -184,7 +184,7 @@ class FullNodeDiscovery:
def _num_needed_peers(self) -> int:
target = self.target_outbound_count
outgoing = len(self.server.get_full_node_outgoing_connections())
outgoing = len(self.server.get_connections(NodeType.FULL_NODE, outbound=True))
return max(0, target - outgoing)
"""
@ -320,7 +320,7 @@ class FullNodeDiscovery:
# Only connect out to one peer per network group (/16 for IPv4).
groups = set()
full_node_connected = self.server.get_full_node_outgoing_connections()
full_node_connected = self.server.get_connections(NodeType.FULL_NODE, outbound=True)
connected = [c.get_peer_info() for c in full_node_connected]
connected = [c for c in connected if c is not None]
for conn in full_node_connected:

View File

@ -714,18 +714,14 @@ class ChiaServer:
for message in messages:
await connection.send_message(message)
def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]:
result = []
connections = self.get_connections(NodeType.FULL_NODE)
for connection in connections:
if connection.is_outbound:
result.append(connection)
return result
def get_connections(self, node_type: Optional[NodeType] = None) -> List[WSChiaConnection]:
def get_connections(
self, node_type: Optional[NodeType] = None, *, outbound: Optional[bool] = None
) -> List[WSChiaConnection]:
result = []
for _, connection in self.all_connections.items():
if node_type is None or connection.connection_type == node_type:
node_type_match = node_type is None or connection.connection_type == node_type
outbound_match = outbound is None or connection.is_outbound == outbound
if node_type_match and outbound_match:
result.append(connection)
return result
@ -802,7 +798,7 @@ class ChiaServer:
def accept_inbound_connections(self, node_type: NodeType) -> bool:
if not self._local_type == NodeType.FULL_NODE:
return True
inbound_count = len([conn for conn in self.get_connections(node_type) if not conn.is_outbound])
inbound_count = len(self.get_connections(node_type, outbound=False))
if node_type == NodeType.FULL_NODE:
return inbound_count < self.config["target_peer_count"] - self.config["target_outbound_peer_count"]
if node_type == NodeType.WALLET:

View File

@ -8,6 +8,7 @@ import pytest_asyncio
from chia.cmds.units import units
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.full_node.full_node import FullNode
from chia.server.outbound_message import NodeType
from chia.server.server import ChiaServer
from chia.server.start_service import Service
from chia.simulator.block_tools import BlockTools, create_block_tools_async
@ -84,7 +85,7 @@ class TestSimulation:
# Connect node 1 to node 2
connected: bool = await server1.start_client(PeerInfo(self_hostname, node2_port))
assert connected, f"node1 was unable to connect to node2 on port {node2_port}"
assert len(server1.get_full_node_outgoing_connections()) >= 1
assert len(server1.get_connections(NodeType.FULL_NODE, outbound=True)) >= 1
# Connect node3 to node1 and node2 - checks come later
node3: Service[FullNode] = extra_node
@ -93,7 +94,7 @@ class TestSimulation:
assert connected, f"server3 was unable to connect to node1 on port {node1_port}"
connected = await server3.start_client(PeerInfo(self_hostname, node2_port))
assert connected, f"server3 was unable to connect to node2 on port {node2_port}"
assert len(server3.get_full_node_outgoing_connections()) >= 2
assert len(server3.get_connections(NodeType.FULL_NODE, outbound=True)) >= 2
# wait up to 10 mins for node2 to sync the chain to height 7
await time_out_assert(600, node2.full_node.blockchain.get_peak_height, 7)

View File

@ -71,7 +71,9 @@ class FakeServer:
async def get_peer_info(self) -> Optional[PeerInfo]:
return None
def get_full_node_outgoing_connections(self) -> List[WSChiaConnection]:
def get_connections(
self, node_type: Optional[NodeType] = None, *, outbound: Optional[bool] = False
) -> List[WSChiaConnection]:
return []
def is_duplicate_or_self_connection(self, target_node: PeerInfo) -> bool: