Peer gossip. (#414)

* Initial commit.

* mypy

* Fix start service logic.

* Fix AddressManager tests.

* Experimentally increase timeout.

* Attempt to fix test.

* Flake8 typo

* Print traceback for CI build.

* Revert exception catching to gain more logs.

* Add serialization.

* Try to hack simulation test.

* Add debug info. Connect peers more often.

* Try to fix full node gossip.

* Remove introducer protocol from full node.

* Update request_peers test.

* Flake8 the test.

* Add sqlite.

* Address some review comments.

* Try to remove Peers() everywhere but introducer.

* flake8

* More linting.

* Limit other types of inbound connections.

* Initial commit new code.

* AddressManager tests working.

* flake8, mypy, some tests.

* More flake8.

* Tweak gossip protocol.

* Catch more connection failures.

* First attempt wallet gossip.

* Add wallet gossip test.

* Add back global signal handler.

* Resolve some SuperLinter issues.

* Fix some bugs.

* More linting.

* Lint.

* Attempt to improve network connectivity.

* Small fixes.

* Lint.

* Lint.

* Checkpoint address relay.

* Some fixes.

* Fix db path.

* Lint.

* Small fixes.

* Fix bugs.

* flake8, improve speed test simulation.

* py.black

* LGTM, other small fixes.

* Don't self connect.

* py.black

* Punish self connections even more.

* One more attempt to fix self connection.

* Don't connect to the introducer if I have all peers.

* Catch bugs.

* Fix typo.

* Type checking.

* Lint, catch timestamp bug.
This commit is contained in:
Florin Chirica 2020-09-23 20:12:31 +03:00 committed by Gene Hoffman
parent b37e033d56
commit 5b17e1cf23
21 changed files with 2435 additions and 303 deletions

View File

@ -6,7 +6,6 @@ import time
import random
from pathlib import Path
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Callable
import aiosqlite
from chiabip158 import PyBIP158
from chiapos import Verifier
@ -55,6 +54,7 @@ from src.util.hash import std_hash
from src.util.ints import uint32, uint64, uint128
from src.util.merkle_set import MerkleSet
from src.util.path import mkdir, path_from_root
from src.server.node_discovery import FullNodePeers
from src.types.peer_info import PeerInfo
OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None]
@ -63,6 +63,7 @@ OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None]
class FullNode:
block_store: BlockStore
full_node_store: FullNodeStore
full_node_peers: FullNodePeers
sync_store: SyncStore
coin_store: CoinStore
mempool_manager: MempoolManager
@ -108,6 +109,19 @@ class FullNode:
self.full_node_store = await FullNodeStore.create(self.connection)
self.sync_store = await SyncStore.create()
self.coin_store = await CoinStore.create(self.connection)
self.full_node_peers = FullNodePeers(
self.server,
self.root_path,
self.global_connections,
self.config["target_peer_count"]
- self.config["target_outbound_peer_count"],
self.config["target_outbound_peer_count"],
self.config["peer_db_path"],
self.config["introducer_peer"],
self.config["peer_connect_interval"],
self.log,
)
await self.full_node_peers.start()
self.log.info("Initializing blockchain from disk")
self.blockchain = await Blockchain.create(
@ -288,6 +302,31 @@ class FullNode:
async for msg in self._send_tips_to_farmers(Delivery.RESPOND):
yield msg
@api_request
async def request_peers_with_peer_info(
self,
request: full_node_protocol.RequestPeers,
peer_info: PeerInfo,
):
async for msg in self.full_node_peers.request_peers(peer_info):
yield msg
@api_request
async def respond_peers_with_peer_info(
self,
request: introducer_protocol.RespondPeers,
peer_info: PeerInfo,
):
await self.full_node_peers.respond_peers(request, peer_info, False)
@api_request
async def respond_peers_full_node_with_peer_info(
self,
request: full_node_protocol.RespondPeers,
peer_info: PeerInfo,
):
await self.full_node_peers.respond_peers(request, peer_info, True)
def _num_needed_peers(self) -> int:
assert self.global_connections is not None
diff = self.config["target_peer_count"] - len(
@ -298,6 +337,7 @@ class FullNode:
def _close(self):
self._shut_down = True
self.blockchain.shut_down()
asyncio.create_task(self.full_node_peers.close())
async def _await_closed(self):
await self.connection.close()
@ -1741,52 +1781,6 @@ class FullNode:
for _ in []:
yield _
@api_request
async def request_peers(
self, request: full_node_protocol.RequestPeers
) -> OutboundMessageGenerator:
if self.global_connections is None:
return
connected_peers = self.global_connections.get_full_node_peerinfos()
unconnected_peers = self.global_connections.peers.get_peers(
recent_threshold=24 * 60 * 60
)
peers = list(set(connected_peers + unconnected_peers))
yield OutboundMessage(
NodeType.FULL_NODE,
Message("respond_peers_full_node", full_node_protocol.RespondPeers(peers)),
Delivery.RESPOND,
)
@api_request
async def respond_peers(
self, request: introducer_protocol.RespondPeers
) -> OutboundMessageGenerator:
if self.server is None or self.global_connections is None:
return
conns = self.global_connections
for peer in request.peer_list:
conns.peers.add(PeerInfo(peer.host, peer.port))
# Pseudo-message to close the connection
yield OutboundMessage(NodeType.INTRODUCER, Message("", None), Delivery.CLOSE)
unconnected = conns.get_unconnected_peers(
recent_threshold=self.config["recent_peer_threshold"]
)
to_connect = unconnected[: self._num_needed_peers()]
if not len(to_connect):
return
self.log.info(f"Trying to connect to peers: {to_connect}")
for target in to_connect:
asyncio.create_task(self.server.start_client(target, self._on_connect))
@api_request
async def respond_peers_full_node(self, request: full_node_protocol.RespondPeers):
pass
@api_request
async def request_mempool_transactions(
self, request: full_node_protocol.RequestMempoolTransactions

View File

@ -1,13 +1,15 @@
import asyncio
import logging
import time
from typing import AsyncGenerator, Dict, Optional
from src.types.sized_bytes import bytes32
from src.protocols.introducer_protocol import RespondPeers, RequestPeers
from src.server.connection import PeerConnections
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.sized_bytes import bytes32
from src.server.server import ChiaServer
from src.util.api_decorators import api_request
from src.types.peer_info import PeerInfo, TimestampedPeerInfo
from src.util.ints import uint64
log = logging.getLogger(__name__)
@ -15,6 +17,7 @@ log = logging.getLogger(__name__)
class Introducer:
def __init__(self, max_peers_to_send: int, recent_peer_threshold: int):
self.vetted: Dict[bytes32, bool] = {}
self.vetted_timestamps: Dict[bytes32, int] = {}
self.max_peers_to_send = max_peers_to_send
self.recent_peer_threshold = recent_peer_threshold
self._shut_down = False
@ -38,14 +41,24 @@ class Introducer:
return
try:
log.info("Vetting random peers.")
rawpeers = self.global_connections.peers.get_peers(
100, True, self.recent_peer_threshold
if self.global_connections.introducer_peers is None:
await asyncio.sleep(3)
continue
rawpeers = self.global_connections.introducer_peers.get_peers(
100, True, 3 * self.recent_peer_threshold
)
for peer in rawpeers:
if self._shut_down:
return
if peer.get_hash() not in self.vetted:
if peer.get_hash() in self.vetted_timestamps:
if time.time() > self.vetted_timestamps[peer.get_hash()] + 3600:
if peer.get_hash() in self.vetted:
self.vetted[peer.get_hash()] = False
if (
peer.get_hash() not in self.vetted
or not self.vetted[peer.get_hash()]
):
try:
log.info(f"Vetting peer {peer.host} {peer.port}")
r, w = await asyncio.wait_for(
@ -60,6 +73,7 @@ class Introducer:
log.info(f"Have vetted {peer} successfully!")
self.vetted[peer.get_hash()] = True
self.vetted_timestamps[peer.get_hash()] = int(time.time())
except Exception as e:
log.error(e)
for i in range(30):
@ -71,21 +85,31 @@ class Introducer:
self.global_connections: PeerConnections = global_connections
@api_request
async def request_peers(
self, request: RequestPeers
async def request_peers_with_peer_info(
self,
request: RequestPeers,
peer_info: PeerInfo,
) -> AsyncGenerator[OutboundMessage, None]:
max_peers = self.max_peers_to_send
rawpeers = self.global_connections.peers.get_peers(
if self.global_connections.introducer_peers is None:
return
rawpeers = self.global_connections.introducer_peers.get_peers(
max_peers * 5, True, self.recent_peer_threshold
)
peers = []
for peer in rawpeers:
if peer.get_hash() not in self.vetted:
continue
if self.vetted[peer.get_hash()]:
peers.append(peer)
if peer.host == peer_info.host and peer.port == peer_info.port:
continue
peer_without_timestamp = TimestampedPeerInfo(
peer.host,
peer.port,
uint64(0),
)
peers.append(peer_without_timestamp)
if len(peers) >= max_peers:
break

View File

@ -10,7 +10,6 @@ from src.util.cbor_message import cbor_message
from src.util.ints import uint8, uint32, uint64, uint128
from src.types.peer_info import TimestampedPeerInfo
"""
Protocol between full nodes.
"""

View File

@ -0,0 +1,626 @@
from src.util.hash import std_hash
from secrets import randbits
from random import randrange, choice
from src.types.peer_info import PeerInfo, TimestampedPeerInfo
from src.util.ints import uint16, uint64
from typing import Dict, List, Optional, Tuple
from asyncio import Lock
import time
import math
TRIED_BUCKETS_PER_GROUP = 8
NEW_BUCKETS_PER_SOURCE_GROUP = 64
TRIED_BUCKET_COUNT = 256
NEW_BUCKET_COUNT = 1024
BUCKET_SIZE = 64
TRIED_COLLISION_SIZE = 10
NEW_BUCKETS_PER_ADDRESS = 8
LOG_TRIED_BUCKET_COUNT = 3
LOG_NEW_BUCKET_COUNT = 10
LOG_BUCKET_SIZE = 6
HORIZON_DAYS = 30
MAX_RETRIES = 3
MIN_FAIL_DAYS = 7
MAX_FAILURES = 10
# This is a Python port from 'CAddrInfo' class from Bitcoin core code.
class ExtendedPeerInfo:
def __init__(
self,
addr: TimestampedPeerInfo,
src_peer: Optional[PeerInfo],
):
self.peer_info: PeerInfo = PeerInfo(
addr.host,
addr.port,
)
self.timestamp: int = addr.timestamp
self.src: Optional[PeerInfo] = src_peer
if src_peer is None:
self.src = self.peer_info
self.random_pos: Optional[int] = None
self.is_tried: bool = False
self.ref_count: int = 0
self.last_success: int = 0
self.last_try: int = 0
self.num_attempts: int = 0
self.last_count_attempt: int = 0
def to_string(self) -> str:
assert self.src is not None
out = (
self.peer_info.host
+ " "
+ str(int(self.peer_info.port))
+ " "
+ self.src.host
+ " "
+ str(int(self.src.port))
)
return out
@classmethod
def from_string(cls, peer_str: str):
blobs = peer_str.split(" ")
assert len(blobs) == 4
peer_info = TimestampedPeerInfo(blobs[0], uint16(int(blobs[1])), uint64(0))
src_peer = PeerInfo(blobs[2], uint16(int(blobs[3])))
return cls(peer_info, src_peer)
def get_tried_bucket(self, key: int) -> int:
hash1 = int.from_bytes(
bytes(
std_hash(key.to_bytes(32, byteorder="big") + self.peer_info.get_key())[
:8
]
),
byteorder="big",
)
hash1 = hash1 % TRIED_BUCKETS_PER_GROUP
hash2 = int.from_bytes(
bytes(
std_hash(
key.to_bytes(32, byteorder="big")
+ self.peer_info.get_group()
+ bytes([hash1])
)[:8]
),
byteorder="big",
)
return hash2 % TRIED_BUCKET_COUNT
def get_new_bucket(self, key: int, src_peer: Optional[PeerInfo] = None) -> int:
if src_peer is None:
src_peer = self.src
assert src_peer is not None
hash1 = int.from_bytes(
bytes(
std_hash(
key.to_bytes(32, byteorder="big")
+ self.peer_info.get_group()
+ src_peer.get_group()
)[:8]
),
byteorder="big",
)
hash1 = hash1 % NEW_BUCKETS_PER_SOURCE_GROUP
hash2 = int.from_bytes(
bytes(
std_hash(
key.to_bytes(32, byteorder="big")
+ src_peer.get_group()
+ bytes([hash1])
)[:8]
),
byteorder="big",
)
return hash2 % NEW_BUCKET_COUNT
def get_bucket_position(self, key: int, is_new: bool, nBucket: int) -> int:
ch = "N" if is_new else "K"
hash1 = int.from_bytes(
bytes(
std_hash(
key.to_bytes(32, byteorder="big")
+ ch.encode()
+ nBucket.to_bytes(3, byteorder="big")
+ self.peer_info.get_key()
)[:8]
),
byteorder="big",
)
return hash1 % BUCKET_SIZE
def is_terrible(self, now: Optional[int] = None) -> bool:
if now is None:
now = int(math.floor(time.time()))
# never remove things tried in the last minute
if self.last_try > 0 and self.last_try >= now - 60:
return False
# came in a flying DeLorean
if self.timestamp > now + 10 * 60:
return True
# not seen in recent history
if self.timestamp == 0 or now - self.timestamp > HORIZON_DAYS * 24 * 60 * 60:
return True
# tried N times and never a success
if self.last_success == 0 and self.num_attempts >= MAX_RETRIES:
return True
# N successive failures in the last week
if (
now - self.last_success > MIN_FAIL_DAYS * 24 * 60 * 60
and self.num_attempts >= MAX_FAILURES
):
return True
return False
def get_selection_chance(self, now: Optional[int] = None):
if now is None:
now = int(math.floor(time.time()))
chance = 1.0
since_last_try = max(now - self.last_try, 0)
# deprioritize very recent attempts away
if since_last_try < 60 * 10:
chance *= 0.01
# deprioritize 66% after each failed attempt,
# but at most 1/28th to avoid the search taking forever or overly penalizing outages.
chance *= pow(0.66, min(self.num_attempts, 8))
return chance
# This is a Python port from 'CAddrMan' class from Bitcoin core code.
class AddressManager:
id_count: int
key: int
random_pos: List[int]
tried_matrix: List[List[int]]
new_matrix: List[List[int]]
tried_count: int
new_count: int
map_addr: Dict[str, int]
map_info: Dict[int, ExtendedPeerInfo]
last_good: int
tried_collisions: List[int]
def __init__(self):
self.clear()
self.lock: Lock = Lock()
def clear(self):
self.id_count = 0
self.key = randbits(256)
self.random_pos = []
self.tried_matrix = [
[-1 for x in range(BUCKET_SIZE)] for y in range(TRIED_BUCKET_COUNT)
]
self.new_matrix = [
[-1 for x in range(BUCKET_SIZE)] for y in range(NEW_BUCKET_COUNT)
]
self.tried_count = 0
self.new_count = 0
self.map_addr = {}
self.map_info = {}
self.last_good = 1
self.tried_collisions = []
def create_(
self, addr: TimestampedPeerInfo, addr_src: Optional[PeerInfo]
) -> Tuple[ExtendedPeerInfo, int]:
self.id_count += 1
node_id = self.id_count
self.map_info[node_id] = ExtendedPeerInfo(addr, addr_src)
self.map_addr[addr.host] = node_id
self.map_info[node_id].random_pos = len(self.random_pos)
self.random_pos.append(node_id)
return (self.map_info[node_id], node_id)
def find_(self, addr: PeerInfo) -> Tuple[Optional[ExtendedPeerInfo], Optional[int]]:
if addr.host not in self.map_addr:
return (None, None)
node_id = self.map_addr[addr.host]
if node_id not in self.map_info:
return (None, node_id)
return (self.map_info[node_id], node_id)
def swap_random_(self, rand_pos_1: int, rand_pos_2: int):
if rand_pos_1 == rand_pos_2:
return
assert rand_pos_1 < len(self.random_pos) and rand_pos_2 < len(self.random_pos)
node_id_1 = self.random_pos[rand_pos_1]
node_id_2 = self.random_pos[rand_pos_2]
self.map_info[node_id_1].random_pos = rand_pos_2
self.map_info[node_id_2].random_pos = rand_pos_1
self.random_pos[rand_pos_1] = node_id_2
self.random_pos[rand_pos_2] = node_id_1
def make_tried_(self, info: ExtendedPeerInfo, node_id: int):
for bucket in range(NEW_BUCKET_COUNT):
pos = info.get_bucket_position(self.key, True, bucket)
if self.new_matrix[bucket][pos] == node_id:
self.new_matrix[bucket][pos] = -1
info.ref_count -= 1
assert info.ref_count == 0
self.new_count -= 1
cur_bucket = info.get_tried_bucket(self.key)
cur_bucket_pos = info.get_bucket_position(self.key, False, cur_bucket)
if self.tried_matrix[cur_bucket][cur_bucket_pos] != -1:
# Evict the old node from the tried table.
node_id_evict = self.tried_matrix[cur_bucket][cur_bucket_pos]
assert node_id_evict in self.map_info
old_info = self.map_info[node_id_evict]
old_info.is_tried = False
self.tried_matrix[cur_bucket][cur_bucket_pos] = -1
self.tried_count -= 1
# Find its position into new table.
new_bucket = old_info.get_new_bucket(self.key)
new_bucket_pos = old_info.get_bucket_position(self.key, True, new_bucket)
self.clear_new_(new_bucket, new_bucket_pos)
old_info.ref_count = 1
self.new_matrix[new_bucket][new_bucket_pos] = node_id_evict
self.new_count += 1
self.tried_matrix[cur_bucket][cur_bucket_pos] = node_id
self.tried_count += 1
info.is_tried = True
def clear_new_(self, bucket: int, pos: int):
if self.new_matrix[bucket][pos] != -1:
delete_id = self.new_matrix[bucket][pos]
delete_info = self.map_info[delete_id]
assert delete_info.ref_count > 0
delete_info.ref_count -= 1
self.new_matrix[bucket][pos] = -1
if delete_info.ref_count == 0:
self.delete_new_entry_(delete_id)
def mark_good_(self, addr: PeerInfo, test_before_evict: bool, timestamp: int):
self.last_good = timestamp
(info, node_id) = self.find_(addr)
if info is None:
return
if node_id is None:
return
if not (info.peer_info.host == addr.host and info.peer_info.port == addr.port):
return
# update info
info.last_success = timestamp
info.last_try = timestamp
info.num_attempts = 0
# timestamp is not updated here, to avoid leaking information about
# currently-connected peers.
# if it is already in the tried set, don't do anything else
if info.is_tried:
return
# find a bucket it is in now
bucket_rand = randrange(NEW_BUCKET_COUNT)
new_bucket = -1
for n in range(NEW_BUCKET_COUNT):
cur_new_bucket = (n + bucket_rand) % NEW_BUCKET_COUNT
cur_new_bucket_pos = info.get_bucket_position(
self.key, True, cur_new_bucket
)
if self.new_matrix[cur_new_bucket][cur_new_bucket_pos] == node_id:
new_bucket = cur_new_bucket
break
# if no bucket is found, something bad happened;
if new_bucket == -1:
return
# NOTE(Florin): Double check this. It's not used anywhere else.
# which tried bucket to move the entry to
tried_bucket = info.get_tried_bucket(self.key)
tried_bucket_pos = info.get_bucket_position(self.key, False, tried_bucket)
# Will moving this address into tried evict another entry?
if (
test_before_evict
and self.tried_matrix[tried_bucket][tried_bucket_pos] != -1
):
if len(self.tried_collisions) < TRIED_COLLISION_SIZE:
if node_id not in self.tried_collisions:
self.tried_collisions.append(node_id)
else:
self.make_tried_(info, node_id)
def delete_new_entry_(self, node_id: int):
info = self.map_info[node_id]
if info is None or info.random_pos is None:
return
self.swap_random_(info.random_pos, len(self.random_pos) - 1)
self.random_pos = self.random_pos[:-1]
del self.map_addr[info.peer_info.host]
del self.map_info[node_id]
self.new_count -= 1
def add_to_new_table_(
self, addr: TimestampedPeerInfo, source: Optional[PeerInfo], penalty: int
) -> bool:
is_unique = False
peer_info = PeerInfo(
addr.host,
addr.port,
)
(info, node_id) = self.find_(peer_info)
if (
info is not None
and info.peer_info.host == addr.host
and info.peer_info.port == addr.port
):
penalty = 0
if info is not None:
# periodically update timestamp
currently_online = time.time() - addr.timestamp < 24 * 60 * 60
update_interval = 60 * 60 if currently_online else 24 * 60 * 60
if addr.timestamp > 0 and (
info.timestamp > 0
or info.timestamp < addr.timestamp - update_interval - penalty
):
info.timestamp = max(0, addr.timestamp - penalty)
# do not update if no new information is present
if addr.timestamp == 0 or (
info.timestamp > 0 and addr.timestamp <= info.timestamp
):
return False
# do not update if the entry was already in the "tried" table
if info.is_tried:
return False
# do not update if the max reference count is reached
if info.ref_count == NEW_BUCKETS_PER_ADDRESS:
return False
# stochastic test: previous ref_count == N: 2^N times harder to increase it
factor = 1 << info.ref_count
if factor > 1 and randrange(factor) != 0:
return False
else:
(info, node_id) = self.create_(addr, source)
info.timestamp = max(0, info.timestamp - penalty)
self.new_count += 1
is_unique = True
new_bucket = info.get_new_bucket(self.key, source)
new_bucket_pos = info.get_bucket_position(self.key, True, new_bucket)
if self.new_matrix[new_bucket][new_bucket_pos] != node_id:
add_to_new = self.new_matrix[new_bucket][new_bucket_pos] == -1
if not add_to_new:
info_existing = self.map_info[
self.new_matrix[new_bucket][new_bucket_pos]
]
if info_existing.is_terrible() or (
info_existing.ref_count > 1 and info.ref_count == 0
):
add_to_new = True
if add_to_new:
self.clear_new_(new_bucket, new_bucket_pos)
info.ref_count += 1
if node_id is not None:
self.new_matrix[new_bucket][new_bucket_pos] = node_id
else:
if info.ref_count == 0:
if node_id is not None:
self.delete_new_entry_(node_id)
return is_unique
def attempt_(self, addr: PeerInfo, count_failures: bool, timestamp: int):
info, _ = self.find_(addr)
if info is None:
return
if not (info.peer_info.host == addr.host and info.peer_info.port == addr.port):
return
info.last_try = timestamp
if count_failures and info.last_count_attempt < self.last_good:
info.last_count_attempt = timestamp
info.num_attempts += 1
def select_peer_(self, new_only: bool) -> Optional[ExtendedPeerInfo]:
if len(self.random_pos) == 0:
return None
if new_only and self.new_count == 0:
return None
# Use a 50% chance for choosing between tried and new table entries.
if (
not new_only
and self.tried_count > 0
and (self.new_count == 0 or randrange(2) == 0)
):
chance = 1.0
while True:
tried_bucket = randrange(TRIED_BUCKET_COUNT)
tried_buket_pos = randrange(BUCKET_SIZE)
while self.tried_matrix[tried_bucket][tried_buket_pos] == -1:
tried_bucket = (
tried_bucket + randbits(LOG_TRIED_BUCKET_COUNT)
) % TRIED_BUCKET_COUNT
tried_buket_pos = (
tried_buket_pos + randbits(LOG_BUCKET_SIZE)
) % BUCKET_SIZE
node_id = self.tried_matrix[tried_bucket][tried_buket_pos]
info = self.map_info[node_id]
if randbits(30) < (chance * info.get_selection_chance() * (1 << 30)):
return info
chance *= 1.2
else:
chance = 1.0
while True:
new_bucket = randrange(NEW_BUCKET_COUNT)
new_bucket_pos = randrange(BUCKET_SIZE)
while self.new_matrix[new_bucket][new_bucket_pos] == -1:
new_bucket = (
new_bucket + randbits(LOG_NEW_BUCKET_COUNT)
) % NEW_BUCKET_COUNT
new_bucket_pos = (
new_bucket_pos + randbits(LOG_BUCKET_SIZE)
) % BUCKET_SIZE
node_id = self.new_matrix[new_bucket][new_bucket_pos]
info = self.map_info[node_id]
if randbits(30) < chance * info.get_selection_chance() * (1 << 30):
return info
chance *= 1.2
def resolve_tried_collisions_(self):
for node_id in self.tried_collisions[:]:
resolved = False
if node_id not in self.map_info:
resolved = True
else:
info = self.map_info[node_id]
peer = info.peer_info
tried_bucket = info.get_tried_bucket(self.key)
tried_bucket_pos = info.get_bucket_position(
self.key, False, tried_bucket
)
if self.tried_matrix[tried_bucket][tried_bucket_pos] != -1:
old_id = self.tried_matrix[tried_bucket][tried_bucket_pos]
old_info = self.map_info[old_id]
if time.time() - old_info.last_success < 4 * 60 * 60:
resolved = True
elif time.time() - old_info.last_try < 4 * 60 * 60:
if time.time() - old_info.last_try > 60:
self.mark_good_(peer, False, time.time())
resolved = True
elif time.time() - info.last_success > 40 * 60:
self.mark_good_(peer, False, time.time())
resolved = True
else:
self.mark_good_(peer, False, time.time())
resolved = True
if resolved:
self.tried_collisions.remove(node_id)
def select_tried_collision_(self) -> Optional[ExtendedPeerInfo]:
if len(self.tried_collisions) == 0:
return None
new_id = choice(self.tried_collisions)
if new_id not in self.map_info:
self.tried_collisions.remove(new_id)
return None
new_info = self.map_info[new_id]
tried_bucket = new_info.get_tried_bucket(self.key)
tried_bucket_pos = new_info.get_bucket_position(self.key, False, tried_bucket)
old_id = self.tried_matrix[tried_bucket][tried_bucket_pos]
return self.map_info[old_id]
def get_peers_(self) -> List[TimestampedPeerInfo]:
addr: List[TimestampedPeerInfo] = []
num_nodes = math.ceil(23 * len(self.random_pos) / 100)
if num_nodes > 1000:
num_nodes = 1000
for n in range(len(self.random_pos)):
if len(addr) >= num_nodes:
return addr
rand_pos = randrange(len(self.random_pos) - n) + n
self.swap_random_(n, rand_pos)
info = self.map_info[self.random_pos[n]]
if not info.is_terrible():
cur_peer_info = TimestampedPeerInfo(
info.peer_info.host,
uint16(info.peer_info.port),
uint64(info.timestamp),
)
addr.append(cur_peer_info)
return addr
def connect_(self, addr: PeerInfo, timestamp: int):
info, _ = self.find_(addr)
if info is None:
return
# check whether we are talking about the exact same peer
if not (info.peer_info.host == addr.host and info.peer_info.port == addr.port):
return
update_interval = 20 * 60
if timestamp - info.timestamp > update_interval:
info.timestamp = timestamp
async def size(self) -> int:
async with self.lock:
return len(self.random_pos)
async def add_to_new_table(
self,
addresses: List[TimestampedPeerInfo],
source: Optional[PeerInfo] = None,
penalty: int = 0,
) -> bool:
is_added = False
async with self.lock:
for addr in addresses:
cur_peer_added = self.add_to_new_table_(addr, source, penalty)
is_added = is_added or cur_peer_added
return is_added
# Mark an entry as accesible.
async def mark_good(
self,
addr: PeerInfo,
test_before_evict: bool = True,
timestamp: int = -1,
):
if timestamp == -1:
timestamp = math.floor(time.time())
async with self.lock:
self.mark_good_(addr, test_before_evict, timestamp)
# Mark an entry as connection attempted to.
async def attempt(
self,
addr: PeerInfo,
count_failures: bool,
timestamp: int = -1,
):
if timestamp == -1:
timestamp = math.floor(time.time())
async with self.lock:
self.attempt_(addr, count_failures, timestamp)
# See if any to-be-evicted tried table entries have been tested and if so resolve the collisions.
async def resolve_tried_collisions(self):
async with self.lock:
self.resolve_tried_collisions_()
# Randomly select an address in tried that another address is attempting to evict.
async def select_tried_collision(self) -> Optional[ExtendedPeerInfo]:
async with self.lock:
return self.select_tried_collision_()
# Choose an address to connect to.
async def select_peer(self, new_only: bool = False) -> Optional[ExtendedPeerInfo]:
async with self.lock:
return self.select_peer_(new_only)
# Return a bunch of addresses, selected at random.
async def get_peers(self) -> List[TimestampedPeerInfo]:
async with self.lock:
return self.get_peers_()
async def connect(self, addr: PeerInfo, timestamp: int = -1):
if timestamp == -1:
timestamp = math.floor(time.time())
async with self.lock:
return self.connect_(addr, timestamp)

View File

@ -0,0 +1,220 @@
import logging
import aiosqlite
from src.server.address_manager import (
AddressManager,
ExtendedPeerInfo,
NEW_BUCKET_COUNT,
BUCKET_SIZE,
NEW_BUCKETS_PER_ADDRESS,
)
from typing import Dict, List, Tuple
log = logging.getLogger(__name__)
class AddressManagerStore:
"""
Metadata table:
- private key
- new table count
- tried table count
Nodes table:
* Maps entries from new/tried table to unique node ids.
- node_id
- IP, port, together with the IP, port of the source peer.
New table:
* Stores node_id, bucket for each occurrence in the new table of an entry.
* Once we know the buckets, we can also deduce the bucket positions.
Every other information, such as tried_matrix, map_addr, map_info, random_pos,
be deduced and it is not explicitly stored, instead it is recalculated.
"""
db: aiosqlite.Connection
@classmethod
async def create(cls, connection):
self = cls()
self.db = connection
await self.db.commit()
await self.db.execute(
"CREATE TABLE IF NOT EXISTS peer_metadata(key text,value text)"
)
await self.db.commit()
await self.db.execute(
"CREATE TABLE IF NOT EXISTS peer_nodes(node_id int,value text)"
)
await self.db.commit()
await self.db.execute(
"CREATE TABLE IF NOT EXISTS peer_new_table(node_id int,bucket int)"
)
await self.db.commit()
return self
async def clear(self):
cursor = await self.db.execute("DELETE from peer_metadata")
await cursor.close()
cursor = await self.db.execute("DELETE from peer_nodes")
await cursor.close()
cursor = await self.db.execute("DELETE from peer_new_table")
await cursor.close()
await self.db.commit()
async def get_metadata(self) -> Dict[str, str]:
cursor = await self.db.execute("SELECT key, value from peer_metadata")
metadata = await cursor.fetchall()
await cursor.close()
return {key: value for key, value in metadata}
async def is_empty(self) -> bool:
metadata = await self.get_metadata()
if "key" not in metadata:
return True
if int(metadata.get("new_count", 0)) > 0:
return False
if int(metadata.get("tried_count", 0)) > 0:
return False
return True
async def get_nodes(self) -> List[Tuple[int, ExtendedPeerInfo]]:
cursor = await self.db.execute("SELECT node_id, value from peer_nodes")
nodes_id = await cursor.fetchall()
await cursor.close()
return [
(node_id, ExtendedPeerInfo.from_string(info_str))
for node_id, info_str in nodes_id
]
async def get_new_table(self) -> List[Tuple[int, int]]:
cursor = await self.db.execute("SELECT node_id, bucket from peer_new_table")
entries = await cursor.fetchall()
await cursor.close()
return [(node_id, bucket) for node_id, bucket in entries]
async def set_metadata(self, metadata):
for key, value in metadata:
cursor = await self.db.execute(
"INSERT OR REPLACE INTO peer_metadata VALUES(?, ?)",
(key, value),
)
await cursor.close()
await self.db.commit()
async def set_nodes(self, node_list):
for node_id, peer_info in node_list:
cursor = await self.db.execute(
"INSERT OR REPLACE INTO peer_nodes VALUES(?, ?)",
(node_id, peer_info.to_string()),
)
await cursor.close()
await self.db.commit()
async def set_new_table(self, entries):
for node_id, bucket in entries:
cursor = await self.db.execute(
"INSERT OR REPLACE INTO peer_new_table VALUES(?, ?)",
(node_id, bucket),
)
await cursor.close()
await self.db.commit()
async def serialize(self, address_manager: AddressManager):
metadata = []
nodes = []
new_table_entries = []
metadata.append(("key", str(address_manager.key)))
metadata.append(("new_count", str(address_manager.new_count)))
metadata.append(("tried_count", str(address_manager.tried_count)))
unique_ids = {}
count_ids = 0
for node_id, info in address_manager.map_info.items():
unique_ids[node_id] = count_ids
if info.ref_count > 0:
assert count_ids != address_manager.new_count
nodes.append((count_ids, info))
count_ids += 1
tried_ids = 0
for node_id, info in address_manager.map_info.items():
if info.is_tried:
assert info is not None
assert tried_ids != address_manager.tried_count
nodes.append((count_ids, info))
count_ids += 1
tried_ids += 1
for bucket in range(NEW_BUCKET_COUNT):
for i in range(BUCKET_SIZE):
if address_manager.new_matrix[bucket][i] != -1:
index = unique_ids[address_manager.new_matrix[bucket][i]]
new_table_entries.append((index, bucket))
await self.clear()
await self.set_metadata(metadata)
await self.set_nodes(nodes)
await self.set_new_table(new_table_entries)
async def deserialize(self) -> AddressManager:
address_manager = AddressManager()
metadata = await self.get_metadata()
nodes = await self.get_nodes()
new_table_entries = await self.get_new_table()
address_manager.clear()
address_manager.key = int(metadata["key"])
address_manager.new_count = int(metadata["new_count"])
address_manager.tried_count = int(metadata["tried_count"])
new_table_nodes = [
(node_id, info)
for node_id, info in nodes
if node_id < address_manager.new_count
]
for n, info in new_table_nodes:
address_manager.map_addr[info.peer_info.host] = n
address_manager.map_info[n] = info
info.random_pos = len(address_manager.random_pos)
address_manager.random_pos.append(n)
tried_table_nodes = [
(node_id, info)
for node_id, info in nodes
if node_id >= address_manager.new_count
]
lost_count = 0
for node_id, info in tried_table_nodes:
tried_bucket = info.get_tried_bucket(address_manager.key)
tried_bucket_pos = info.get_bucket_position(
address_manager.key, False, tried_bucket
)
if address_manager.tried_matrix[tried_bucket][tried_bucket_pos] == -1:
info.random_pos = len(address_manager.random_pos)
info.is_tried = True
address_manager.random_pos.append(node_id)
address_manager.map_info[node_id] = info
address_manager.map_addr[info.peer_info.host] = node_id
address_manager.tried_matrix[tried_bucket][tried_bucket_pos] = node_id
else:
lost_count += 1
address_manager.tried_count -= lost_count
for node_id, bucket in new_table_entries:
if node_id >= 0 and node_id < address_manager.new_count:
info = address_manager.map_info[node_id]
bucket_pos = info.get_bucket_position(address_manager.key, True, bucket)
if (
address_manager.new_matrix[bucket][bucket_pos] == -1
and info.ref_count < NEW_BUCKETS_PER_ADDRESS
):
info.ref_count += 1
address_manager.new_matrix[bucket][bucket_pos] = node_id
for node_id, info in address_manager.map_info.items():
if not info.is_tried and info.ref_count == 0:
address_manager.delete_new_entry_(node_id)
return address_manager

View File

@ -1,14 +1,15 @@
import logging
import random
import time
import asyncio
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional
import socket
from typing import Any, AsyncGenerator, Callable, List, Optional
from src.server.outbound_message import Message, NodeType, OutboundMessage
from src.types.peer_info import PeerInfo, TimestampedPeerInfo
from src.types.sized_bytes import bytes32
from src.types.peer_info import PeerInfo
from src.util import cbor
from src.util.ints import uint16, uint64
from src.util.ints import uint16
from src.server.introducer_peers import IntroducerPeers
from src.util.errors import Err, ProtocolError
# Each message is prepended with LENGTH_BYTES bytes specifying the length
LENGTH_BYTES: int = 4
@ -33,6 +34,9 @@ class ChiaConnection:
server_port: int,
on_connect: OnConnectFunc,
log: logging.Logger,
is_outbound: bool,
# Special type of connection, that disconnects after the handshake.
is_feeler: bool,
):
self.local_type = local_type
self.connection_type = connection_type
@ -47,6 +51,8 @@ class ChiaConnection:
self.node_id = None
self.on_connect = on_connect
self.log = log
self.is_outbound = is_outbound
self.is_feeler = is_feeler
# ChiaConnection metrics
self.creation_time = time.time()
@ -120,31 +126,53 @@ class ChiaConnection:
class PeerConnections:
def __init__(self, all_connections: List[ChiaConnection] = []):
def __init__(
self, local_type: NodeType, all_connections: List[ChiaConnection] = []
):
self._all_connections = all_connections
# Only full node peers are added to `peers`
self.peers = Peers()
for c in all_connections:
if c.connection_type == NodeType.FULL_NODE:
self.peers.add(c.get_peer_info())
self.local_type = local_type
self.introducer_peers = None
self.connection = None
if local_type == NodeType.INTRODUCER:
self.introducer_peers = IntroducerPeers()
for c in all_connections:
if c.connection_type == NodeType.FULL_NODE:
self.introducer_peers.add(c.get_peer_info())
self.state_changed_callback: Optional[Callable] = None
self.full_node_peers_callback: Optional[Callable] = None
self.wallet_callback: Optional[Callable] = None
self.max_inbound_count = 0
def set_state_changed_callback(self, callback: Callable):
self.state_changed_callback = callback
def set_full_node_peers_callback(self, callback: Callable):
self.full_node_peers_callback = callback
def set_wallet_callback(self, callback: Callable):
self.wallet_callback = callback
def _state_changed(self, state: str):
if self.state_changed_callback is not None:
self.state_changed_callback(state)
def add(self, connection: ChiaConnection) -> bool:
if not connection.is_outbound:
if (
connection.connection_type is not None
and not self.accept_inbound_connections(connection.connection_type)
):
raise ProtocolError(Err.MAX_INBOUND_CONNECTIONS_REACHED)
for c in self._all_connections:
if c.node_id == connection.node_id:
return False
raise ProtocolError(Err.DUPLICATE_CONNECTION, [False])
self._all_connections.append(connection)
if connection.connection_type == NodeType.FULL_NODE:
self._state_changed("add_connection")
return self.peers.add(connection.get_peer_info())
if self.introducer_peers is not None:
return self.introducer_peers.add(connection.get_peer_info())
self._state_changed("add_connection")
return True
@ -155,14 +183,37 @@ class PeerConnections:
connection.close()
self._state_changed("close_connection")
if not keep_peer:
self.peers.remove(info)
if self.introducer_peers is not None:
self.introducer_peers.remove(info)
def close_all_connections(self):
for connection in self._all_connections:
connection.close()
self._state_changed("close_connection")
self._all_connections = []
self.peers = Peers()
if self.local_type == NodeType.INTRODUCER:
self.introducer_peers = IntroducerPeers()
def get_local_peerinfo(self) -> Optional[PeerInfo]:
ip = None
port = None
for c in self._all_connections:
if c.connection_type == NodeType.FULL_NODE:
port = c.local_port
break
if port is None:
return None
# https://stackoverflow.com/a/28950776
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
try:
s.connect(("introducer.beta.chia.net", 8444))
ip = s.getsockname()[0]
except Exception:
ip = None
if ip is None:
return None
return PeerInfo(str(ip), uint16(port))
def get_connections(self):
return self._all_connections
@ -175,52 +226,98 @@ class PeerConnections:
filter(None, map(ChiaConnection.get_peer_info, self._all_connections))
)
def get_unconnected_peers(self, max_peers=0, recent_threshold=9999999):
connected = self.get_full_node_peerinfos()
peers = self.peers.get_peers(recent_threshold=recent_threshold)
unconnected = list(filter(lambda peer: peer not in connected, peers))
if not max_peers:
max_peers = len(unconnected)
return unconnected[:max_peers]
async def successful_handshake(self, connection):
if connection.connection_type == NodeType.FULL_NODE:
if connection.is_outbound:
if self.full_node_peers_callback is not None:
self.full_node_peers_callback(
"mark_tried",
connection.get_peer_info(),
)
if self.wallet_callback is not None:
self.wallet_callback(
"make_tried",
connection.get_peer_info(),
)
if connection.is_feeler:
connection.close()
self.close(connection)
return
# Request peers after handshake.
if connection.local_type == NodeType.FULL_NODE:
await connection.send(Message("request_peers", ""))
else:
if self.full_node_peers_callback is not None:
self.full_node_peers_callback(
"new_inbound_connection",
connection.get_peer_info(),
)
yield connection
def failed_handshake(self, connection, e):
if connection.connection_type == NodeType.FULL_NODE and connection.is_outbound:
if isinstance(e, ProtocolError) and e.code == Err.DUPLICATE_CONNECTION:
return
class Peers:
"""
Has the list of known full node peers that are already connected or may be
connected to, and the time that they were last added.
"""
if self.full_node_peers_callback is not None:
self.full_node_peers_callback(
"mark_attempted",
connection.get_peer_info(),
)
if self.wallet_callback is not None:
self.wallet_callback(
"mark_attempted",
connection.get_peer_info(),
)
def __init__(self):
self._peers: List[PeerInfo] = []
self.time_added: Dict[bytes32, uint64] = {}
def failed_connection(self, peer_info):
if self.full_node_peers_callback is not None:
self.full_node_peers_callback(
"mark_attempted",
peer_info,
)
if self.wallet_callback is not None:
self.wallet_callback(
"mark_attempted",
peer_info,
)
def add(self, peer: Optional[PeerInfo]) -> bool:
if peer is None or not peer.port:
return False
if peer not in self._peers:
self._peers.append(peer)
self.time_added[peer.get_hash()] = uint64(int(time.time()))
return True
def update_connection_time(self, connection):
if connection.connection_type == NodeType.FULL_NODE and connection.is_outbound:
if self.full_node_peers_callback is not None:
self.full_node_peers_callback(
"update_connection_time",
connection.get_peer_info(),
)
if self.wallet_callback is not None:
self.wallet_callback(
"update_connection_time",
connection.get_peer_info(),
)
def remove(self, peer: Optional[PeerInfo]) -> bool:
if peer is None or not peer.port:
return False
try:
self._peers.remove(peer)
return True
except ValueError:
return False
# Functions related to outbound and inbound connections for the full node.
def count_outbound_connections(self):
return len(self.get_outbound_connections())
def get_peers(
self, max_peers: int = 0, randomize: bool = False, recent_threshold=9999999
) -> List[TimestampedPeerInfo]:
target_peers = [
TimestampedPeerInfo(peer.host, uint16(peer.port), uint64(0))
for peer in self._peers
if time.time() - self.time_added[peer.get_hash()] < recent_threshold
def get_outbound_connections(self):
return [
conn
for conn in self._all_connections
if conn.is_outbound and conn.connection_type == NodeType.FULL_NODE
]
if not max_peers or max_peers > len(target_peers):
max_peers = len(target_peers)
if randomize:
random.shuffle(target_peers)
return target_peers[:max_peers]
def accept_inbound_connections(self, node_type: NodeType):
if not self.local_type == NodeType.FULL_NODE:
return True
inbound_count = len(
[
conn
for conn in self._all_connections
if not conn.is_outbound and conn.connection_type == node_type
]
)
if node_type == NodeType.FULL_NODE:
return inbound_count < self.max_inbound_count
if node_type == NodeType.WALLET:
return inbound_count < 20
return inbound_count < 10

View File

@ -0,0 +1,48 @@
import time
import random
from src.types.sized_bytes import bytes32
from typing import List, Dict, Optional
from src.util.ints import uint64
from src.types.peer_info import PeerInfo
class IntroducerPeers:
"""
Has the list of known full node peers that are already connected or may be
connected to, and the time that they were last added.
"""
def __init__(self):
self._peers: List[PeerInfo] = []
self.time_added: Dict[bytes32, uint64] = {}
def add(self, peer: Optional[PeerInfo]) -> bool:
if peer is None or not peer.port:
return False
if peer not in self._peers:
self._peers.append(peer)
self.time_added[peer.get_hash()] = uint64(int(time.time()))
return True
def remove(self, peer: Optional[PeerInfo]) -> bool:
if peer is None or not peer.port:
return False
try:
self._peers.remove(peer)
return True
except ValueError:
return False
def get_peers(
self, max_peers: int = 0, randomize: bool = False, recent_threshold=9999999
) -> List[PeerInfo]:
target_peers = [
peer
for peer in self._peers
if time.time() - self.time_added[peer.get_hash()] < recent_threshold
]
if not max_peers or max_peers > len(target_peers):
max_peers = len(target_peers)
if randomize:
random.shuffle(target_peers)
return target_peers[:max_peers]

View File

@ -0,0 +1,509 @@
import asyncio
import time
import math
import aiosqlite
import traceback
from random import Random
from src.types.peer_info import PeerInfo, TimestampedPeerInfo
from src.util.path import path_from_root, mkdir
from src.server.outbound_message import (
Delivery,
OutboundMessage,
Message,
NodeType,
)
from src.server.address_manager import ExtendedPeerInfo, AddressManager
from src.server.address_manager_store import AddressManagerStore
from src.protocols import (
introducer_protocol,
full_node_protocol,
)
from secrets import randbits
from src.util.hash import std_hash
from typing import Dict, Optional, AsyncGenerator
from src.util.ints import uint64
OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None]
class FullNodeDiscovery:
def __init__(
self,
server,
root_path,
global_connections,
target_outbound_count,
peer_db_path,
introducer_info,
peer_connect_interval,
log,
):
self.server = server
assert self.server is not None
self.message_queue = asyncio.Queue()
self.is_closed = False
self.global_connections = global_connections
self.target_outbound_count = target_outbound_count
self.peer_db_path = path_from_root(root_path, peer_db_path)
self.introducer_info = PeerInfo(
introducer_info["host"],
introducer_info["port"],
)
self.peer_connect_interval = peer_connect_interval
self.log = log
self.relay_queue = None
async def initialize_address_manager(self):
mkdir(self.peer_db_path.parent)
self.connection = await aiosqlite.connect(self.peer_db_path)
self.address_manager_store = await AddressManagerStore.create(self.connection)
if not await self.address_manager_store.is_empty():
self.address_manager = await self.address_manager_store.deserialize()
else:
await self.address_manager_store.clear()
self.address_manager = AddressManager()
async def start_tasks(self):
self.process_messages_task = asyncio.create_task(self._process_messages())
random = Random()
self.connect_peers_task = asyncio.create_task(self._connect_to_peers(random))
self.serialize_task = asyncio.create_task(self._periodically_serialize(random))
async def _close_common(self):
self.is_closed = True
self.connect_peers_task.cancel()
self.process_messages_task.cancel()
self.serialize_task.cancel()
await self.connection.close()
def add_message(self, message, data):
self.message_queue.put_nowait((message, data))
async def _process_messages(self):
connection_time_pretest: Dict = {}
while not self.is_closed:
try:
message, peer_info = await self.message_queue.get()
if peer_info is None or not peer_info.port:
continue
if message == "make_tried":
await self.address_manager.mark_good(peer_info, True)
await self.address_manager.connect(peer_info)
elif message == "mark_attempted":
await self.address_manager.attempt(peer_info, True)
elif message == "update_connection_time":
if peer_info.host not in connection_time_pretest:
connection_time_pretest[peer_info.host] = time.time()
if time.time() - connection_time_pretest[peer_info.host] > 60:
await self.address_manager.connect(peer_info)
connection_time_pretest[peer_info.host] = time.time()
elif message == "new_inbound_connection":
timestamped_peer_info = TimestampedPeerInfo(
peer_info.host,
peer_info.port,
uint64(int(time.time())),
)
await self.address_manager.add_to_new_table(
[timestamped_peer_info], peer_info, 0
)
await self.address_manager.mark_good(peer_info, True)
if self.relay_queue is not None:
self.relay_queue.put_nowait((timestamped_peer_info, 1))
except Exception as e:
self.log.error(f"Exception in process message: {e}")
def _num_needed_peers(self) -> int:
diff = self.target_outbound_count
diff -= self.global_connections.count_outbound_connections()
return diff if diff >= 0 else 0
"""
Uses the Poisson distribution to determine the next time
when we'll initiate a feeler connection.
(https://en.wikipedia.org/wiki/Poisson_distribution)
"""
def _poisson_next_send(self, now, avg_interval_seconds, random):
return now + (
math.log(random.randrange(1 << 48) * -0.0000000000000035527136788 + 1)
* avg_interval_seconds
* -1000000.0
+ 0.5
)
async def _introducer_client(self):
async def on_connect() -> OutboundMessageGenerator:
msg = Message("request_peers", introducer_protocol.RequestPeers())
yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND)
await self.server.start_client(self.introducer_info, on_connect)
# If we are still connected to introducer, disconnect
for connection in self.global_connections.get_connections():
if connection.connection_type == NodeType.INTRODUCER:
self.global_connections.close(connection)
async def _connect_to_peers(self, random):
next_feeler = self._poisson_next_send(time.time() * 1000 * 1000, 240, random)
empty_tables = False
local_peerinfo: Optional[
PeerInfo
] = self.global_connections.get_local_peerinfo()
last_timestamp_local_info: uint64 = uint64(int(time.time()))
while not self.is_closed:
# We don't know any address, connect to the introducer to get some.
size = await self.address_manager.size()
if size == 0 or empty_tables:
await self._introducer_client()
await asyncio.sleep(min(15, self.peer_connect_interval))
empty_tables = False
continue
# Only connect out to one peer per network group (/16 for IPv4).
groups = []
connected = self.global_connections.get_full_node_peerinfos()
for conn in self.global_connections.get_outbound_connections():
peer = conn.get_peer_info()
group = peer.get_group()
if group not in groups:
groups.append(group)
# Feeler Connections
#
# Design goals:
# * Increase the number of connectable addresses in the tried table.
#
# Method:
# * Choose a random address from new and attempt to connect to it if we can connect
# successfully it is added to tried.
# * Start attempting feeler connections only after node finishes making outbound
# connections.
# * Only make a feeler connection once every few minutes.
is_feeler = False
has_collision = False
if self._num_needed_peers() == 0:
if time.time() * 1000 * 1000 > next_feeler:
next_feeler = self._poisson_next_send(
time.time() * 1000 * 1000, 240, random
)
is_feeler = True
await self.address_manager.resolve_tried_collisions()
tries = 0
now = time.time()
got_peer = False
addr: Optional[PeerInfo] = None
max_tries = 50 if len(groups) >= 3 else 10
while not got_peer and not self.is_closed:
await asyncio.sleep(min(15, self.peer_connect_interval))
tries += 1
if tries > max_tries:
addr = None
empty_tables = True
break
info: Optional[
ExtendedPeerInfo
] = await self.address_manager.select_tried_collision()
if info is None:
info = await self.address_manager.select_peer(is_feeler)
else:
has_collision = True
if info is None:
if not is_feeler:
empty_tables = True
break
# Require outbound connections, other than feelers, to be to distinct network groups.
addr = info.peer_info
if not is_feeler and addr.get_group() in groups:
addr = None
continue
if addr in connected:
addr = None
continue
# only consider very recently tried nodes after 30 failed attempts
if now - info.last_try < 1800 and tries < 30:
continue
if (
time.time() - last_timestamp_local_info > 1800
or local_peerinfo is None
):
local_peerinfo = self.global_connections.get_local_peerinfo()
last_timestamp_local_info = uint64(int(time.time()))
if local_peerinfo is not None and addr == local_peerinfo:
continue
got_peer = True
disconnect_after_handshake = is_feeler
if self._num_needed_peers() == 0:
disconnect_after_handshake = True
empty_tables = False
initiate_connection = (
self._num_needed_peers() > 0 or has_collision or is_feeler
)
if addr is not None and initiate_connection:
asyncio.create_task(
self.server.start_client(
addr, None, None, disconnect_after_handshake
)
)
sleep_interval = 5 + len(connected) * 10
sleep_interval = min(sleep_interval, self.peer_connect_interval)
await asyncio.sleep(sleep_interval)
async def _periodically_serialize(self, random: Random):
while not self.is_closed:
serialize_interval = random.randint(15 * 60, 30 * 60)
await asyncio.sleep(serialize_interval)
async with self.address_manager.lock:
await self.address_manager_store.serialize(self.address_manager)
async def _respond_peers_common(self, request, peer_src, is_full_node):
# Check if we got the peers from a full node or from the introducer.
peers_adjusted_timestamp = []
for peer in request.peer_list:
if peer.timestamp < 100000000 or peer.timestamp > time.time() + 10 * 60:
# Invalid timestamp, predefine a bad one.
current_peer = TimestampedPeerInfo(
peer.host,
peer.port,
uint64(int(time.time() - 5 * 24 * 60 * 60)),
)
else:
current_peer = peer
if not is_full_node:
current_peer = TimestampedPeerInfo(
peer.host,
peer.port,
uint64(0),
)
peers_adjusted_timestamp.append(current_peer)
if is_full_node:
await self.address_manager.add_to_new_table(
peers_adjusted_timestamp, peer_src, 2 * 60 * 60
)
else:
await self.address_manager.add_to_new_table(
peers_adjusted_timestamp, None, 0
)
class FullNodePeers(FullNodeDiscovery):
def __init__(
self,
server,
root_path,
global_connections,
max_inbound_count,
target_outbound_count,
peer_db_path,
introducer_info,
peer_connect_interval,
log,
):
super().__init__(
server,
root_path,
global_connections,
target_outbound_count,
peer_db_path,
introducer_info,
peer_connect_interval,
log,
)
self.global_connections.max_inbound_count = max_inbound_count
self.relay_queue = asyncio.Queue()
self.lock = asyncio.Lock()
self.neighbour_known_peers = {}
self.key = randbits(256)
async def start(self):
await self.initialize_address_manager()
self.global_connections.set_full_node_peers_callback(self.add_message)
self.self_advertise_task = asyncio.create_task(
self._periodically_self_advertise()
)
self.address_relay_task = asyncio.create_task(self._address_relay())
await self.start_tasks()
async def close(self):
await self._close_common()
self.self_advertise_task.cancel()
self.address_relay_task.cancel()
async def _periodically_self_advertise(self):
while not self.is_closed:
await asyncio.sleep(3600 * 24)
# Clean up known nodes for neighbours every 24 hours.
async with self.lock:
for neighbour in self.neighbour_known_peers:
neighbour.clear()
# Self advertise every 24 hours.
peer = self.global_connections.get_local_peerinfo()
if peer is None:
continue
timestamped_peer = [
TimestampedPeerInfo(
peer.host,
peer.port,
uint64(int(time.time())),
)
]
outbound_message = OutboundMessage(
NodeType.FULL_NODE,
Message(
"respond_peers_full_node",
full_node_protocol.RespondPeers([timestamped_peer]),
),
Delivery.BROADCAST,
)
if self.server is not None:
self.server.push_message(outbound_message)
async def add_peers_neighbour(self, peers, neighbour_info):
neighbour_data = (neighbour_info.host, neighbour_info.port)
async with self.lock:
for peer in peers:
if neighbour_data not in self.neighbour_known_peers:
self.neighbour_known_peers[neighbour_data] = set()
if peer.host not in self.neighbour_known_peers[neighbour_data]:
self.neighbour_known_peers[neighbour_data].add(peer.host)
async def request_peers(self, peer_info):
try:
conns = self.global_connections.get_outbound_connections()
is_outbound = False
for conn in conns:
conn_peer_info = conn.get_peer_info()
if conn_peer_info == peer_info:
is_outbound = True
break
# Prevent a fingerprint attack: do not send peers to inbound connections.
# This asymmetric behavior for inbound and outbound connections was introduced
# to prevent a fingerprinting attack: an attacker can send specific fake addresses
# to users' AddrMan and later request them by sending getaddr messages.
# Making nodes which are behind NAT and can only make outgoing connections ignore
# the request_peers message mitigates the attack.
if is_outbound:
return
peers = await self.address_manager.get_peers()
await self.add_peers_neighbour(peers, peer_info)
outbound_message = OutboundMessage(
NodeType.FULL_NODE,
Message(
"respond_peers_full_node",
full_node_protocol.RespondPeers(peers),
),
Delivery.RESPOND,
)
yield outbound_message
outbound_message2 = OutboundMessage(
NodeType.WALLET,
Message(
"respond_peers_full_node",
full_node_protocol.RespondPeers(peers),
),
Delivery.RESPOND,
)
yield outbound_message2
except Exception as e:
self.log.error(f"Request peers exception: {e}")
async def respond_peers(self, request, peer_src, is_full_node):
await self._respond_peers_common(request, peer_src, is_full_node)
if is_full_node:
await self.add_peers_neighbour(request.peer_list, peer_src)
if len(request.peer_list) == 1 and self.relay_queue is not None:
peer = request.peer_list[0]
if peer.timestamp > time.time() - 60 * 10:
self.relay_queue.put_nowait((peer, 2))
async def _address_relay(self):
while not self.is_closed:
try:
relay_peer, num_peers = await self.relay_queue.get()
# https://en.bitcoin.it/wiki/Satoshi_Client_Node_Discovery#Address_Relay
connections = self.global_connections.get_full_node_connections()
hashes = []
cur_day = int(time.time()) // (24 * 60 * 60)
for connection in connections:
peer_info = connection.get_peer_info()
cur_hash = int.from_bytes(
bytes(
std_hash(
self.key.to_bytes(32, byteorder="big")
+ peer_info.get_key()
+ cur_day.to_bytes(3, byteorder="big")
)
),
byteorder="big",
)
hashes.append((cur_hash, connection))
hashes.sort(key=lambda x: x[0])
for index, (_, connection) in enumerate(hashes):
if index >= num_peers:
break
peer_info = connection.get_peer_info()
pair = (peer_info.host, peer_info.port)
async with self.lock:
if (
pair in self.neighbour_known_peers
and relay_peer.host in self.neighbour_known_peers[pair]
):
continue
if pair not in self.neighbour_known_peers:
self.neighbour_known_peers[pair] = set()
self.neighbour_known_peers[pair].add(relay_peer.host)
if connection.node_id is None:
continue
msg = OutboundMessage(
NodeType.FULL_NODE,
Message(
"respond_peers_full_node",
full_node_protocol.RespondPeers([relay_peer]),
),
Delivery.SPECIFIC,
connection.node_id,
)
self.server.push_message(msg)
except Exception as e:
self.log.error(f"Exception in address relay: {e}")
self.log.error(f"Traceback: {traceback.format_exc()}")
class WalletPeers(FullNodeDiscovery):
def __init__(
self,
server,
root_path,
global_connections,
target_outbound_count,
peer_db_path,
introducer_info,
peer_connect_interval,
log,
):
super().__init__(
server,
root_path,
global_connections,
target_outbound_count,
peer_db_path,
introducer_info,
peer_connect_interval,
log,
)
async def start(self):
await self.initialize_address_manager()
self.global_connections.set_wallet_callback(self.add_message)
await self.start_tasks()
async def ensure_is_closed(self):
if self.is_closed:
return
await self._close_common()
async def respond_peers(self, request, peer_src, is_full_node):
await self._respond_peers_common(request, peer_src, is_full_node)

View File

@ -4,7 +4,6 @@ import logging
import random
import ssl
from typing import Any, AsyncGenerator, List, Optional, Tuple
from aiter import aiter_forker, iter_to_aiter, join_aiters, map_aiter, push_aiter
from src.protocols.shared_protocol import (
@ -149,7 +148,7 @@ async def initialize_pipeline(
async def stream_reader_writer_to_connection(
swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc],
swrt: Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc, bool, bool],
server_port: int,
local_type: NodeType,
log: logging.Logger,
@ -158,8 +157,10 @@ async def stream_reader_writer_to_connection(
Maps a tuple of (StreamReader, StreamWriter, on_connect) to a ChiaConnection object,
which also stores the type of connection (str). It is also added to the global list.
"""
sr, sw, on_connect = swrt
con = ChiaConnection(local_type, None, sr, sw, server_port, on_connect, log)
sr, sw, on_connect, is_outbound, is_feeler = swrt
con = ChiaConnection(
local_type, None, sr, sw, server_port, on_connect, log, is_outbound, is_feeler
)
con.log.info(f"Connection with {con.get_peername()} established")
return con
@ -206,23 +207,22 @@ async def perform_handshake(
):
raise ProtocolError(Err.INVALID_HANDSHAKE)
if inbound_handshake.node_id == outbound_handshake.data.node_id:
raise ProtocolError(Err.SELF_CONNECTION)
# Makes sure that we only start one connection with each peer
connection.node_id = inbound_handshake.node_id
connection.peer_server_port = int(inbound_handshake.server_port)
connection.connection_type = inbound_handshake.node_type
if inbound_handshake.node_id == outbound_handshake.data.node_id:
raise ProtocolError(Err.SELF_CONNECTION)
if srwt_aiter.is_stopped():
raise Exception("No longer accepting handshakes, closing.")
if not global_connections.add(connection):
raise ProtocolError(Err.DUPLICATE_CONNECTION, [False])
# Tries adding the connection, and raises an exception on failure.
global_connections.add(connection)
# Send Ack message
await connection.send(Message("handshake_ack", HandshakeAck()))
# Read Ack message
full_message = await connection.read_one_message()
if full_message.function != "handshake_ack":
@ -241,10 +241,13 @@ async def perform_handshake(
f" established"
)
)
# Only yield a connection if the handshake is succesful and the connection is not a duplicate.
yield connection, global_connections
async for conn in global_connections.successful_handshake(connection):
yield conn, global_connections
except Exception as e:
connection.log.warning(f"{e}, handshake not completed. Connection not created.")
global_connections.failed_handshake(connection, e)
# Make sure to close the connection even if it's not in global connections
connection.close()
# Remove the conenction from global connections
@ -323,15 +326,19 @@ async def handle_message(
Message("pong", Pong(ping_msg.nonce)),
Delivery.RESPOND,
)
global_connections.update_connection_time(connection)
yield connection, outbound_message, global_connections
return
elif full_message.function == "pong":
global_connections.update_connection_time(connection)
return
f_with_peer_name = getattr(api, full_message.function + "_with_peer_name", None)
f_with_peer_info = getattr(api, full_message.function + "_with_peer_info", None)
if f_with_peer_name is not None:
result = f_with_peer_name(full_message.data, connection.get_peername())
elif f_with_peer_info is not None:
result = f_with_peer_info(full_message.data, connection.get_peer_info())
else:
f = getattr(api, full_message.function, None)
@ -347,6 +354,8 @@ async def handle_message(
yield connection, outbound_message, global_connections
else:
await result
global_connections.update_connection_time(connection)
except Exception:
tb = traceback.format_exc()
connection.log.error(f"Error, closing connection {connection}. {tb}")

View File

@ -37,11 +37,12 @@ async def start_server(
def add_connection_type(
srw: Tuple[asyncio.StreamReader, asyncio.StreamWriter]
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc]:
) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter, OnConnectFunc, bool, bool]:
ssl_object = srw[1].get_extra_info(name="ssl_object")
peer_cert = ssl_object.getpeercert()
self.log.info(f"Client authed as {peer_cert}")
return (srw[0], srw[1], on_connect)
# Inbound peer, not a feeler.
return (srw[0], srw[1], on_connect, False, False)
srwt_aiter = map_aiter(add_connection_type, aiter)
@ -66,7 +67,7 @@ class ChiaServer:
name: str = None,
):
# Keeps track of all connections to and from this node.
self.global_connections: PeerConnections = PeerConnections([])
self.global_connections: PeerConnections = PeerConnections(local_type, [])
self._port = port # TCP port to identify our node
self._local_type = local_type # NodeType (farmer, full node, timelord, pool, harvester, wallet)
@ -116,25 +117,32 @@ class ChiaServer:
self.root_path = root_path
self.config = config
self._pending_connections: List = []
async def start_client(
self,
target_node: PeerInfo,
on_connect: OnConnectFunc = None,
auth: bool = False,
is_feeler: bool = False,
) -> bool:
"""
Tries to connect to the target node, adding one connection into the pipeline, if successful.
An on connect method can also be specified, and this will be saved into the instance variables.
"""
self.log.info(f"Trying to connect with {target_node}.")
if self._pipeline_task.done():
self.log.warning("Starting client after server closed")
self.log.error("Starting client after server closed")
return False
ssl_context = ssl_context_for_client(self.root_path, self.config, auth=auth)
try:
# Sometimes open_connection takes a long time, so we add it as a task, and cancel
# the task in the event of closing the node.
peer_info = (target_node.host, target_node.port)
if peer_info in self._pending_connections:
return False
self._pending_connections.append(peer_info)
oc_task: asyncio.Task = asyncio.create_task(
asyncio.open_connection(
target_node.host, int(target_node.port), ssl=ssl_context
@ -142,15 +150,22 @@ class ChiaServer:
)
self._oc_tasks.append(oc_task)
reader, writer = await oc_task
self._pending_connections.remove(peer_info)
self._oc_tasks.remove(oc_task)
except Exception as e:
self.log.warning(
f"Could not connect to {target_node}. {type(e)}{str(e)}. Aborting and removing peer."
)
self.global_connections.peers.remove(target_node)
self.global_connections.failed_connection(target_node)
if self.global_connections.introducer_peers is not None:
self.global_connections.introducer_peers.remove(target_node)
if peer_info in self._pending_connections:
self._pending_connections.remove(peer_info)
return False
if not self._srwt_aiter.is_stopped():
self._srwt_aiter.push(iter_to_aiter([(reader, writer, on_connect)]))
self._srwt_aiter.push(
iter_to_aiter([(reader, writer, on_connect, True, is_feeler)])
)
ssl_object = writer.get_extra_info(name="ssl_object")
peer_cert = ssl_object.getpeercert()

View File

@ -9,7 +9,6 @@ from src.util.config import load_config_cli
from src.util.default_root import DEFAULT_ROOT_PATH
from src.server.upnp import upnp_remap_port
from src.types.peer_info import PeerInfo
# See: https://bugs.python.org/issue29288
u"".encode("idna")
@ -21,9 +20,6 @@ def service_kwargs_for_full_node(root_path):
api = FullNode(config, root_path=root_path, consensus_constants=constants)
introducer = config["introducer_peer"]
peer_info = PeerInfo(introducer["host"], introducer["port"])
async def start_callback():
if config["enable_upnp"]:
upnp_remap_port(config["port"])
@ -46,11 +42,6 @@ def service_kwargs_for_full_node(root_path):
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
periodic_introducer_poll=(
peer_info,
config["introducer_connect_interval"],
config["target_peer_count"],
),
)
if config["start_rpc_server"]:
kwargs["rpc_info"] = (FullNodeRpcApi, config["rpc_port"])

View File

@ -4,15 +4,14 @@ import logging.config
import signal
from sys import platform
from typing import Any, AsyncGenerator, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple
try:
import uvloop
except ImportError:
uvloop = None
from src.protocols import introducer_protocol
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.server.outbound_message import NodeType
from src.server.server import ChiaServer, start_server
from src.types.peer_info import PeerInfo
from src.util.logging import initialize_logging
@ -23,8 +22,6 @@ from src.server.connection import OnConnectFunc
from .reconnect_task import start_reconnect_task
OutboundMessageGenerator = AsyncGenerator[OutboundMessage, None]
stopped_by_signal = False
@ -34,43 +31,6 @@ def global_signal_handler(*args):
stopped_by_signal = True
def create_periodic_introducer_poll_task(
server,
peer_info,
global_connections,
introducer_connect_interval,
target_peer_count,
):
"""
Start a background task connecting periodically to the introducer and
requesting the peer list.
"""
def _num_needed_peers() -> int:
diff = target_peer_count - len(global_connections.get_full_node_connections())
return diff if diff >= 0 else 0
async def introducer_client():
async def on_connect() -> OutboundMessageGenerator:
msg = Message("request_peers", introducer_protocol.RequestPeers())
yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND)
while True:
# If we are still connected to introducer, disconnect
for connection in global_connections.get_connections():
if connection.connection_type == NodeType.INTRODUCER:
global_connections.close(connection)
# The first time connecting to introducer, keep trying to connect
if _num_needed_peers():
if not await server.start_client(peer_info, on_connect):
await asyncio.sleep(5)
continue
await asyncio.sleep(introducer_connect_interval)
return asyncio.create_task(introducer_client())
class Service:
def __init__(
self,
@ -87,7 +47,6 @@ class Service:
start_callback: Optional[Callable] = None,
stop_callback: Optional[Callable] = None,
await_closed_callback: Optional[Callable] = None,
periodic_introducer_poll: Optional[Tuple[PeerInfo, int, int]] = None,
parse_cli_args=True,
):
net_config = load_config(root_path, "config.yaml")
@ -136,7 +95,6 @@ class Service:
self._is_stopping = False
self._stopped_by_rpc = False
self._periodic_introducer_poll = periodic_introducer_poll
self._on_connect_callback = on_connect_callback
self._start_callback = start_callback
self._stop_callback = stop_callback
@ -152,21 +110,6 @@ class Service:
if self._start_callback:
await self._start_callback()
self._introducer_poll_task = None
if self._periodic_introducer_poll:
(
peer_info,
introducer_connect_interval,
target_peer_count,
) = self._periodic_introducer_poll
self._introducer_poll_task = create_periodic_introducer_poll_task(
self._server,
peer_info,
self._server.global_connections,
introducer_connect_interval,
target_peer_count,
)
self._rpc_task = None
self._rpc_close_task = None
if self._rpc_info:
@ -223,9 +166,6 @@ class Service:
self._log.info("Closing connections")
self._server.close_all()
self._api._shut_down = True
self._log.info("Stopping introducer task")
if self._introducer_poll_task:
self._introducer_poll_task.cancel()
self._log.info("Calling service stop callback")
if self._stop_callback:

View File

@ -27,11 +27,12 @@ def service_kwargs_for_wallet(root_path):
api = WalletNode(config, keychain, root_path, consensus_constants=wallet_constants)
introducer = config["introducer_peer"]
peer_info = PeerInfo(introducer["host"], introducer["port"])
connect_peers = [
PeerInfo(config["full_node_peer"]["host"], config["full_node_peer"]["port"])
]
if "full_node_peer" in config:
connect_peers = [
PeerInfo(config["full_node_peer"]["host"], config["full_node_peer"]["port"])
]
else:
connect_peers = []
async def start_callback():
await api._start()
@ -56,11 +57,6 @@ def service_kwargs_for_wallet(root_path):
rpc_info=(WalletRpcApi, config["rpc_port"]),
connect_peers=connect_peers,
auth_connect_peers=False,
periodic_introducer_poll=(
peer_info,
config["introducer_connect_interval"],
config["target_peer_count"],
),
)
return kwargs

View File

@ -3,14 +3,47 @@ from dataclasses import dataclass
from src.util.ints import uint16, uint64
from src.util.streamable import Streamable, streamable
import ipaddress
@dataclass(frozen=True)
@streamable
class PeerInfo(Streamable):
# TODO: Change `host` type to bytes16
host: str
port: uint16
# Functions related to peer bucketing in new/tried tables.
def get_key(self):
try:
ip = ipaddress.IPv6Address(self.host)
except ValueError:
ip_v4 = ipaddress.IPv4Address(self.host)
ip = ipaddress.IPv6Address(
int(ipaddress.IPv6Address("2002::")) | (int(ip_v4) << 80)
)
key = ip.packed
key += bytes(
[
self.port // 0x100,
self.port & 0x0FF,
]
)
return key
def get_group(self):
# TODO: Port everything from Bitcoin.
ipv4 = 1
try:
ip = ipaddress.IPv4Address(self.host)
except ValueError:
ip = ipaddress.IPv6Address(self.host)
ipv4 = 0
if ipv4:
group = bytes([1]) + ip.packed[:2]
else:
group = bytes([0]) + ip.packed[:4]
return group
@dataclass(frozen=True)
@streamable

View File

@ -16,6 +16,7 @@ class Err(Enum):
BLOCK_NOT_IN_BLOCKCHAIN = -10
NO_PROOF_OF_SPACE_FOUND = -11
PEERS_DONT_HAVE_BLOCK = -12
MAX_INBOUND_CONNECTIONS_REACHED = -13
UNKNOWN = -9999

View File

@ -114,6 +114,7 @@ full_node:
# Run multiple nodes with different databases by changing the database_path
database_path: db/blockchain_v20.db
peer_db_path: db/peer_table_node.db
simulator_database_path: sim_db/simulator_blockchain_v20.db
# If True, starts an RPC server at the following port
@ -131,9 +132,11 @@ full_node:
sync_blocks_behind_threshold: 20
# How often to connect to introducer if we need to learn more peers
introducer_connect_interval: 500
# Continue trying to connect to more peers until this number of connections
target_peer_count: 15
peer_connect_interval: 50
# Accept peers until this number of connections
target_peer_count: 60
# Initiate outbound connections until this number is hit.
target_outbound_peer_count: 10
# Only connect to peers who we have heard about in the last recent_peer_threshold seconds
recent_peer_threshold: 6000
@ -200,11 +203,12 @@ wallet:
testing: False
database_path: wallet/db/blockchain_wallet_v20.db
wallet_peers_path: wallet/db/wallet_peers.db
logging: *logging
target_peer_count: 5
introducer_connect_interval: 60
peer_connect_interval: 60
# The introducer will only return peers who it has seen in the last
# recent_peer_threshold seconds
recent_peer_threshold: 6000

View File

@ -10,7 +10,6 @@ import logging
import traceback
from blspy import PrivateKey
from src.full_node.full_node import OutboundMessageGenerator
from src.types.peer_info import PeerInfo
from src.util.byte_types import hexstr_to_bytes
from src.util.merkle_set import (
@ -23,7 +22,8 @@ from src.consensus.constants import ConsensusConstants
from src.server.connection import PeerConnections
from src.server.server import ChiaServer
from src.server.outbound_message import OutboundMessage, NodeType, Message, Delivery
from src.util.ints import uint16, uint32, uint64
from src.server.node_discovery import WalletPeers
from src.util.ints import uint32, uint64
from src.types.sized_bytes import bytes32
from src.util.api_decorators import api_request
from src.wallet.derivation_record import DerivationRecord
@ -49,7 +49,7 @@ class WalletNode:
constants: ConsensusConstants
server: Optional[ChiaServer]
log: logging.Logger
wallet_peers: WalletPeers
# Maintains the state of the wallet (blockchain and transactions), handles DB connections
wallet_state_manager: Optional[WalletStateManager]
@ -181,6 +181,18 @@ class WalletNode:
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
self._shut_down = False
self.wallet_peers = WalletPeers(
self.server,
self.root_path,
self.global_connections,
self.config["target_peer_count"],
self.config["wallet_peers_path"],
self.config["introducer_peer"],
self.config["peer_connect_interval"],
self.log,
)
await self.wallet_peers.start()
asyncio.create_task(self._periodically_check_full_node())
return True
def _close(self):
@ -191,6 +203,9 @@ class WalletNode:
self.wallet_state_manager.close_all_stores()
)
self.global_connections.close_all_connections()
self.wallet_peers_task = asyncio.create_task(
self.wallet_peers.ensure_is_closed()
)
async def _await_closed(self):
if self.sync_generator_task is not None:
@ -310,16 +325,16 @@ class WalletNode:
for msg in messages:
yield msg
def _num_needed_peers(self) -> int:
if self.wallet_state_manager is None or self.backup_initialized is False:
return 0
assert self.server is not None
diff = self.config["target_peer_count"] - len(
self.global_connections.get_full_node_connections()
)
if diff < 0:
return 0
async def _periodically_check_full_node(self):
tries = 0
while not self._shut_down and tries < 5:
if self._has_full_node():
await self.wallet_peers.ensure_is_closed()
break
tries += 1
await asyncio.sleep(180)
def _has_full_node(self) -> bool:
if "full_node_peer" in self.config:
full_node_peer = PeerInfo(
self.config["full_node_peer"]["host"],
@ -345,39 +360,30 @@ class WalletNode:
f"Closing unnecessary connection to {connection.get_peer_info()}."
)
self.global_connections.close(connection)
return 0
return diff
return True
return False
@api_request
async def respond_peers(
self, request: introducer_protocol.RespondPeers
) -> OutboundMessageGenerator:
"""
We have received a list of full node peers that we can connect to.
"""
if self.server is None or self.wallet_state_manager is None:
return
conns = self.global_connections
for peer in request.peer_list:
conns.peers.add(PeerInfo(peer.host, uint16(peer.port)))
async def respond_peers_with_peer_info(
self,
request: introducer_protocol.RespondPeers,
peer_info: PeerInfo,
):
if not self._has_full_node():
await self.wallet_peers.respond_peers(request, peer_info, False)
else:
await self.wallet_peers.ensure_is_closed()
# Pseudo-message to close the connection
yield OutboundMessage(NodeType.INTRODUCER, Message("", None), Delivery.CLOSE)
unconnected = conns.get_unconnected_peers(
recent_threshold=self.config["recent_peer_threshold"]
)
to_connect = unconnected[: self._num_needed_peers()]
if not len(to_connect):
return
self.log.info(f"Trying to connect to peers: {to_connect}")
tasks = []
for target in to_connect:
tasks.append(
asyncio.create_task(self.server.start_client(target, self._on_connect))
)
await asyncio.gather(*tasks)
@api_request
async def respond_peers_full_node_with_peer_info(
self,
request: full_node_protocol.RespondPeers,
peer_info: PeerInfo,
):
if not self._has_full_node():
await self.wallet_peers.respond_peers(request, peer_info, True)
else:
await self.wallet_peers.ensure_is_closed()
@api_request
async def respond_peers_full_node(self, request: full_node_protocol.RespondPeers):

View File

@ -0,0 +1,581 @@
import asyncio
import pytest
import time
import math
import aiosqlite
from src.types.peer_info import PeerInfo, TimestampedPeerInfo
from src.server.address_manager import ExtendedPeerInfo, AddressManager
from src.server.address_manager_store import AddressManagerStore
from pathlib import Path
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
class AddressManagerTest(AddressManager):
def __init__(self, make_deterministic=True):
super().__init__()
if make_deterministic:
self.make_deterministic()
def make_deterministic(self):
# Fix seed.
self.key = 2 ** 256 - 1
async def simulate_connection_fail(self, peer):
await self.mark_good(peer.peer_info, True, 1)
await self.attempt(peer.peer_info, False, time.time() - 61)
async def add_peer_info(self, peers, peer_src=None):
timestamped_peers = [
TimestampedPeerInfo(
peer.host,
peer.port,
0,
)
for peer in peers
]
added = await self.add_to_new_table(timestamped_peers, peer_src)
return added
class TestPeerManager:
@pytest.mark.asyncio
async def test_addr_manager(self):
addrman = AddressManagerTest()
# Test: Does Addrman respond correctly when empty.
none_peer = await addrman.select_peer()
assert none_peer is None
assert await addrman.size() == 0
# Test: Does Add work as expected.
peer1 = PeerInfo("250.1.1.1", 8444)
assert await addrman.add_peer_info([peer1])
assert await addrman.size() == 1
peer1_ret = await addrman.select_peer()
assert peer1_ret.peer_info == peer1
# Test: Does IP address deduplication work correctly.
peer1_duplicate = PeerInfo("250.1.1.1", 8444)
assert not await addrman.add_peer_info([peer1_duplicate])
assert await addrman.size() == 1
# Test: New table has one addr and we add a diff addr we should
# have at least one addr.
# Note that addrman's size cannot be tested reliably after insertion, as
# hash collisions may occur. But we can always be sure of at least one
# success.
peer2 = PeerInfo("250.1.1.2", 8444)
assert await addrman.add_peer_info([peer2])
assert await addrman.size() >= 1
# Test: AddrMan::Add multiple addresses works as expected
addrman2 = AddressManagerTest()
peers = [peer1, peer2]
assert await addrman2.add_peer_info(peers)
assert await addrman2.size() >= 1
@pytest.mark.asyncio
async def test_addr_manager_ports(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
source = PeerInfo("252.2.2.2", 8444)
# Test: Addr with same IP but diff port does not replace existing addr.
peer1 = PeerInfo("250.1.1.1", 8444)
assert await addrman.add_peer_info([peer1], source)
assert await addrman.size() == 1
peer2 = PeerInfo("250.1.1.1", 8445)
assert not await addrman.add_peer_info([peer2], source)
assert await addrman.size() == 1
peer3 = await addrman.select_peer()
assert peer3.peer_info == peer1
# Test: Add same IP but diff port to tried table, it doesn't get added.
# Perhaps this is not ideal behavior but it is the current behavior.
await addrman.mark_good(peer2)
assert await addrman.size() == 1
peer3_ret = await addrman.select_peer(True)
assert peer3_ret.peer_info == peer1
# This is a fleaky test, since it uses randomness.
# TODO: Make sure it always succeeds.
@pytest.mark.asyncio
async def test_addrman_select(self):
addrman = AddressManagerTest()
source = PeerInfo("252.2.2.2", 8444)
# Test: Select from new with 1 addr in new.
peer1 = PeerInfo("250.1.1.1", 8444)
assert await addrman.add_peer_info([peer1], source)
assert await addrman.size() == 1
peer1_ret = await addrman.select_peer(True)
assert peer1_ret.peer_info == peer1
# Test: move addr to tried, select from new expected nothing returned.
await addrman.mark_good(peer1)
assert await addrman.size() == 1
peer2_ret = await addrman.select_peer(True)
assert peer2_ret is None
peer3_ret = await addrman.select_peer()
assert peer3_ret.peer_info == peer1
# Add three addresses to new table.
peer2 = PeerInfo("250.3.1.1", 8444)
peer3 = PeerInfo("250.3.2.2", 9999)
peer4 = PeerInfo("250.3.3.3", 9999)
assert await addrman.add_peer_info([peer2], PeerInfo("250.3.1.1", 8444))
assert await addrman.add_peer_info([peer3], PeerInfo("250.3.1.1", 8444))
assert await addrman.add_peer_info([peer4], PeerInfo("250.4.1.1", 8444))
# Add three addresses to tried table.
peer5 = PeerInfo("250.4.4.4", 8444)
peer6 = PeerInfo("250.4.5.5", 7777)
peer7 = PeerInfo("250.4.6.6", 8444)
assert await addrman.add_peer_info([peer5], PeerInfo("250.3.1.1", 8444))
await addrman.mark_good(peer5)
assert await addrman.add_peer_info([peer6], PeerInfo("250.3.1.1", 8444))
await addrman.mark_good(peer6)
assert await addrman.add_peer_info([peer7], PeerInfo("250.1.1.3", 8444))
await addrman.mark_good(peer7)
# Test: 6 addrs + 1 addr from last test = 7.
assert await addrman.size() == 7
# Test: Select pulls from new and tried regardless of port number.
ports = []
for _ in range(200):
peer = await addrman.select_peer()
if peer.peer_info.port not in ports:
ports.append(peer.peer_info.port)
if len(ports) == 3:
break
assert len(ports) == 3
@pytest.mark.asyncio
async def test_addrman_collisions_new(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
source = PeerInfo("252.2.2.2", 8444)
for i in range(1, 8):
peer = PeerInfo("250.1.1." + str(i), 8444)
assert await addrman.add_peer_info([peer], source)
assert await addrman.size() == i
# Test: new table collision!
peer1 = PeerInfo("250.1.1.8", 8444)
assert await addrman.add_peer_info([peer1], source)
assert await addrman.size() == 7
peer2 = PeerInfo("250.1.1.9", 8444)
assert await addrman.add_peer_info([peer2], source)
assert await addrman.size() == 8
@pytest.mark.asyncio
async def test_addrman_collisions_tried(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
source = PeerInfo("252.2.2.2", 8444)
for i in range(1, 77):
peer = PeerInfo("250.1.1." + str(i), 8444)
assert await addrman.add_peer_info([peer], source)
await addrman.mark_good(peer)
# Test: No collision in tried table yet.
assert await addrman.size() == i
# Test: tried table collision!
peer1 = PeerInfo("250.1.1.77", 8444)
assert await addrman.add_peer_info([peer1], source)
assert await addrman.size() == 76
peer2 = PeerInfo("250.1.1.78", 8444)
assert await addrman.add_peer_info([peer2], source)
assert await addrman.size() == 77
@pytest.mark.asyncio
async def test_addrman_find(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
peer1 = PeerInfo("250.1.2.1", 8333)
peer2 = PeerInfo("250.1.2.1", 9999)
peer3 = PeerInfo("251.255.2.1", 8333)
source1 = PeerInfo("250.1.2.1", 8444)
source2 = PeerInfo("250.1.2.2", 8444)
assert await addrman.add_peer_info([peer1], source1)
assert not await addrman.add_peer_info([peer2], source2)
assert await addrman.add_peer_info([peer3], source1)
# Test: ensure Find returns an IP matching what we searched on.
info1 = addrman.find_(peer1)
assert info1[0] is not None and info1[1] is not None
assert info1[0].peer_info == peer1
# Test: Find does not discriminate by port number.
info2 = addrman.find_(peer2)
assert info2[0] is not None and info2[1] is not None
assert info2 == info1
# Test: Find returns another IP matching what we searched on.
info3 = addrman.find_(peer3)
assert info3[0] is not None and info3[1] is not None
assert info3[0].peer_info == peer3
@pytest.mark.asyncio
async def test_addrman_create(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
peer1 = PeerInfo("250.1.2.1", 8444)
t_peer = TimestampedPeerInfo("250.1.2.1", 8444, 0)
info, node_id = addrman.create_(t_peer, peer1)
assert info.peer_info == peer1
info, _ = addrman.find_(peer1)
assert info.peer_info == peer1
@pytest.mark.asyncio
async def test_addrman_delete(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
peer1 = PeerInfo("250.1.2.1", 8444)
t_peer = TimestampedPeerInfo("250.1.2.1", 8444, 0)
info, node_id = addrman.create_(t_peer, peer1)
# Test: Delete should actually delete the addr.
assert await addrman.size() == 1
addrman.delete_new_entry_(node_id)
assert await addrman.size() == 0
info2, _ = addrman.find_(peer1)
assert info2 is None
@pytest.mark.asyncio
async def test_addrman_get_peers(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
peers1 = await addrman.get_peers()
assert len(peers1) == 0
peer1 = TimestampedPeerInfo("250.250.2.1", 8444, time.time())
peer2 = TimestampedPeerInfo("250.250.2.2", 9999, time.time())
peer3 = TimestampedPeerInfo("251.252.2.3", 8444, time.time())
peer4 = TimestampedPeerInfo("251.252.2.4", 8444, time.time())
peer5 = TimestampedPeerInfo("251.252.2.5", 8444, time.time())
source1 = PeerInfo("250.1.2.1", 8444)
source2 = PeerInfo("250.2.3.3", 8444)
# Test: Ensure GetPeers works with new addresses.
assert await addrman.add_to_new_table([peer1], source1)
assert await addrman.add_to_new_table([peer2], source2)
assert await addrman.add_to_new_table([peer3], source1)
assert await addrman.add_to_new_table([peer4], source1)
assert await addrman.add_to_new_table([peer5], source1)
# GetPeers returns 23% of addresses, 23% of 5 is 2 rounded up.
peers2 = await addrman.get_peers()
assert len(peers2) == 2
# Test: Ensure GetPeers works with new and tried addresses.
await addrman.mark_good(peer1)
await addrman.mark_good(peer2)
peers3 = await addrman.get_peers()
assert len(peers3) == 2
# Test: Ensure GetPeers still returns 23% when addrman has many addrs.
for i in range(1, 8 * 256):
octet1 = i % 256
octet2 = i >> 8 % 256
peer = TimestampedPeerInfo(
str(octet1) + "." + str(octet2) + ".1.23", 8444, time.time()
)
await addrman.add_to_new_table([peer])
if i % 8 == 0:
await addrman.mark_good(peer)
peers4 = await addrman.get_peers()
percent = await addrman.size()
percent = math.ceil(percent * 23 / 100)
assert len(peers4) == percent
@pytest.mark.asyncio
async def test_addrman_tried_bucket(self):
peer1 = PeerInfo("250.1.1.1", 8444)
t_peer1 = TimestampedPeerInfo("250.1.1.1", 8444, 0)
peer2 = PeerInfo("250.1.1.1", 9999)
t_peer2 = TimestampedPeerInfo("250.1.1.1", 9999, 0)
source1 = PeerInfo("250.1.1.1", 8444)
peer_info1 = ExtendedPeerInfo(t_peer1, source1)
# Test: Make sure key actually randomizes bucket placement. A fail on
# this test could be a security issue.
key1 = 2 ** 256 - 1
key2 = 2 ** 128 - 1
bucket1 = peer_info1.get_tried_bucket(key1)
bucket2 = peer_info1.get_tried_bucket(key2)
assert bucket1 != bucket2
# Test: Two addresses with same IP but different ports can map to
# different buckets because they have different keys.
peer_info2 = ExtendedPeerInfo(t_peer2, source1)
assert peer1.get_key() != peer2.get_key()
assert peer_info1.get_tried_bucket(key1) != peer_info2.get_tried_bucket(key1)
# Test: IP addresses in the same group (\16 prefix for IPv4) should
# never get more than 8 buckets
buckets = []
for i in range(255):
peer = PeerInfo("250.1.1." + str(i), 8444)
t_peer = TimestampedPeerInfo("250.1.1." + str(i), 8444, 0)
extended_peer_info = ExtendedPeerInfo(t_peer, peer)
bucket = extended_peer_info.get_tried_bucket(key1)
if bucket not in buckets:
buckets.append(bucket)
assert len(buckets) == 8
# Test: IP addresses in the different groups should map to more than
# 8 buckets.
buckets = []
for i in range(255):
peer = PeerInfo("250." + str(i) + ".1.1", 8444)
t_peer = TimestampedPeerInfo("250." + str(i) + ".1.1", 8444, 0)
extended_peer_info = ExtendedPeerInfo(t_peer, peer)
bucket = extended_peer_info.get_tried_bucket(key1)
if bucket not in buckets:
buckets.append(bucket)
assert len(buckets) > 8
@pytest.mark.asyncio
async def test_addrman_new_bucket(self):
t_peer1 = TimestampedPeerInfo("250.1.2.1", 8444, 0)
source1 = PeerInfo("250.1.2.1", 8444)
t_peer2 = TimestampedPeerInfo("250.1.2.1", 9999, 0)
peer_info1 = ExtendedPeerInfo(t_peer1, source1)
# Test: Make sure key actually randomizes bucket placement. A fail on
# this test could be a security issue.
key1 = 2 ** 256 - 1
key2 = 2 ** 128 - 1
bucket1 = peer_info1.get_new_bucket(key1)
bucket2 = peer_info1.get_new_bucket(key2)
assert bucket1 != bucket2
# Test: Ports should not affect bucket placement in the addr
peer_info2 = ExtendedPeerInfo(t_peer2, source1)
assert peer_info1.get_new_bucket(key1) == peer_info2.get_new_bucket(key1)
# Test: IP addresses in the same group (\16 prefix for IPv4) should
# always map to the same bucket.
buckets = []
for i in range(255):
peer = PeerInfo("250.1.1." + str(i), 8444)
t_peer = TimestampedPeerInfo("250.1.1." + str(i), 8444, 0)
extended_peer_info = ExtendedPeerInfo(t_peer, peer)
bucket = extended_peer_info.get_new_bucket(key1)
if bucket not in buckets:
buckets.append(bucket)
assert len(buckets) == 1
# Test: IP addresses in the same source groups should map to no more
# than 64 buckets.
buckets = []
for i in range(4 * 255):
src = PeerInfo("251.4.1.1", 8444)
peer = PeerInfo(str(250 + i // 255) + "." + str(i % 256) + ".1.1", 8444)
t_peer = TimestampedPeerInfo(
str(250 + i // 255) + "." + str(i % 256) + ".1.1", 8444, 0
)
extended_peer_info = ExtendedPeerInfo(t_peer, src)
bucket = extended_peer_info.get_new_bucket(key1)
if bucket not in buckets:
buckets.append(bucket)
assert len(buckets) <= 64
# Test: IP addresses in the different source groups should map to more
# than 64 buckets.
buckets = []
for i in range(255):
src = PeerInfo("250." + str(i) + ".1.1", 8444)
peer = PeerInfo("250.1.1.1", 8444)
t_peer = TimestampedPeerInfo("250.1.1.1", 8444, 0)
extended_peer_info = ExtendedPeerInfo(t_peer, src)
bucket = extended_peer_info.get_new_bucket(key1)
if bucket not in buckets:
buckets.append(bucket)
assert len(buckets) > 64
@pytest.mark.asyncio
async def test_addrman_select_collision_no_collision(self):
addrman = AddressManagerTest()
collision = await addrman.select_tried_collision()
assert collision is None
# Add 17 addresses.
source = PeerInfo("252.2.2.2", 8444)
for i in range(1, 18):
peer = PeerInfo("250.1.1." + str(i), 8444)
assert await addrman.add_peer_info([peer], source)
await addrman.mark_good(peer)
# No collisions yet.
assert await addrman.size() == i
collision = await addrman.select_tried_collision()
assert collision is None
# Ensure Good handles duplicates well.
for i in range(1, 18):
peer = PeerInfo("250.1.1." + str(i), 8444)
await addrman.mark_good(peer)
assert await addrman.size() == 17
collision = await addrman.select_tried_collision()
assert collision is None
@pytest.mark.asyncio
async def test_addrman_no_evict(self):
addrman = AddressManagerTest()
# Add 17 addresses.
source = PeerInfo("252.2.2.2", 8444)
for i in range(1, 18):
peer = PeerInfo("250.1.1." + str(i), 8444)
assert await addrman.add_peer_info([peer], source)
await addrman.mark_good(peer)
# No collision yet.
assert await addrman.size() == i
collision = await addrman.select_tried_collision()
assert collision is None
peer18 = PeerInfo("250.1.1.18", 8444)
assert await addrman.add_peer_info([peer18], source)
await addrman.mark_good(peer18)
assert await addrman.size() == 18
collision = await addrman.select_tried_collision()
assert collision.peer_info == PeerInfo("250.1.1.16", 8444)
await addrman.resolve_tried_collisions()
collision = await addrman.select_tried_collision()
assert collision is None
# Lets create two collisions.
for i in range(19, 37):
peer = PeerInfo("250.1.1." + str(i), 8444)
assert await addrman.add_peer_info([peer], source)
await addrman.mark_good(peer)
assert await addrman.size() == i
assert await addrman.select_tried_collision() is None
# Cause a collision.
peer37 = PeerInfo("250.1.1.37", 8444)
assert await addrman.add_peer_info([peer37], source)
await addrman.mark_good(peer37)
assert await addrman.size() == 37
# Cause a second collision.
assert not await addrman.add_peer_info([peer18], source)
await addrman.mark_good(peer18)
assert await addrman.size() == 37
collision = await addrman.select_tried_collision()
assert collision is not None
await addrman.resolve_tried_collisions()
collision = await addrman.select_tried_collision()
assert collision is None
@pytest.mark.asyncio
async def test_addrman_eviction_works(self):
addrman = AddressManagerTest()
assert await addrman.size() == 0
# Empty addrman should return blank addrman info.
assert await addrman.select_tried_collision() is None
# Add twenty two addresses.
source = PeerInfo("252.2.2.2", 8444)
for i in range(1, 18):
peer = PeerInfo("250.1.1." + str(i), 8444)
assert await addrman.add_peer_info([peer], source)
await addrman.mark_good(peer)
# No collision yet.
assert await addrman.size() == i
assert await addrman.select_tried_collision() is None
# Collision between 18 and 16.
peer18 = PeerInfo("250.1.1.18", 8444)
assert await addrman.add_peer_info([peer18], source)
await addrman.mark_good(peer18)
assert await addrman.size() == 18
collision = await addrman.select_tried_collision()
assert collision.peer_info == PeerInfo("250.1.1.16", 8444)
await addrman.simulate_connection_fail(collision)
# Should swap 18 for 16.
await addrman.resolve_tried_collisions()
assert await addrman.select_tried_collision() is None
# If 18 was swapped for 16, then this should cause no collisions.
assert not await addrman.add_peer_info([peer18], source)
await addrman.mark_good(peer18)
assert await addrman.select_tried_collision() is None
# If we insert 16 is should collide with 18.
addr16 = PeerInfo("250.1.1.16", 8444)
assert not await addrman.add_peer_info([addr16], source)
await addrman.mark_good(addr16)
collision = await addrman.select_tried_collision()
assert collision.peer_info == PeerInfo("250.1.1.18", 8444)
await addrman.resolve_tried_collisions()
assert await addrman.select_tried_collision() is None
@pytest.mark.asyncio
async def test_serialization(self):
addrman = AddressManagerTest()
peer1 = PeerInfo("250.7.1.1", 8333)
peer2 = PeerInfo("250.7.2.2", 9999)
peer3 = PeerInfo("250.7.3.3", 9999)
source = PeerInfo("252.5.1.1", 8333)
await addrman.add_peer_info([peer1, peer2, peer3], source)
await addrman.mark_good(peer1)
db_filename = Path("peer_table.db")
if db_filename.exists():
db_filename.unlink()
connection = await aiosqlite.connect(db_filename)
address_manager_store = await AddressManagerStore.create(connection)
await address_manager_store.serialize(addrman)
addrman2 = await address_manager_store.deserialize()
retrieved_peers = []
for _ in range(20):
peer = await addrman2.select_peer()
if peer not in retrieved_peers:
retrieved_peers.append(peer)
if len(retrieved_peers) == 3:
break
assert len(retrieved_peers) == 3
t_peer1 = TimestampedPeerInfo("250.7.1.1", 8333, 0)
t_peer2 = TimestampedPeerInfo("250.7.2.2", 9999, 0)
t_peer3 = TimestampedPeerInfo("250.7.3.3", 9999, 0)
wanted_peers = [
ExtendedPeerInfo(t_peer1, source),
ExtendedPeerInfo(t_peer2, source),
ExtendedPeerInfo(t_peer3, source),
]
recovered = 0
for target_peer in wanted_peers:
for current_peer in retrieved_peers:
if (
current_peer.peer_info == target_peer.peer_info
and current_peer.src == target_peer.src
):
recovered += 1
assert recovered == 3
await connection.close()
db_filename.unlink()

View File

@ -12,7 +12,8 @@ from src.protocols import (
wallet_protocol,
)
from src.server.outbound_message import NodeType
from src.types.peer_info import PeerInfo
from src.types.peer_info import TimestampedPeerInfo, PeerInfo
from src.server.address_manager import AddressManager
from src.types.full_block import FullBlock
from src.types.proof_of_space import ProofOfSpace
from src.types.spend_bundle import SpendBundle
@ -85,6 +86,34 @@ async def wallet_blocks_five(two_nodes):
class TestFullNodeProtocol:
@pytest.mark.asyncio
async def test_request_peers(self, two_nodes, wallet_blocks):
full_node_1, full_node_2, server_1, server_2 = two_nodes
await server_2.start_client(PeerInfo("::1", uint16(server_1._port)), None)
async def have_msgs():
await full_node_1.full_node_peers.address_manager.add_to_new_table(
[
TimestampedPeerInfo(
"127.0.0.1", uint16(1000), uint64(int(time.time())) - 1000
),
],
None,
)
msgs = [
_
async for _ in full_node_2.request_peers_with_peer_info(
fnp.RequestPeers(), PeerInfo("::1", server_2._port)
)
]
if not (len(msgs) > 0 and len(msgs[0].message.data.peer_list) == 1):
return False
for peer in msgs[0].message.data.peer_list:
return peer.host == "127.0.0.1" and peer.port == 1000
await time_out_assert(10, have_msgs, True)
full_node_1.full_node_peers.address_manager = AddressManager()
@pytest.mark.asyncio
async def test_new_tip(self, two_nodes, wallet_blocks):
full_node_1, full_node_2, server_1, server_2 = two_nodes
@ -793,32 +822,6 @@ class TestFullNodeProtocol:
]
assert len(msgs) == 0
"""
This test will be added back soon.
@pytest.mark.asyncio
async def test_request_peers(self, two_nodes, wallet_blocks):
full_node_1, full_node_2, server_1, server_2 = two_nodes
wallet_a, wallet_receiver, blocks = wallet_blocks
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
async def num_connections():
return len(full_node_1.global_connections.get_connections())
await time_out_assert(10, num_connections, 1)
async def have_msgs():
msgs = [
_
async for _ in full_node_1.request_peers(
introducer_protocol.RequestPeers()
)
]
return len(msgs) > 0 and len(msgs[0].message.data.peer_list) > 0
await time_out_assert(10, have_msgs, True)
"""
class TestWalletProtocol:
@pytest.mark.asyncio

View File

@ -73,13 +73,11 @@ async def setup_full_node(
config = load_config(bt.root_path, "config.yaml", "full_node")
config["database_path"] = db_name
config["send_uncompact_interval"] = send_uncompact_interval
periodic_introducer_poll = None
config["peer_connect_interval"] = 3
config["introducer_peer"]["host"] = "::1"
if introducer_port is not None:
periodic_introducer_poll = (
PeerInfo(self_hostname, introducer_port),
30,
config["target_peer_count"],
)
config["introducer_peer"]["port"] = introducer_port
if not simulator:
api: FullNode = FullNode(
config=config,
@ -121,7 +119,6 @@ async def setup_full_node(
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
periodic_introducer_poll=periodic_introducer_poll,
parse_cli_args=False,
)
@ -169,13 +166,11 @@ async def setup_wallet_node(
consensus_constants=consensus_constants,
name="wallet1",
)
periodic_introducer_poll = None
config["introducer_peer"]["host"] = "::1"
if introducer_port is not None:
periodic_introducer_poll = (
PeerInfo(self_hostname, introducer_port),
30,
config["target_peer_count"],
)
config["introducer_peer"]["port"] = introducer_port
config["peer_connect_interval"] = 10
connect_peers: List[PeerInfo] = []
if full_node_port is not None:
connect_peers = [PeerInfo(self_hostname, full_node_port)]
@ -206,7 +201,6 @@ async def setup_wallet_node(
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
periodic_introducer_poll=periodic_introducer_poll,
parse_cli_args=False,
)

View File

@ -7,7 +7,14 @@ from src.protocols import full_node_protocol
from src.simulator.simulator_protocol import FarmNewBlockProtocol, ReorgProtocol
from src.types.peer_info import PeerInfo
from src.util.ints import uint16, uint32
from tests.setup_nodes import setup_simulators_and_wallets
from tests.setup_nodes import (
setup_simulators_and_wallets,
test_constants,
setup_full_node,
setup_wallet_node,
setup_introducer,
_teardown_nodes,
)
from src.consensus.block_rewards import calculate_base_fee, calculate_block_reward
from tests.time_out_assert import time_out_assert, time_out_assert_not_None
@ -298,6 +305,41 @@ class TestWalletSimulator:
await time_out_assert(5, wallet_0.get_unconfirmed_balance, new_funds - 5)
await time_out_assert(5, wallet_1.get_confirmed_balance, 5)
@pytest.mark.asyncio
async def test_wallet_finds_full_node(self):
node_iters = [
setup_full_node(
test_constants,
"blockchain_test.db",
11234,
introducer_port=11236,
simulator=False,
),
setup_wallet_node(
11235,
test_constants,
None,
introducer_port=11236,
),
setup_introducer(11236),
]
full_node, s1 = await node_iters[0].__anext__()
wallet, s2 = await node_iters[1].__anext__()
introducer, introducer_server = await node_iters[2].__anext__()
async def has_full_node():
return (
wallet.wallet_peers.global_connections.count_outbound_connections() > 0
)
await time_out_assert(
2 * 60,
has_full_node,
True,
)
await _teardown_nodes(node_iters)
# @pytest.mark.asyncio
# async def test_wallet_make_transaction_with_fee(self, two_wallet_nodes):
# num_blocks = 5