Refactor Seeder & Crawler code + add tests (#15781)

* Cleanup seeder & mypy all files

holy crap mother of all tech debt this was horrid

* Make UDP Protocol class a dataclass & separate out functions.

* add TCP protocol class to DNS server

* Fix mypy types & other cleanup

also fix a couple bugs

* Add seeder and crawler tests

* change log levels

* re add db lock timeout

oops

* add edns & edns tests

* fix repeated shutdown on close signal

* fix binding to use both ipv6 and ipv4

whyyyyyyy

* add ipv6 and ipv4 tests + add ipv4 if windows
This commit is contained in:
Jack Nelson 2023-08-11 06:25:19 -04:00 committed by GitHub
parent c9e6c9d5cc
commit 0d36874caa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 1082 additions and 468 deletions

View File

@ -1,37 +1,34 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import dataclasses
import ipaddress import ipaddress
import logging import logging
import random import random
import time import time
from dataclasses import dataclass, field, replace
from typing import Dict, List from typing import Dict, List
import aiosqlite import aiosqlite
from chia.seeder.peer_record import PeerRecord, PeerReliability from chia.seeder.peer_record import PeerRecord, PeerReliability
from chia.util.ints import uint64
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@dataclass
class CrawlStore: class CrawlStore:
crawl_db: aiosqlite.Connection crawl_db: aiosqlite.Connection
last_timestamp: int host_to_records: Dict[str, PeerRecord] = field(default_factory=dict) # peer_id: PeerRecord
lock: asyncio.Lock host_to_selected_time: Dict[str, float] = field(default_factory=dict) # peer_id: timestamp (as a float)
host_to_reliability: Dict[str, PeerReliability] = field(default_factory=dict) # peer_id: PeerReliability
host_to_records: Dict banned_peers: int = 0
host_to_selected_time: Dict ignored_peers: int = 0
host_to_reliability: Dict reliable_peers: int = 0
banned_peers: int
ignored_peers: int
reliable_peers: int
@classmethod @classmethod
async def create(cls, connection: aiosqlite.Connection): async def create(cls, connection: aiosqlite.Connection) -> CrawlStore:
self = cls() self = cls(connection)
self.crawl_db = connection
await self.crawl_db.execute( await self.crawl_db.execute(
( (
"CREATE TABLE IF NOT EXISTS peer_records(" "CREATE TABLE IF NOT EXISTS peer_records("
@ -82,21 +79,16 @@ class CrawlStore:
await self.crawl_db.execute("CREATE INDEX IF NOT EXISTS ignore_till on peer_reliability(ignore_till)") await self.crawl_db.execute("CREATE INDEX IF NOT EXISTS ignore_till on peer_reliability(ignore_till)")
await self.crawl_db.commit() await self.crawl_db.commit()
self.last_timestamp = 0
self.ignored_peers = 0
self.banned_peers = 0
self.reliable_peers = 0
self.host_to_selected_time = {}
await self.unload_from_db() await self.unload_from_db()
return self return self
def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability): def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability) -> None:
if peer_record.peer_id not in self.host_to_records: if peer_record.peer_id not in self.host_to_records:
self.host_to_records[peer_record.peer_id] = peer_record self.host_to_records[peer_record.peer_id] = peer_record
if peer_reliability.peer_id not in self.host_to_reliability: if peer_reliability.peer_id not in self.host_to_reliability:
self.host_to_reliability[peer_reliability.peer_id] = peer_reliability self.host_to_reliability[peer_reliability.peer_id] = peer_reliability
async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False): async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False) -> None:
if not save_db: if not save_db:
self.host_to_records[peer_record.peer_id] = peer_record self.host_to_records[peer_record.peer_id] = peer_record
self.host_to_reliability[peer_reliability.peer_id] = peer_reliability self.host_to_reliability[peer_reliability.peer_id] = peer_reliability
@ -152,41 +144,41 @@ class CrawlStore:
async def get_peer_reliability(self, peer_id: str) -> PeerReliability: async def get_peer_reliability(self, peer_id: str) -> PeerReliability:
return self.host_to_reliability[peer_id] return self.host_to_reliability[peer_id]
async def peer_failed_to_connect(self, peer: PeerRecord): async def peer_failed_to_connect(self, peer: PeerRecord) -> None:
now = int(time.time()) now = uint64(time.time())
age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp)) age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp))
if age_timestamp == 0: if age_timestamp == 0:
age_timestamp = now - 1000 age_timestamp = now - 1000
replaced = dataclasses.replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now) replaced = replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now)
reliability = await self.get_peer_reliability(peer.peer_id) reliability = await self.get_peer_reliability(peer.peer_id)
if reliability is None: if reliability is None:
reliability = PeerReliability(peer.peer_id) reliability = PeerReliability(peer.peer_id)
reliability.update(False, now - age_timestamp) reliability.update(False, now - age_timestamp)
await self.add_peer(replaced, reliability) await self.add_peer(replaced, reliability)
async def peer_connected(self, peer: PeerRecord, tls_version: str): async def peer_connected(self, peer: PeerRecord, tls_version: str) -> None:
now = int(time.time()) now = uint64(time.time())
age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp)) age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp))
if age_timestamp == 0: if age_timestamp == 0:
age_timestamp = now - 1000 age_timestamp = now - 1000
replaced = dataclasses.replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version) replaced = replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version)
reliability = await self.get_peer_reliability(peer.peer_id) reliability = await self.get_peer_reliability(peer.peer_id)
if reliability is None: if reliability is None:
reliability = PeerReliability(peer.peer_id) reliability = PeerReliability(peer.peer_id)
reliability.update(True, now - age_timestamp) reliability.update(True, now - age_timestamp)
await self.add_peer(replaced, reliability) await self.add_peer(replaced, reliability)
async def update_best_timestamp(self, host: str, timestamp): async def update_best_timestamp(self, host: str, timestamp: uint64) -> None:
if host not in self.host_to_records: if host not in self.host_to_records:
return return
record = self.host_to_records[host] record = self.host_to_records[host]
replaced = dataclasses.replace(record, best_timestamp=timestamp) replaced = replace(record, best_timestamp=timestamp)
if host not in self.host_to_reliability: if host not in self.host_to_reliability:
return return
reliability = self.host_to_reliability[host] reliability = self.host_to_reliability[host]
await self.add_peer(replaced, reliability) await self.add_peer(replaced, reliability)
async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown"): async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown") -> None:
if host not in self.host_to_records: if host not in self.host_to_records:
return return
record = self.host_to_records[host] record = self.host_to_records[host]
@ -195,7 +187,7 @@ class CrawlStore:
else: else:
await self.peer_failed_to_connect(record) await self.peer_failed_to_connect(record)
async def get_peers_to_crawl(self, min_batch_size, max_batch_size) -> List[PeerRecord]: async def get_peers_to_crawl(self, min_batch_size: int, max_batch_size: int) -> List[PeerRecord]:
now = int(time.time()) now = int(time.time())
records = [] records = []
records_v6 = [] records_v6 = []
@ -270,20 +262,20 @@ class CrawlStore:
def get_reliable_peers(self) -> int: def get_reliable_peers(self) -> int:
return self.reliable_peers return self.reliable_peers
async def load_to_db(self): async def load_to_db(self) -> None:
log.error("Saving peers to DB...") log.info("Saving peers to DB...")
for peer_id in list(self.host_to_reliability.keys()): for peer_id in list(self.host_to_reliability.keys()):
if peer_id in self.host_to_reliability and peer_id in self.host_to_records: if peer_id in self.host_to_reliability and peer_id in self.host_to_records:
reliability = self.host_to_reliability[peer_id] reliability = self.host_to_reliability[peer_id]
record = self.host_to_records[peer_id] record = self.host_to_records[peer_id]
await self.add_peer(record, reliability, True) await self.add_peer(record, reliability, True)
await self.crawl_db.commit() await self.crawl_db.commit()
log.error(" - Done saving peers to DB") log.info(" - Done saving peers to DB")
async def unload_from_db(self): async def unload_from_db(self) -> None:
self.host_to_records = {} self.host_to_records = {}
self.host_to_reliability = {} self.host_to_reliability = {}
log.error("Loading peer reliability records...") log.info("Loading peer reliability records...")
cursor = await self.crawl_db.execute( cursor = await self.crawl_db.execute(
"SELECT * from peer_reliability", "SELECT * from peer_reliability",
) )
@ -312,9 +304,9 @@ class CrawlStore:
row[18], row[18],
row[19], row[19],
) )
self.host_to_reliability[row[0]] = reliability self.host_to_reliability[reliability.peer_id] = reliability
log.error(" - Done loading peer reliability records...") log.info(" - Done loading peer reliability records...")
log.error("Loading peer records...") log.info("Loading peer records...")
cursor = await self.crawl_db.execute( cursor = await self.crawl_db.execute(
"SELECT * from peer_records", "SELECT * from peer_records",
) )
@ -324,24 +316,23 @@ class CrawlStore:
peer = PeerRecord( peer = PeerRecord(
row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10], row[11] row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10], row[11]
) )
self.host_to_records[row[0]] = peer self.host_to_records[peer.peer_id] = peer
log.error(" - Done loading peer records...") log.info(" - Done loading peer records...")
# Crawler -> DNS. # Crawler -> DNS.
async def load_reliable_peers_to_db(self): async def load_reliable_peers_to_db(self) -> None:
peers = [] peers = []
for peer_id in self.host_to_reliability: for peer_id, reliability in self.host_to_reliability.items():
reliability = self.host_to_reliability[peer_id]
if reliability.is_reliable(): if reliability.is_reliable():
peers.append(peer_id) peers.append(peer_id)
self.reliable_peers = len(peers) self.reliable_peers = len(peers)
log.error("Deleting old good_peers from DB...") log.info("Deleting old good_peers from DB...")
cursor = await self.crawl_db.execute( cursor = await self.crawl_db.execute(
"DELETE from good_peers", "DELETE from good_peers",
) )
await cursor.close() await cursor.close()
log.error(" - Done deleting old good_peers...") log.info(" - Done deleting old good_peers...")
log.error("Saving new good_peers to DB...") log.info("Saving new good_peers to DB...")
for peer in peers: for peer in peers:
cursor = await self.crawl_db.execute( cursor = await self.crawl_db.execute(
"INSERT OR REPLACE INTO good_peers VALUES(?)", "INSERT OR REPLACE INTO good_peers VALUES(?)",
@ -349,9 +340,9 @@ class CrawlStore:
) )
await cursor.close() await cursor.close()
await self.crawl_db.commit() await self.crawl_db.commit()
log.error(" - Done saving new good_peers to DB...") log.info(" - Done saving new good_peers to DB...")
def load_host_to_version(self): def load_host_to_version(self) -> tuple[dict[str, str], dict[str, uint64]]:
versions = {} versions = {}
handshake = {} handshake = {}
@ -366,19 +357,30 @@ class CrawlStore:
versions[host] = record.version versions[host] = record.version
handshake[host] = record.handshake_time handshake[host] = record.handshake_time
return (versions, handshake) return versions, handshake
def load_best_peer_reliability(self): def load_best_peer_reliability(self) -> dict[str, uint64]:
best_timestamp = {} best_timestamp = {}
for host, record in self.host_to_records.items(): for host, record in self.host_to_records.items():
if record.best_timestamp > time.time() - 5 * 24 * 3600: if record.best_timestamp > time.time() - 5 * 24 * 3600:
best_timestamp[host] = record.best_timestamp best_timestamp[host] = record.best_timestamp
return best_timestamp return best_timestamp
async def update_version(self, host, version, now): async def update_version(self, host: str, version: str, timestamp_now: uint64) -> None:
record = self.host_to_records.get(host, None) record = self.host_to_records.get(host, None)
reliability = self.host_to_reliability.get(host, None) reliability = self.host_to_reliability.get(host, None)
if record is None or reliability is None: if record is None or reliability is None:
return return
record.update_version(version, now) record.update_version(version, timestamp_now)
await self.add_peer(record, reliability) await self.add_peer(record, reliability)
async def get_good_peers(self) -> list[str]: # This is for the DNS server
cursor = await self.crawl_db.execute(
"SELECT * from good_peers",
)
rows = await cursor.fetchall()
await cursor.close()
result = [row[0] for row in rows]
if len(result) > 0:
random.shuffle(result) # mix up the peers
return result

View File

@ -6,15 +6,16 @@ import logging
import time import time
import traceback import traceback
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple
import aiosqlite import aiosqlite
from chia.consensus.constants import ConsensusConstants from chia.consensus.constants import ConsensusConstants
from chia.full_node.coin_store import CoinStore
from chia.full_node.full_node_api import FullNodeAPI from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols import full_node_protocol from chia.protocols import full_node_protocol
from chia.protocols.full_node_protocol import RespondPeers
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.seeder.crawl_store import CrawlStore from chia.seeder.crawl_store import CrawlStore
from chia.seeder.peer_record import PeerRecord, PeerReliability from chia.seeder.peer_record import PeerRecord, PeerReliability
@ -29,20 +30,28 @@ from chia.util.path import path_from_root
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@dataclass
class Crawler: class Crawler:
sync_store: Any config: Dict[str, Any]
coin_store: CoinStore
connection: Optional[aiosqlite.Connection]
config: Dict
_server: Optional[ChiaServer]
crawl_store: Optional[CrawlStore]
log: logging.Logger
constants: ConsensusConstants
_shut_down: bool
root_path: Path root_path: Path
peer_count: int constants: ConsensusConstants
with_peak: set print_status: bool = True
minimum_version_count: int state_changed_callback: Optional[StateChangedProtocol] = None
_server: Optional[ChiaServer] = None
crawl_task: Optional[asyncio.Task[None]] = None
crawl_store: Optional[CrawlStore] = None
log: logging.Logger = log
_shut_down: bool = False
peer_count: int = 0
with_peak: Set[PeerInfo] = field(default_factory=set)
seen_nodes: Set[str] = field(default_factory=set)
minimum_version_count: int = 0
peers_retrieved: List[RespondPeers] = field(default_factory=list)
host_to_version: Dict[str, str] = field(default_factory=dict)
versions: Dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
version_cache: List[Tuple[str, str]] = field(default_factory=list)
handshake_time: Dict[str, uint64] = field(default_factory=dict)
best_timestamp_per_peer: Dict[str, uint64] = field(default_factory=dict)
@property @property
def server(self) -> ChiaServer: def server(self) -> ChiaServer:
@ -53,38 +62,16 @@ class Crawler:
return self._server return self._server
def __init__( def __post_init__(self) -> None:
self, # get db path
config: Dict, crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db")
root_path: Path, self.db_path = path_from_root(self.root_path, crawler_db_path)
consensus_constants: ConsensusConstants,
name: str = None,
):
self.initialized = False
self.root_path = root_path
self.config = config
self.connection = None
self._server = None
self._shut_down = False # Set to true to close all infinite loops
self.constants = consensus_constants
self.state_changed_callback: Optional[StateChangedProtocol] = None
self.crawl_store = None
self.log = log
self.peer_count = 0
self.with_peak = set()
self.peers_retrieved: List[Any] = []
self.host_to_version: Dict[str, str] = {}
self.version_cache: List[Tuple[str, str]] = []
self.handshake_time: Dict[str, int] = {}
self.best_timestamp_per_peer: Dict[str, int] = {}
crawler_db_path: str = config.get("crawler_db_path", "crawler.db")
self.db_path = path_from_root(root_path, crawler_db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True) self.db_path.parent.mkdir(parents=True, exist_ok=True)
self.bootstrap_peers = config["bootstrap_peers"] # load data from config
self.minimum_height = config["minimum_height"] self.bootstrap_peers = self.config["bootstrap_peers"]
self.other_peers_port = config["other_peers_port"] self.minimum_height = self.config["minimum_height"]
self.versions: Dict[str, int] = defaultdict(lambda: 0) self.other_peers_port = self.config["other_peers_port"]
self.minimum_version_count = self.config.get("minimum_version_count", 100) self.minimum_version_count: int = self.config.get("minimum_version_count", 100)
if self.minimum_version_count < 1: if self.minimum_version_count < 1:
self.log.warning( self.log.warning(
f"Crawler configuration minimum_version_count expected to be greater than zero: " f"Crawler configuration minimum_version_count expected to be greater than zero: "
@ -97,11 +84,16 @@ class Crawler:
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]: def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type) return default_get_connections(server=self.server, request_node_type=request_node_type)
async def create_client(self, peer_info, on_connect): async def create_client(
self, peer_info: PeerInfo, on_connect: Callable[[WSChiaConnection], Awaitable[None]]
) -> bool:
return await self.server.start_client(peer_info, on_connect) return await self.server.start_client(peer_info, on_connect)
async def connect_task(self, peer): async def connect_task(self, peer: PeerRecord) -> None:
async def peer_action(peer: WSChiaConnection): if self.crawl_store is None:
raise ValueError("Not Connected to DB")
async def peer_action(peer: WSChiaConnection) -> None:
peer_info = peer.get_peer_info() peer_info = peer.get_peer_info()
version = peer.get_version() version = peer.get_version()
if peer_info is not None and version is not None: if peer_info is not None and version is not None:
@ -134,41 +126,34 @@ class Crawler:
if not connected: if not connected:
await self.crawl_store.peer_failed_to_connect(peer) await self.crawl_store.peer_failed_to_connect(peer)
except Exception as e: except Exception as e:
self.log.info(f"Exception: {e}. Traceback: {traceback.format_exc()}.") self.log.warning(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
await self.crawl_store.peer_failed_to_connect(peer) await self.crawl_store.peer_failed_to_connect(peer)
async def _start(self): async def _start(self) -> None:
# We override the default peer_connect_timeout when running from the crawler # We override the default peer_connect_timeout when running from the crawler
crawler_peer_timeout = self.config.get("peer_connect_timeout", 2) crawler_peer_timeout = self.config.get("peer_connect_timeout", 2)
self.server.config["peer_connect_timeout"] = crawler_peer_timeout self.server.config["peer_connect_timeout"] = crawler_peer_timeout
self.task = asyncio.create_task(self.crawl()) # Connect to the DB
self.crawl_store: CrawlStore = await CrawlStore.create(await aiosqlite.connect(self.db_path))
async def crawl(self): # Bootstrap the initial peers
# Ensure the state_changed callback is set up before moving on await self.load_bootstrap_peers()
# Sometimes, the daemon connection + state changed callback isn't up and ready self.crawl_task = asyncio.create_task(self.crawl())
# by the time we get to the first _state_changed call, so this just ensures it's there before moving on
while self.state_changed_callback is None:
self.log.info("Waiting for state changed callback...")
await asyncio.sleep(0.1)
async def load_bootstrap_peers(self) -> None:
assert self.crawl_store is not None
try: try:
self.connection = await aiosqlite.connect(self.db_path) self.log.warning("Bootstrapping initial peers...")
self.crawl_store = await CrawlStore.create(self.connection)
self.log.info("Started")
t_start = time.time() t_start = time.time()
total_nodes = 0
self.seen_nodes = set()
tried_nodes = set()
for peer in self.bootstrap_peers: for peer in self.bootstrap_peers:
new_peer = PeerRecord( new_peer = PeerRecord(
peer, peer,
peer, peer,
self.other_peers_port, self.other_peers_port,
False, False,
0, uint64(0),
0, uint32(0),
0, uint64(0),
uint64(int(time.time())), uint64(int(time.time())),
uint64(0), uint64(0),
"undefined", "undefined",
@ -180,14 +165,26 @@ class Crawler:
self.host_to_version, self.handshake_time = self.crawl_store.load_host_to_version() self.host_to_version, self.handshake_time = self.crawl_store.load_host_to_version()
self.best_timestamp_per_peer = self.crawl_store.load_best_peer_reliability() self.best_timestamp_per_peer = self.crawl_store.load_best_peer_reliability()
self.versions = defaultdict(lambda: 0)
for host, version in self.host_to_version.items(): for host, version in self.host_to_version.items():
self.versions[version] += 1 self.versions[version] += 1
self._state_changed("loaded_initial_peers") self.log.warning(f"Bootstrapped initial peers in {time.time() - t_start} seconds")
except Exception as e:
self.log.error(f"Error bootstrapping initial peers: {e}")
while True: async def crawl(self) -> None:
self.with_peak = set() # Ensure the state_changed callback is set up before moving on
# Sometimes, the daemon connection + state changed callback isn't up and ready
# by the time we get to the first _state_changed call, so this just ensures it's there before moving on
while self.state_changed_callback is None:
self.log.info("Waiting for state changed callback...")
await asyncio.sleep(0.1)
assert self.crawl_store is not None
t_start = time.time()
total_nodes = 0
tried_nodes = set()
try:
while not self._shut_down:
peers_to_crawl = await self.crawl_store.get_peers_to_crawl(25000, 250000) peers_to_crawl = await self.crawl_store.get_peers_to_crawl(25000, 250000)
tasks = set() tasks = set()
for peer in peers_to_crawl: for peer in peers_to_crawl:
@ -200,7 +197,6 @@ class Crawler:
if len(tasks) >= 250: if len(tasks) >= 250:
await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
tasks = set(filter(lambda t: not t.done(), tasks)) tasks = set(filter(lambda t: not t.done(), tasks))
if len(tasks) > 0: if len(tasks) > 0:
await asyncio.wait(tasks, timeout=30) await asyncio.wait(tasks, timeout=30)
@ -231,16 +227,15 @@ class Crawler:
tls_version="unknown", tls_version="unknown",
) )
new_peer_reliability = PeerReliability(response_peer.host) new_peer_reliability = PeerReliability(response_peer.host)
if self.crawl_store is not None: self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability)
self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability)
await self.crawl_store.update_best_timestamp( await self.crawl_store.update_best_timestamp(
response_peer.host, response_peer.host,
self.best_timestamp_per_peer[response_peer.host], self.best_timestamp_per_peer[response_peer.host],
) )
for host, version in self.version_cache: for host, version in self.version_cache:
self.handshake_time[host] = int(time.time()) self.handshake_time[host] = uint64(time.time())
self.host_to_version[host] = version self.host_to_version[host] = version
await self.crawl_store.update_version(host, version, int(time.time())) await self.crawl_store.update_version(host, version, uint64(time.time()))
to_remove = set() to_remove = set()
now = int(time.time()) now = int(time.time())
@ -264,94 +259,46 @@ class Crawler:
self.versions = defaultdict(lambda: 0) self.versions = defaultdict(lambda: 0)
for host, version in self.host_to_version.items(): for host, version in self.host_to_version.items():
self.versions[version] += 1 self.versions[version] += 1
self.version_cache = []
self.peers_retrieved = []
# clear caches
self.version_cache: List[Tuple[str, str]] = []
self.peers_retrieved = []
self.server.banned_peers = {} self.server.banned_peers = {}
self.with_peak = set()
if len(peers_to_crawl) == 0: if len(peers_to_crawl) == 0:
continue continue
# Try up to 5 times to write to the DB in case there is a lock that causes a timeout await self.save_to_db()
for i in range(1, 5): await self.print_summary(t_start, total_nodes, tried_nodes)
try: await asyncio.sleep(15) # 15 seconds between db updates
await self.crawl_store.load_to_db()
await self.crawl_store.load_reliable_peers_to_db()
except Exception as e:
self.log.error(f"Exception while saving to DB: {e}.")
self.log.error("Waiting 5 seconds before retry...")
await asyncio.sleep(5)
continue
break
total_records = self.crawl_store.get_total_records()
ipv6_count = self.crawl_store.get_ipv6_peers()
self.log.error("***")
self.log.error("Finished batch:")
self.log.error(f"Total IPs stored in DB: {total_records}.")
self.log.error(f"Total IPV6 addresses stored in DB: {ipv6_count}")
self.log.error(f"Total connections attempted since crawler started: {total_nodes}.")
self.log.error(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.")
t_now = time.time()
t_delta = int(t_now - t_start)
if t_delta > 0:
self.log.error(f"Avg connections per second: {total_nodes // t_delta}.")
# Periodically print detailed stats.
reliable_peers = self.crawl_store.get_reliable_peers()
self.log.error(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}")
banned_peers = self.crawl_store.get_banned_peers()
ignored_peers = self.crawl_store.get_ignored_peers()
available_peers = len(self.host_to_version)
addresses_count = len(self.best_timestamp_per_peer)
total_records = self.crawl_store.get_total_records()
ipv6_addresses_count = 0
for host in self.best_timestamp_per_peer.keys():
try:
ipaddress.IPv6Address(host)
ipv6_addresses_count += 1
except ipaddress.AddressValueError:
continue
self.log.error(
"IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{addresses_count - ipv6_addresses_count}."
)
self.log.error(
"IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{ipv6_addresses_count}."
)
ipv6_available_peers = 0
for host in self.host_to_version.keys():
try:
ipaddress.IPv6Address(host)
ipv6_available_peers += 1
except ipaddress.AddressValueError:
continue
self.log.error(
f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}."
)
self.log.error(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.")
self.log.error("Version distribution among reachable in the last 5 days (at least 100 nodes):")
for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True):
if count >= self.minimum_version_count:
self.log.error(f"Version: {version} - Count: {count}")
self.log.error(f"Banned addresses in the DB: {banned_peers}")
self.log.error(f"Temporary ignored addresses in the DB: {ignored_peers}")
self.log.error(
"Peers to crawl from in the next batch (total IPs - ignored - banned): "
f"{total_records - banned_peers - ignored_peers}"
)
self.log.error("***")
self._state_changed("crawl_batch_completed") self._state_changed("crawl_batch_completed")
except Exception as e: except Exception as e:
self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
def set_server(self, server: ChiaServer): async def save_to_db(self) -> None:
# Try up to 5 times to write to the DB in case there is a lock that causes a timeout
if self.crawl_store is None:
raise ValueError("Not Connected to DB")
for i in range(1, 5):
try:
await self.crawl_store.load_to_db()
await self.crawl_store.load_reliable_peers_to_db()
return
except Exception as e:
self.log.error(f"Exception while saving to DB: {e}.")
self.log.error("Waiting 5 seconds before retry...")
await asyncio.sleep(5)
continue
def set_server(self, server: ChiaServer) -> None:
self._server = server self._server = server
def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None): def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None:
if self.state_changed_callback is not None: if self.state_changed_callback is not None:
self.state_changed_callback(change, change_data) self.state_changed_callback(change, change_data)
async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection): async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection) -> None:
try: try:
peer_info = peer.get_peer_info() peer_info = peer.get_peer_info()
tls_version = peer.get_tls_version() tls_version = peer.get_tls_version()
@ -359,6 +306,11 @@ class Crawler:
tls_version = "unknown" tls_version = "unknown"
if peer_info is None: if peer_info is None:
return return
# validate peer ip address:
try:
ipaddress.ip_address(peer_info.host)
except ValueError:
raise ValueError(f"Invalid peer ip address: {peer_info.host}")
if request.height >= self.minimum_height: if request.height >= self.minimum_height:
if self.crawl_store is not None: if self.crawl_store is not None:
await self.crawl_store.peer_connected_hostname(peer_info.host, True, tls_version) await self.crawl_store.peer_connected_hostname(peer_info.host, True, tls_version)
@ -366,12 +318,79 @@ class Crawler:
except Exception as e: except Exception as e:
self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
async def on_connect(self, connection: WSChiaConnection): async def on_connect(self, connection: WSChiaConnection) -> None:
pass pass
def _close(self): def _close(self) -> None:
self._shut_down = True self._shut_down = True
async def _await_closed(self): async def _await_closed(self) -> None:
if self.connection is not None: if self.crawl_task is not None:
await self.connection.close() try:
await asyncio.wait_for(self.crawl_task, timeout=10) # wait 10 seconds before giving up
except asyncio.TimeoutError:
self.log.error("Crawl task did not exit in time, killing task.")
self.crawl_task.cancel()
if self.crawl_store is not None:
self.log.info("Closing connection to DB.")
await self.crawl_store.crawl_db.close()
async def print_summary(self, t_start: float, total_nodes: int, tried_nodes: Set[str]) -> None:
assert self.crawl_store is not None # this is only ever called from the crawl task
if not self.print_status:
return
total_records = self.crawl_store.get_total_records()
ipv6_count = self.crawl_store.get_ipv6_peers()
self.log.warning("***")
self.log.warning("Finished batch:")
self.log.warning(f"Total IPs stored in DB: {total_records}.")
self.log.warning(f"Total IPV6 addresses stored in DB: {ipv6_count}")
self.log.warning(f"Total connections attempted since crawler started: {total_nodes}.")
self.log.warning(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.")
t_now = time.time()
t_delta = int(t_now - t_start)
if t_delta > 0:
self.log.warning(f"Avg connections per second: {total_nodes // t_delta}.")
# Periodically print detailed stats.
reliable_peers = self.crawl_store.get_reliable_peers()
self.log.warning(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}")
banned_peers = self.crawl_store.get_banned_peers()
ignored_peers = self.crawl_store.get_ignored_peers()
available_peers = len(self.host_to_version)
addresses_count = len(self.best_timestamp_per_peer)
total_records = self.crawl_store.get_total_records()
ipv6_addresses_count = 0
for host in self.best_timestamp_per_peer.keys():
try:
ipaddress.IPv6Address(host)
ipv6_addresses_count += 1
except ipaddress.AddressValueError:
continue
self.log.warning(
"IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{addresses_count - ipv6_addresses_count}."
)
self.log.warning(
"IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: "
f"{ipv6_addresses_count}."
)
ipv6_available_peers = 0
for host in self.host_to_version.keys():
try:
ipaddress.IPv6Address(host)
ipv6_available_peers += 1
except ipaddress.AddressValueError:
continue
self.log.warning(f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}.")
self.log.warning(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.")
self.log.warning("Version distribution among reachable in the last 5 days (at least 100 nodes):")
for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True):
if count >= self.minimum_version_count:
self.log.warning(f"Version: {version} - Count: {count}")
self.log.warning(f"Banned addresses in the DB: {banned_peers}")
self.log.warning(f"Temporary ignored addresses in the DB: {ignored_peers}")
self.log.warning(
"Peers to crawl from in the next batch (total IPs - ignored - banned): "
f"{total_records - banned_peers - ignored_peers}"
)
self.log.warning("***")

View File

@ -1,294 +1,534 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import ipaddress
import logging import logging
import random
import signal import signal
import sys
import traceback import traceback
from asyncio import AbstractEventLoop
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from ipaddress import IPv4Address, IPv6Address, ip_address
from multiprocessing import freeze_support
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional
import aiosqlite import aiosqlite
from dnslib import AAAA, CNAME, MX, NS, QTYPE, RR, SOA, A, DNSHeader, DNSRecord from dnslib import AAAA, EDNS0, NS, QTYPE, RCODE, RD, RR, SOA, A, DNSError, DNSHeader, DNSQuestion, DNSRecord
from chia.util.chia_logging import initialize_logging from chia.seeder.crawl_store import CrawlStore
from chia.util.config import load_config from chia.util.chia_logging import initialize_service_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.path import path_from_root from chia.util.path import path_from_root
SERVICE_NAME = "seeder" SERVICE_NAME = "seeder"
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
DnsCallback = Callable[[DNSRecord], Awaitable[DNSRecord]]
# DNS snippet taken from: https://gist.github.com/pklaus/b5a7876d4d2cf7271873 # DNS snippet taken from: https://gist.github.com/pklaus/b5a7876d4d2cf7271873
class DomainName(str): class DomainName(str):
def __getattr__(self, item): def __getattr__(self, item: str) -> DomainName:
return DomainName(item + "." + self) return DomainName(item + "." + self) # DomainName.NS becomes DomainName("NS.DomainName")
D = None @dataclass(frozen=True)
ns = None class PeerList:
IP = "127.0.0.1" ipv4: List[IPv4Address]
TTL = None ipv6: List[IPv6Address]
soa_record = None
ns_records: List[Any] = [] @property
def no_peers(self) -> bool:
return not self.ipv4 and not self.ipv6
class EchoServerProtocol(asyncio.DatagramProtocol): @dataclass
def __init__(self, callback): class UDPDNSServerProtocol(asyncio.DatagramProtocol):
self.data_queue = asyncio.Queue() """
self.callback = callback This is a really simple UDP Server, that converts all requests to DNSRecord objects and passes them to the callback.
asyncio.ensure_future(self.respond()) """
def connection_made(self, transport): callback: DnsCallback
self.transport = transport transport: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
data_queue: asyncio.Queue[tuple[DNSRecord, tuple[str, int]]] = field(default_factory=asyncio.Queue)
queue_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
def datagram_received(self, data, addr): def start(self) -> None:
asyncio.ensure_future(self.handler(data, addr)) self.queue_task = asyncio.create_task(self.respond()) # This starts the dns respond loop.
async def respond(self): async def stop(self) -> None:
while True: if self.queue_task is not None:
self.queue_task.cancel()
try: try:
resp, caller = await self.data_queue.get() await self.queue_task
self.transport.sendto(resp, caller) except asyncio.CancelledError: # we dont care
pass
if self.transport is not None:
self.transport.close()
def connection_made(self, transport: asyncio.BaseTransport) -> None:
# we use the #ignore because transport is a subclass of BaseTransport, but we need the real type.
self.transport = transport # type: ignore[assignment]
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
log.debug(f"Received UDP DNS request from {addr}.")
dns_request: Optional[DNSRecord] = parse_dns_request(data)
if dns_request is None: # Invalid Request, we can just drop it and move on.
return
asyncio.create_task(self.handler(dns_request, addr))
async def respond(self) -> None:
log.info("UDP DNS responder started.")
while self.transport is None: # we wait for the transport to be set.
await asyncio.sleep(0.1)
while not self.transport.is_closing():
try:
edns_max_size = 0
reply, caller = await self.data_queue.get()
if len(reply.ar) > 0 and reply.ar[0].rtype == QTYPE.OPT:
edns_max_size = reply.ar[0].edns_len
reply_packed = reply.pack()
if len(reply_packed) > max(512, edns_max_size): # 512 is the default max size for DNS:
log.debug(f"DNS response to {caller} is too large, truncating.")
reply_packed = reply.truncate().pack()
self.transport.sendto(reply_packed, caller)
log.debug(f"Sent UDP DNS response to {caller}, of size {len(reply_packed)}.")
except Exception as e: except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") log.error(f"Exception while responding to UDP DNS request: {e}. Traceback: {traceback.format_exc()}.")
log.info("UDP DNS responder stopped.")
async def handler(self, data, caller): async def handler(self, data: DNSRecord, caller: tuple[str, int]) -> None:
r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response.
await self.data_queue.put((r_data, caller))
@dataclass
class TCPDNSServerProtocol(asyncio.BufferedProtocol):
"""
This TCP server is a little more complicated, because we need to handle the length field, however it still
converts all requests to DNSRecord objects and passes them to the callback, after receiving the full message.
"""
callback: DnsCallback
transport: Optional[asyncio.Transport] = field(init=False, default=None)
peer_info: str = field(init=False, default="")
expected_length: int = 0
buffer: bytearray = field(init=False, default_factory=lambda: bytearray(2))
futures: List[asyncio.Future[None]] = field(init=False, default_factory=list)
def connection_made(self, transport: asyncio.BaseTransport) -> None:
"""
This is called whenever we get a new connection.
"""
# we use the #ignore because transport is a subclass of BaseTransport, but we need the real type.
self.transport = transport # type: ignore[assignment]
peer_info = transport.get_extra_info("peername")
self.peer_info = f"{peer_info[0]}:{peer_info[1]}"
log.debug(f"TCP connection established with {self.peer_info}.")
def connection_lost(self, exc: Optional[Exception]) -> None:
"""
This is called whenever a connection is lost, or closed.
"""
if exc is not None:
log.debug(f"TCP DNS connection lost with {self.peer_info}. Exception: {exc}.")
else:
log.debug(f"TCP DNS connection closed with {self.peer_info}.")
# reset the state of the protocol.
for future in self.futures:
future.cancel()
self.futures = []
self.buffer = bytearray(2)
self.expected_length = 0
def get_buffer(self, sizehint: int) -> memoryview:
"""
This is the first function called after connection_made, it returns a buffer that the tcp server will write to.
Once a buffer is written to, buffer_updated is called.
"""
return memoryview(self.buffer)
def buffer_updated(self, nbytes: int) -> None:
"""
This is called whenever the buffer is written to, and it loops through the buffer, grouping them into messages
and then dns records.
"""
while not len(self.buffer) == 0 and self.transport is not None:
if not self.expected_length:
# Length field received (This is the first part of the message)
self.expected_length = int.from_bytes(self.buffer, byteorder="big")
self.buffer = self.buffer[2:] # Remove the length field from the buffer.
if len(self.buffer) >= self.expected_length:
# This is the rest of the message (after the length field)
message = self.buffer[: self.expected_length]
self.buffer = self.buffer[self.expected_length :] # Remove the message from the buffer
self.expected_length = 0 # Reset the expected length
dns_request: Optional[DNSRecord] = parse_dns_request(message)
if dns_request is None: # Invalid Request, so we disconnect and don't send anything back.
self.transport.close()
return
self.futures.append(asyncio.create_task(self.handle_and_respond(dns_request)))
self.buffer = bytearray(2 if self.expected_length == 0 else self.expected_length) # Reset the buffer if empty.
def eof_received(self) -> Optional[bool]:
"""
This is called when the client closes the connection, False or None means we close the connection.
True means we keep the connection open.
"""
if len(self.futures) > 0: # Successful requests
if self.expected_length != 0: # Incomplete requests
log.warning(
f"Received incomplete TCP DNS request of length {self.expected_length} from {self.peer_info}, "
f"closing connection after dns replies are sent."
)
asyncio.create_task(self.wait_for_futures())
return True # Keep connection open, until the futures are done.
log.info(f"Received early EOF from {self.peer_info}, closing connection.")
return False
async def wait_for_futures(self) -> None:
"""
Waits for all the futures to complete, and then closes the connection.
"""
try: try:
data = await self.callback(data) await asyncio.wait_for(asyncio.gather(*self.futures), timeout=10)
if data is None: except asyncio.TimeoutError:
return log.warning(f"Timed out waiting for DNS replies to be sent to {self.peer_info}.")
await self.data_queue.put((data, caller)) if self.transport is not None:
self.transport.close()
async def handle_and_respond(self, data: DNSRecord) -> None:
r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response.
try:
# If the client closed the connection, we don't want to send anything.
if self.transport is not None and not self.transport.is_closing():
self.transport.write(dns_response_to_tcp(r_data)) # send data back to the client
log.debug(f"Sent DNS response for {data.q.qname}, to {self.peer_info}.")
except Exception as e: except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") log.error(f"Exception while responding to TCP DNS request: {e}. Traceback: {traceback.format_exc()}.")
def dns_response_to_tcp(data: DNSRecord) -> bytes:
"""
Converts a DNSRecord response to a TCP DNS response, by adding a 2 byte length field to the start.
"""
dns_response = data.pack()
dns_response_length = len(dns_response).to_bytes(2, byteorder="big")
return bytes(dns_response_length + dns_response)
def create_dns_reply(dns_request: DNSRecord) -> DNSRecord:
"""
Creates a DNS response with the correct header and section flags set.
"""
# QR means query response, AA means authoritative answer, RA means recursion available
return DNSRecord(DNSHeader(id=dns_request.header.id, qr=1, aa=1, ra=0), q=dns_request.q)
def parse_dns_request(data: bytes) -> Optional[DNSRecord]:
"""
Parses the DNS request, and returns a DNSRecord object, or None if the request is invalid.
"""
dns_request: Optional[DNSRecord] = None
try:
dns_request = DNSRecord.parse(data)
except DNSError as e:
log.warning(f"Received invalid DNS request: {e}. Traceback: {traceback.format_exc()}.")
return dns_request
async def get_dns_reply(callback: DnsCallback, dns_request: DNSRecord) -> DNSRecord:
"""
This function calls the callback, and returns SERVFAIL if the callback raises an exception.
"""
try:
dns_reply = await callback(dns_request)
except Exception as e:
log.error(f"Exception during DNS record processing: {e}. Traceback: {traceback.format_exc()}.")
# we return an empty response with an error code
dns_reply = create_dns_reply(dns_request) # This is an empty response, with only the header set.
dns_reply.header.rcode = RCODE.SERVFAIL
return dns_reply
@dataclass
class DNSServer: class DNSServer:
reliable_peers_v4: List[str] config: Dict[str, Any]
reliable_peers_v6: List[str] root_path: Path
lock: asyncio.Lock lock: asyncio.Lock = field(default_factory=asyncio.Lock)
pointer: int shutdown_event: asyncio.Event = field(default_factory=asyncio.Event)
crawl_db: aiosqlite.Connection crawl_store: Optional[CrawlStore] = field(init=False, default=None)
reliable_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
shutdown_task: Optional[asyncio.Task[None]] = field(init=False, default=None)
udp_transport_ipv4: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
udp_protocol_ipv4: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
udp_transport_ipv6: Optional[asyncio.DatagramTransport] = field(init=False, default=None)
udp_protocol_ipv6: Optional[UDPDNSServerProtocol] = field(init=False, default=None)
# TODO: After 3.10 is dropped change to asyncio.Server
tcp_server: Optional[asyncio.base_events.Server] = field(init=False, default=None)
# these are all set in __post_init__
dns_port: int = field(init=False)
db_path: Path = field(init=False)
domain: DomainName = field(init=False)
ns1: DomainName = field(init=False)
ns_records: List[RR] = field(init=False)
ttl: int = field(init=False)
soa_record: RR = field(init=False)
reliable_peers_v4: List[IPv4Address] = field(default_factory=list)
reliable_peers_v6: List[IPv6Address] = field(default_factory=list)
pointer_v4: int = 0
pointer_v6: int = 0
def __init__(self, config: Dict, root_path: Path): def __post_init__(self) -> None:
self.reliable_peers_v4 = [] """
self.reliable_peers_v6 = [] We initialize all the variables set to field(init=False) here.
self.lock = asyncio.Lock() """
self.pointer_v4 = 0 # From Config
self.pointer_v6 = 0 self.dns_port: int = self.config.get("dns_port", 53)
# DB Path
crawler_db_path: str = config.get("crawler_db_path", "crawler.db") crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db")
self.db_path = path_from_root(root_path, crawler_db_path) self.db_path: Path = path_from_root(self.root_path, crawler_db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True) self.db_path.parent.mkdir(parents=True, exist_ok=True)
# DNS info
async def start(self): self.domain: DomainName = DomainName(self.config["domain_name"])
# self.crawl_db = await aiosqlite.connect(self.db_path) if not self.domain.endswith("."):
# Get a reference to the event loop as we plan to use self.domain = DomainName(self.domain + ".") # Make sure the domain ends with a period, as per RFC 1035.
# low-level APIs. self.ns1: DomainName = DomainName(self.config["nameserver"])
loop = asyncio.get_running_loop() self.ns_records: List[NS] = [NS(self.ns1)]
self.ttl: int = self.config["ttl"]
# One protocol instance will be created to serve all self.soa_record: SOA = SOA(
# client requests. mname=self.ns1, # primary name server
self.transport, self.protocol = await loop.create_datagram_endpoint( rname=self.config["soa"]["rname"], # email of the domain administrator
lambda: EchoServerProtocol(self.dns_response), local_addr=("::0", 53) times=(
self.config["soa"]["serial_number"],
self.config["soa"]["refresh"],
self.config["soa"]["retry"],
self.config["soa"]["expire"],
self.config["soa"]["minimum"],
),
) )
@asynccontextmanager
async def run(self) -> AsyncIterator[None]:
log.info("Starting DNS server.")
# Get a reference to the event loop as we plan to use low-level APIs.
loop = asyncio.get_running_loop()
await self.setup_signal_handlers(loop)
# Set up the crawl store and the peer update task.
self.crawl_store = await CrawlStore.create(await aiosqlite.connect(self.db_path, timeout=120))
self.reliable_task = asyncio.create_task(self.periodically_get_reliable_peers()) self.reliable_task = asyncio.create_task(self.periodically_get_reliable_peers())
async def periodically_get_reliable_peers(self): # One protocol instance will be created for each udp transport, so that we can accept ipv4 and ipv6
self.udp_transport_ipv6, self.udp_protocol_ipv6 = await loop.create_datagram_endpoint(
lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("::0", self.dns_port)
)
self.udp_protocol_ipv6.start() # start ipv6 udp transmit task
# in case the port is 0 we need all protocols on the same port.
self.dns_port = self.udp_transport_ipv6.get_extra_info("sockname")[1] # get the actual port we are listening to
if sys.platform.startswith("win32") or sys.platform.startswith("cygwin"):
# Windows does not support dual stack sockets, so we need to create a new socket for ipv4.
self.udp_transport_ipv4, self.udp_protocol_ipv4 = await loop.create_datagram_endpoint(
lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("0.0.0.0", self.dns_port)
)
self.udp_protocol_ipv4.start() # start ipv4 udp transmit task
# One tcp server will handle both ipv4 and ipv6 on both linux and windows.
self.tcp_server = await loop.create_server(
lambda: TCPDNSServerProtocol(self.dns_response), ["::0", "0.0.0.0"], self.dns_port
)
log.info("DNS server started.")
try:
yield
finally: # catches any errors and properly shuts down the server
await self.stop()
log.info("DNS server stopped.")
async def setup_signal_handlers(self, loop: AbstractEventLoop) -> None:
try:
loop.add_signal_handler(signal.SIGINT, self._accept_signal)
loop.add_signal_handler(signal.SIGTERM, self._accept_signal)
except NotImplementedError:
log.warning("signal handlers unsupported on this platform")
def _accept_signal(self) -> None: # pragma: no cover
if self.shutdown_task is None: # otherwise we are already shutting down, so we ignore the signal
self.shutdown_task = asyncio.create_task(self.stop())
async def stop(self) -> None:
log.info("Stopping DNS server...")
if self.reliable_task is not None:
self.reliable_task.cancel() # cancel the peer update task
if self.crawl_store is not None:
await self.crawl_store.crawl_db.close()
if self.udp_protocol_ipv6 is not None:
await self.udp_protocol_ipv6.stop() # stop responding to and accepting udp requests (ipv6) & ipv4 if linux.
if self.udp_protocol_ipv4 is not None:
await self.udp_protocol_ipv4.stop() # stop responding to and accepting udp requests (ipv4) if windows.
if self.tcp_server is not None:
self.tcp_server.close() # stop accepting new tcp requests (ipv4 and ipv6)
await self.tcp_server.wait_closed() # wait for existing TCP requests to finish (ipv4 and ipv6)
self.shutdown_event.set()
async def periodically_get_reliable_peers(self) -> None:
sleep_interval = 0 sleep_interval = 0
while True: while not self.shutdown_event.is_set() and self.crawl_store is not None:
try: try:
# TODO: double check this. It shouldn't take this long to connect. new_reliable_peers = await self.crawl_store.get_good_peers()
crawl_db = await aiosqlite.connect(self.db_path, timeout=600) except Exception as e:
cursor = await crawl_db.execute( log.error(f"Error loading reliable peers from database: {e}. Traceback: {traceback.format_exc()}.")
"SELECT * from good_peers", continue
) if len(new_reliable_peers) == 0:
new_reliable_peers = [] log.info("No reliable peers found in database, waiting for db to be populated.")
rows = await cursor.fetchall() await asyncio.sleep(2) # sleep for 2 seconds, because the db has not been populated yet.
await cursor.close() continue
await crawl_db.close() async with self.lock:
for row in rows: self.reliable_peers_v4 = []
new_reliable_peers.append(row[0]) self.reliable_peers_v6 = []
if len(new_reliable_peers) > 0: self.pointer_v4 = 0
random.shuffle(new_reliable_peers) self.pointer_v6 = 0
async with self.lock: for peer in new_reliable_peers:
self.reliable_peers_v4 = [] try:
self.reliable_peers_v6 = [] validated_peer = ip_address(peer)
for peer in new_reliable_peers: if validated_peer.version == 4:
ipv4 = True self.reliable_peers_v4.append(validated_peer)
try: elif validated_peer.version == 6:
_ = ipaddress.IPv4Address(peer) self.reliable_peers_v6.append(validated_peer)
except ValueError: except ValueError:
ipv4 = False log.error(f"Invalid peer: {peer}")
if ipv4: continue
self.reliable_peers_v4.append(peer) log.info(
else:
try:
_ = ipaddress.IPv6Address(peer)
except ValueError:
continue
self.reliable_peers_v6.append(peer)
self.pointer_v4 = 0
self.pointer_v6 = 0
log.error(
f"Number of reliable peers discovered in dns server:" f"Number of reliable peers discovered in dns server:"
f" IPv4 count - {len(self.reliable_peers_v4)}" f" IPv4 count - {len(self.reliable_peers_v4)}"
f" IPv6 count - {len(self.reliable_peers_v6)}" f" IPv6 count - {len(self.reliable_peers_v6)}"
) )
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
sleep_interval = min(15, sleep_interval + 1) sleep_interval = min(15, sleep_interval + 1)
await asyncio.sleep(sleep_interval * 60) await asyncio.sleep(sleep_interval * 60)
async def get_peers_to_respond(self, ipv4_count, ipv6_count): async def get_peers_to_respond(self, ipv4_count: int, ipv6_count: int) -> PeerList:
peers = []
async with self.lock: async with self.lock:
# Append IPv4. # Append IPv4.
ipv4_peers: List[IPv4Address] = []
size = len(self.reliable_peers_v4) size = len(self.reliable_peers_v4)
if ipv4_count > 0 and size <= ipv4_count: if ipv4_count > 0 and size <= ipv4_count:
peers = self.reliable_peers_v4 ipv4_peers = self.reliable_peers_v4
elif ipv4_count > 0: elif ipv4_count > 0:
peers = [self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)] ipv4_peers = [
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)
]
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size # mark where we left off
# Append IPv6. # Append IPv6.
ipv6_peers: List[IPv6Address] = []
size = len(self.reliable_peers_v6) size = len(self.reliable_peers_v6)
if ipv6_count > 0 and size <= ipv6_count: if ipv6_count > 0 and size <= ipv6_count:
peers = peers + self.reliable_peers_v6 ipv6_peers = self.reliable_peers_v6
elif ipv6_count > 0: elif ipv6_count > 0:
peers = peers + [ ipv6_peers = [
self.reliable_peers_v6[i % size] for i in range(self.pointer_v6, self.pointer_v6 + ipv6_count) self.reliable_peers_v6[i % size] for i in range(self.pointer_v6, self.pointer_v6 + ipv6_count)
] ]
self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size # mark where we left off
return peers return PeerList(ipv4_peers, ipv6_peers)
async def dns_response(self, data): async def dns_response(self, request: DNSRecord) -> DNSRecord:
try: """
request = DNSRecord.parse(data) This function is called when a DNS request is received, and it returns a DNS response.
IPs = [MX(D.mail), soa_record] + ns_records It does not catch any errors as it is called from within a try-except block.
ipv4_count = 0 """
ipv6_count = 0 reply = create_dns_reply(request)
if request.q.qtype == 1: dns_question: DNSQuestion = request.q # this is the question / request
ipv4_count = 32 question_type: int = dns_question.qtype # the type of the record being requested
elif request.q.qtype == 28: qname = dns_question.qname # the name being queried / requested
ipv6_count = 32 # ADD EDNS0 to response if supported
elif request.q.qtype == 255: if len(request.ar) > 0 and request.ar[0].rtype == QTYPE.OPT: # OPT Means EDNS
ipv4_count = 16 udp_len = min(4096, request.ar[0].edns_len)
ipv6_count = 16 edns_reply = EDNS0(udp_len=udp_len)
else: reply.add_ar(edns_reply)
ipv4_count = 32 # DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve
peers = await self.get_peers_to_respond(ipv4_count, ipv6_count) # transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00
if len(peers) == 0: qname_str = str(qname).lower()
return None if qname_str != self.domain and not qname_str.endswith("." + self.domain):
for peer in peers: # we don't answer for other domains (we have the not recursive bit set)
ipv4 = True log.warning(f"Invalid request for {qname_str}, returning REFUSED.")
try: reply.header.rcode = RCODE.REFUSED
_ = ipaddress.IPv4Address(peer) return reply
except ValueError:
ipv4 = False
if ipv4:
IPs.append(A(peer))
else:
try:
_ = ipaddress.IPv6Address(peer)
except ValueError:
continue
IPs.append(AAAA(peer))
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=len(IPs), ra=1), q=request.q)
records = { ttl: int = self.ttl
D: IPs, # we add these to the list as it will allow us to respond to ns and soa requests
D.ns1: [A(IP)], # MX and NS records must never point to a CNAME alias (RFC 2181 section 10.3) ips: List[RD] = [self.soa_record] + self.ns_records
D.ns2: [A(IP)], ipv4_count = 0
D.mail: [A(IP)], ipv6_count = 0
D.andrei: [CNAME(D)], if question_type is QTYPE.A:
} ipv4_count = 32
elif question_type is QTYPE.AAAA:
ipv6_count = 32
elif question_type is QTYPE.ANY:
ipv4_count = 16
ipv6_count = 16
else:
ipv4_count = 32
peers: PeerList = await self.get_peers_to_respond(ipv4_count, ipv6_count)
if peers.no_peers:
log.error("No peers found, returning SOA and NS records only.")
ttl = 60 # 1 minute as we should have some peers very soon
# we always return the SOA and NS records, so we continue even if there are no peers
ips.extend([A(str(peer)) for peer in peers.ipv4])
ips.extend([AAAA(str(peer)) for peer in peers.ipv6])
qname = request.q.qname records: Dict[DomainName, List[RD]] = { # this is where we can add other records we want to serve
# DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve self.domain: ips,
# transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00 }
qn = str(qname).lower()
qtype = request.q.qtype
qt = QTYPE[qtype]
if qn == D or qn.endswith("." + D):
for name, rrs in records.items():
if name == qn:
for rdata in rrs:
rqt = rdata.__class__.__name__
if qt in ["*", rqt] or (qt == "ANY" and (rqt == "A" or rqt == "AAAA")):
reply.add_answer(
RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata)
)
for rdata in ns_records: valid_domain = False
reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata)) for domain_name, domain_responses in records.items():
if domain_name == qname_str: # if the dns name is the same as the requested name
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record)) valid_domain = True
for response in domain_responses:
return reply.pack() rqt: int = getattr(QTYPE, response.__class__.__name__)
except Exception as e: if question_type == rqt or question_type == QTYPE.ANY:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") reply.add_answer(RR(rname=qname, rtype=rqt, rclass=1, ttl=ttl, rdata=response))
if not valid_domain and len(reply.rr) == 0: # if we didn't find any records to return
reply.header.rcode = RCODE.NXDOMAIN
# always put nameservers and the SOA records
for nameserver in self.ns_records:
reply.add_auth(RR(rname=self.domain, rtype=QTYPE.NS, rclass=1, ttl=ttl, rdata=nameserver))
reply.add_auth(RR(rname=self.domain, rtype=QTYPE.SOA, rclass=1, ttl=ttl, rdata=self.soa_record))
return reply
async def serve_dns(config: Dict, root_path: Path): async def run_dns_server(dns_server: DNSServer) -> None: # pragma: no cover
dns_server = DNSServer(config, root_path) async with dns_server.run():
await dns_server.start() await dns_server.shutdown_event.wait() # this is released on SIGINT or SIGTERM or any unhandled exception
# TODO: Make this cleaner?
while True:
await asyncio.sleep(3600)
async def kill_processes(): def create_dns_server_service(config: Dict[str, Any], root_path: Path) -> DNSServer:
# TODO: implement. service_config = config[SERVICE_NAME]
pass
return DNSServer(service_config, root_path)
def signal_received(): def main() -> None: # pragma: no cover
asyncio.create_task(kill_processes()) freeze_support()
async def async_main(config, root_path):
loop = asyncio.get_running_loop()
try:
loop.add_signal_handler(signal.SIGINT, signal_received)
loop.add_signal_handler(signal.SIGTERM, signal_received)
except NotImplementedError:
log.info("signal handlers unsupported")
await serve_dns(config, root_path)
def main():
root_path = DEFAULT_ROOT_PATH root_path = DEFAULT_ROOT_PATH
config = load_config(root_path, "config.yaml", SERVICE_NAME) # TODO: refactor to avoid the double load
initialize_logging(SERVICE_NAME, config["logging"], root_path) config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
global D service_config = load_config_cli(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
global ns config[SERVICE_NAME] = service_config
global TTL initialize_service_logging(service_name=SERVICE_NAME, config=config)
global soa_record
global ns_records
D = DomainName(config["domain_name"])
ns = DomainName(config["nameserver"])
TTL = config["ttl"]
soa_record = SOA(
mname=ns, # primary name server
rname=config["soa"]["rname"], # email of the domain administrator
times=(
config["soa"]["serial_number"],
config["soa"]["refresh"],
config["soa"]["retry"],
config["soa"]["expire"],
config["soa"]["minimum"],
),
)
ns_records = [NS(ns)]
asyncio.run(async_main(config=config, root_path=root_path)) dns_server = create_dns_server_service(config, root_path)
asyncio.run(run_dns_server(dns_server))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,23 +24,19 @@ class PeerRecord(Streamable):
handshake_time: uint64 handshake_time: uint64
tls_version: str tls_version: str
def update_version(self, version, now): def update_version(self, version: str, now: uint64) -> None:
if version != "undefined": if version != "undefined":
object.__setattr__(self, "version", version) object.__setattr__(self, "version", version)
object.__setattr__(self, "handshake_time", uint64(now)) object.__setattr__(self, "handshake_time", now)
@dataclass
class PeerStat: class PeerStat:
weight: float weight: float
count: float count: float
reliability: float reliability: float
def __init__(self, weight, count, reliability): def update(self, is_reachable: bool, age: int, tau: int) -> None:
self.weight = weight
self.count = count
self.reliability = reliability
def update(self, is_reachable: bool, age: int, tau: int):
f = math.exp(-age / tau) f = math.exp(-age / tau)
self.reliability = self.reliability * f + (1.0 - f if is_reachable else 0.0) self.reliability = self.reliability * f + (1.0 - f if is_reachable else 0.0)
self.count = self.count * f + 1.0 self.count = self.count * f + 1.0
@ -81,7 +77,7 @@ class PeerReliability:
stat_1m_reliability: float = 0.0, stat_1m_reliability: float = 0.0,
tries: int = 0, tries: int = 0,
successes: int = 0, successes: int = 0,
): ) -> None:
self.peer_id = peer_id self.peer_id = peer_id
self.ignore_till = ignore_till self.ignore_till = ignore_till
self.ban_till = ban_till self.ban_till = ban_till
@ -132,7 +128,7 @@ class PeerReliability:
return 2 * 3600 return 2 * 3600
return 0 return 0
def update(self, is_reachable: bool, age: int): def update(self, is_reachable: bool, age: int) -> None:
self.stat_2h.update(is_reachable, age, 2 * 3600) self.stat_2h.update(is_reachable, age, 2 * 3600)
self.stat_8h.update(is_reachable, age, 8 * 3600) self.stat_8h.update(is_reachable, age, 8 * 3600)
self.stat_1d.update(is_reachable, age, 24 * 3600) self.stat_1d.update(is_reachable, age, 24 * 3600)

View File

@ -4,7 +4,7 @@ import logging
import pathlib import pathlib
import sys import sys
from multiprocessing import freeze_support from multiprocessing import freeze_support
from typing import Dict, Optional from typing import Any, Dict, Optional
from chia.consensus.constants import ConsensusConstants from chia.consensus.constants import ConsensusConstants
from chia.consensus.default_constants import DEFAULT_CONSTANTS from chia.consensus.default_constants import DEFAULT_CONSTANTS
@ -26,24 +26,25 @@ log = logging.getLogger(__name__)
def create_full_node_crawler_service( def create_full_node_crawler_service(
root_path: pathlib.Path, root_path: pathlib.Path,
config: Dict, config: Dict[str, Any],
consensus_constants: ConsensusConstants, consensus_constants: ConsensusConstants,
connect_to_daemon: bool = True, connect_to_daemon: bool = True,
) -> Service[Crawler, CrawlerAPI]: ) -> Service[Crawler, CrawlerAPI]:
service_config = config[SERVICE_NAME] service_config = config[SERVICE_NAME]
crawler_config = service_config["crawler"]
crawler = Crawler( crawler = Crawler(
service_config, service_config,
root_path=root_path, root_path=root_path,
consensus_constants=consensus_constants, constants=consensus_constants,
) )
api = CrawlerAPI(crawler) api = CrawlerAPI(crawler)
network_id = service_config["selected_network"] network_id = service_config["selected_network"]
rpc_info: Optional[RpcInfo] = None rpc_info: Optional[RpcInfo] = None
if service_config.get("start_rpc_server", True): if crawler_config.get("start_rpc_server", True):
rpc_info = (CrawlerRpcApi, service_config.get("rpc_port", 8561)) rpc_info = (CrawlerRpcApi, crawler_config.get("rpc_port", 8561))
return Service( return Service(
root_path=root_path, root_path=root_path,

View File

@ -24,6 +24,7 @@ from chia.introducer.introducer_api import IntroducerAPI
from chia.protocols.shared_protocol import Capability, capabilities from chia.protocols.shared_protocol import Capability, capabilities
from chia.seeder.crawler import Crawler from chia.seeder.crawler import Crawler
from chia.seeder.crawler_api import CrawlerAPI from chia.seeder.crawler_api import CrawlerAPI
from chia.seeder.dns_server import DNSServer, create_dns_server_service
from chia.seeder.start_crawler import create_full_node_crawler_service from chia.seeder.start_crawler import create_full_node_crawler_service
from chia.server.start_farmer import create_farmer_service from chia.server.start_farmer import create_farmer_service
from chia.server.start_full_node import create_full_node_service from chia.server.start_full_node import create_full_node_service
@ -32,15 +33,17 @@ from chia.server.start_introducer import create_introducer_service
from chia.server.start_service import Service from chia.server.start_service import Service
from chia.server.start_timelord import create_timelord_service from chia.server.start_timelord import create_timelord_service
from chia.server.start_wallet import create_wallet_service from chia.server.start_wallet import create_wallet_service
from chia.simulator.block_tools import BlockTools from chia.simulator.block_tools import BlockTools, test_constants
from chia.simulator.keyring import TempKeyring from chia.simulator.keyring import TempKeyring
from chia.simulator.ssl_certs import get_next_nodes_certs_and_keys, get_next_private_ca_cert_and_key
from chia.simulator.start_simulator import create_full_node_simulator_service from chia.simulator.start_simulator import create_full_node_simulator_service
from chia.ssl.create_ssl import create_all_ssl
from chia.timelord.timelord import Timelord from chia.timelord.timelord import Timelord
from chia.timelord.timelord_api import TimelordAPI from chia.timelord.timelord_api import TimelordAPI
from chia.timelord.timelord_launcher import kill_processes, spawn_process from chia.timelord.timelord_launcher import kill_processes, spawn_process
from chia.types.peer_info import UnresolvedPeerInfo from chia.types.peer_info import UnresolvedPeerInfo
from chia.util.bech32m import encode_puzzle_hash from chia.util.bech32m import encode_puzzle_hash
from chia.util.config import config_path_for_filename, lock_and_load_config, save_config from chia.util.config import config_path_for_filename, load_config, lock_and_load_config, save_config
from chia.util.ints import uint16 from chia.util.ints import uint16
from chia.util.keychain import bytes_to_mnemonic from chia.util.keychain import bytes_to_mnemonic
from chia.util.lock import Lockfile from chia.util.lock import Lockfile
@ -175,24 +178,36 @@ async def setup_full_node(
async def setup_crawler( async def setup_crawler(
bt: BlockTools, root_path_populated_with_config: Path, database_uri: str
) -> AsyncGenerator[Service[Crawler, CrawlerAPI], None]: ) -> AsyncGenerator[Service[Crawler, CrawlerAPI], None]:
config = bt.config create_all_ssl(
root_path=root_path_populated_with_config,
private_ca_crt_and_key=get_next_private_ca_cert_and_key().collateral.cert_and_key,
node_certs_and_keys=get_next_nodes_certs_and_keys().collateral.certs_and_keys,
)
config = load_config(root_path_populated_with_config, "config.yaml")
service_config = config["seeder"] service_config = config["seeder"]
service_config["selected_network"] = "testnet0" service_config["selected_network"] = "testnet0"
service_config["port"] = 0 service_config["port"] = 0
service_config["start_rpc_server"] = False service_config["crawler"]["start_rpc_server"] = False
service_config["other_peers_port"] = 58444
service_config["crawler_db_path"] = database_uri
overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]] overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]]
updated_constants = bt.constants.replace_str_to_bytes(**overrides) updated_constants = test_constants.replace_str_to_bytes(**overrides)
bt.change_config(config)
service = create_full_node_crawler_service( service = create_full_node_crawler_service(
bt.root_path, root_path_populated_with_config,
config, config,
updated_constants, updated_constants,
connect_to_daemon=False, connect_to_daemon=False,
) )
await service.start() await service.start()
if not service_config["crawler"]["start_rpc_server"]: # otherwise the loops don't work.
service._node.state_changed_callback = lambda x, y: None
try: try:
yield service yield service
finally: finally:
@ -200,6 +215,24 @@ async def setup_crawler(
await service.wait_closed() await service.wait_closed()
async def setup_seeder(root_path_populated_with_config: Path, database_uri: str) -> AsyncGenerator[DNSServer, None]:
config = load_config(root_path_populated_with_config, "config.yaml")
service_config = config["seeder"]
service_config["selected_network"] = "testnet0"
if service_config["domain_name"].endswith("."): # remove the trailing . so that we can test that logic.
service_config["domain_name"] = service_config["domain_name"][:-1]
service_config["dns_port"] = 0
service_config["crawler_db_path"] = database_uri
service = create_dns_server_service(
config,
root_path_populated_with_config,
)
async with service.run():
yield service
# Note: convert these setup functions to fixtures, or push it one layer up, # Note: convert these setup functions to fixtures, or push it one layer up,
# keeping these usable independently? # keeping these usable independently?
async def setup_wallet_node( async def setup_wallet_node(

View File

@ -139,6 +139,8 @@ seeder:
port: 8444 port: 8444
# Most full nodes on the network run on this port. (i.e. 8444 for mainnet, 58444 for testnet). # Most full nodes on the network run on this port. (i.e. 8444 for mainnet, 58444 for testnet).
other_peers_port: 8444 other_peers_port: 8444
# What port to run the DNS server on, (this is useful if you are already using port 53 for DNS).
dns_port: 53
# This will override the default full_node.peer_connect_timeout for the crawler full node # This will override the default full_node.peer_connect_timeout for the crawler full node
peer_connect_timeout: 2 peer_connect_timeout: 2
# Path to crawler DB. Defaults to $CHIA_ROOT/crawler.db # Path to crawler DB. Defaults to $CHIA_ROOT/crawler.db
@ -154,7 +156,7 @@ seeder:
nameserver: "example.com." nameserver: "example.com."
ttl: 300 ttl: 300
soa: soa:
rname: "hostmaster.example.com." rname: "hostmaster.example.com" # all @ symbols need to be replaced with . in dns records.
serial_number: 1619105223 serial_number: 1619105223
refresh: 10800 refresh: 10800
retry: 10800 retry: 10800

View File

@ -16,11 +16,6 @@ chia.rpc.rpc_client
chia.rpc.util chia.rpc.util
chia.rpc.wallet_rpc_api chia.rpc.wallet_rpc_api
chia.rpc.wallet_rpc_client chia.rpc.wallet_rpc_client
chia.seeder.crawl_store
chia.seeder.crawler
chia.seeder.dns_server
chia.seeder.peer_record
chia.seeder.start_crawler
chia.simulator.full_node_simulator chia.simulator.full_node_simulator
chia.simulator.keyring chia.simulator.keyring
chia.simulator.setup_services chia.simulator.setup_services

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import datetime import datetime
import multiprocessing import multiprocessing
import os import os
import random
import sysconfig import sysconfig
import tempfile import tempfile
from enum import Enum from enum import Enum
@ -32,6 +33,7 @@ from chia.rpc.harvester_rpc_client import HarvesterRpcClient
from chia.rpc.wallet_rpc_client import WalletRpcClient from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.seeder.crawler import Crawler from chia.seeder.crawler import Crawler
from chia.seeder.crawler_api import CrawlerAPI from chia.seeder.crawler_api import CrawlerAPI
from chia.seeder.dns_server import DNSServer
from chia.server.server import ChiaServer from chia.server.server import ChiaServer
from chia.server.start_service import Service from chia.server.start_service import Service
from chia.simulator.full_node_simulator import FullNodeSimulator from chia.simulator.full_node_simulator import FullNodeSimulator
@ -44,7 +46,7 @@ from chia.simulator.setup_nodes import (
setup_simulators_and_wallets_service, setup_simulators_and_wallets_service,
setup_two_nodes, setup_two_nodes,
) )
from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_timelord from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_seeder, setup_timelord
from chia.simulator.time_out_assert import time_out_assert from chia.simulator.time_out_assert import time_out_assert
from chia.simulator.wallet_tools import WalletTool from chia.simulator.wallet_tools import WalletTool
from chia.types.peer_info import PeerInfo from chia.types.peer_info import PeerInfo
@ -872,11 +874,19 @@ async def timelord_service(bt):
@pytest_asyncio.fixture(scope="function") @pytest_asyncio.fixture(scope="function")
async def crawler_service(bt: BlockTools) -> AsyncIterator[Service[Crawler, CrawlerAPI]]: async def crawler_service(
async for service in setup_crawler(bt): root_path_populated_with_config: Path, database_uri: str
) -> AsyncIterator[Service[Crawler, CrawlerAPI]]:
async for service in setup_crawler(root_path_populated_with_config, database_uri):
yield service yield service
@pytest_asyncio.fixture(scope="function")
async def seeder_service(root_path_populated_with_config: Path, database_uri: str) -> AsyncIterator[DNSServer]:
async for seeder in setup_seeder(root_path_populated_with_config, database_uri):
yield seeder
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def tmp_chia_root(tmp_path): def tmp_chia_root(tmp_path):
""" """
@ -980,3 +990,8 @@ async def harvester_farmer_environment(
harvester_rpc_cl.close() harvester_rpc_cl.close()
await farmer_rpc_cl.await_closed() await farmer_rpc_cl.await_closed()
await harvester_rpc_cl.await_closed() await harvester_rpc_cl.await_closed()
@pytest.fixture(name="database_uri")
def database_uri_fixture() -> str:
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"

View File

@ -2,7 +2,6 @@ from __future__ import annotations
import os import os
import pathlib import pathlib
import random
import sys import sys
import time import time
from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator
@ -54,11 +53,6 @@ def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32]
return request.param # type: ignore[no-any-return] return request.param # type: ignore[no-any-return]
@pytest.fixture(name="database_uri")
def database_uri_fixture() -> str:
return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"
@pytest.fixture(name="tree_id", scope="function") @pytest.fixture(name="tree_id", scope="function")
def tree_id_fixture() -> bytes32: def tree_id_fixture() -> bytes32:
base = b"a tree id" base = b"a tree id"

View File

@ -1,6 +1,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import time
from typing import cast from typing import cast
import pytest import pytest
@ -11,13 +12,14 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import RequestChildren from chia.protocols.wallet_protocol import RequestChildren
from chia.seeder.crawler import Crawler from chia.seeder.crawler import Crawler
from chia.seeder.crawler_api import CrawlerAPI from chia.seeder.crawler_api import CrawlerAPI
from chia.seeder.peer_record import PeerRecord, PeerReliability
from chia.server.outbound_message import make_msg from chia.server.outbound_message import make_msg
from chia.server.start_service import Service from chia.server.start_service import Service
from chia.simulator.setup_nodes import SimulatorsAndWalletsServices from chia.simulator.setup_nodes import SimulatorsAndWalletsServices
from chia.simulator.time_out_assert import time_out_assert from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32, uint128 from chia.util.ints import uint16, uint32, uint64, uint128
@pytest.mark.asyncio @pytest.mark.asyncio
@ -68,3 +70,43 @@ async def test_valid_message(
) )
assert await connection.send_message(msg) assert await connection.send_message(msg)
await time_out_assert(10, peer_added) await time_out_assert(10, peer_added)
@pytest.mark.asyncio
async def test_crawler_to_db(
crawler_service: Service[Crawler, CrawlerAPI], one_node: SimulatorsAndWalletsServices
) -> None:
"""
This is a lot more of an integration test, but it tests the whole process. We add a node to the crawler, then we
save it to the db and validate.
"""
[full_node_service], _, _ = one_node
full_node = full_node_service._node
crawler = crawler_service._node
crawl_store = crawler.crawl_store
assert crawl_store is not None
peer_address = "127.0.0.1"
# create peer records
peer_record = PeerRecord(
peer_address,
peer_address,
uint32(full_node.server.get_port()),
False,
uint64(0),
uint32(0),
uint64(0),
uint64(int(time.time())),
uint64(0),
"undefined",
uint64(0),
tls_version="unknown",
)
peer_reliability = PeerReliability(peer_address, tries=1, successes=1)
# add peer to the db & mark it as connected
await crawl_store.add_peer(peer_record, peer_reliability)
assert peer_record == crawl_store.host_to_records[peer_address]
# validate the db data
await time_out_assert(20, crawl_store.get_good_peers, [peer_address])

View File

@ -9,7 +9,7 @@ from chia.seeder.crawler import Crawler
class TestCrawlerRpc: class TestCrawlerRpc:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_ips_after_timestamp(self, bt): async def test_get_ips_after_timestamp(self, bt):
crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, consensus_constants=bt.constants) crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, constants=bt.constants)
crawler_rpc_api = CrawlerRpcApi(crawler) crawler_rpc_api = CrawlerRpcApi(crawler)
# Should raise ValueError when `after` is not supplied # Should raise ValueError when `after` is not supplied

275
tests/core/test_seeder.py Normal file
View File

@ -0,0 +1,275 @@
from __future__ import annotations
import time
from dataclasses import dataclass
from ipaddress import IPv4Address, IPv6Address
from socket import AF_INET6, SOCK_STREAM
from typing import List, Tuple, cast
import dns
import pytest
from chia.seeder.dns_server import DNSServer
from chia.seeder.peer_record import PeerRecord, PeerReliability
from chia.simulator.time_out_assert import time_out_assert
from chia.util.ints import uint32, uint64
timeout = 0.5
def generate_test_combs() -> List[Tuple[bool, str, dns.rdatatype.RdataType]]:
"""
Generates all the combinations of tests we want to run.
"""
output = []
use_tcp = [True, False]
target_address = ["::1", "127.0.0.1"]
request_types = [dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.ANY, dns.rdatatype.NS, dns.rdatatype.SOA]
# (use_tcp, target-addr, request_type), so udp + tcp, ipv6 +4 for every request type we support.
for addr in target_address:
for tcp in use_tcp:
for req in request_types:
output.append((tcp, addr, req))
return output
all_test_combinations = generate_test_combs()
@dataclass(frozen=True)
class FakeDnsPacket:
real_packet: dns.message.Message
def to_wire(self) -> bytes:
return self.real_packet.to_wire()[23:]
def __getattr__(self, item: object) -> None:
# This is definitely cheating, but it works
return None
async def make_dns_query(
use_tcp: bool, target_address: str, port: int, dns_message: dns.message.Message, d_timeout: float = timeout
) -> dns.message.Message:
"""
Makes a DNS query for the given domain name using the given protocol type.
"""
if use_tcp:
return await dns.asyncquery.tcp(q=dns_message, where=target_address, timeout=d_timeout, port=port)
return await dns.asyncquery.udp(q=dns_message, where=target_address, timeout=d_timeout, port=port)
def get_addresses(num_subnets: int = 10) -> Tuple[List[IPv4Address], List[IPv6Address]]:
ipv4 = []
ipv6 = []
# generate 2500 ipv4 and 2500 ipv6 peers, it's just a string so who cares
for s in range(num_subnets):
for i in range(1, 251): # im being lazy as we can only have 255 per subnet
ipv4.append(IPv4Address(f"192.168.{s}.{i}"))
ipv6.append(IPv6Address(f"2001:db8::{s}:{i}"))
return ipv4, ipv6
def assert_standard_results(
std_query_answer: List[dns.rrset.RRset], request_type: dns.rdatatype.RdataType, num_ns: int
) -> None:
if request_type == dns.rdatatype.A:
assert len(std_query_answer) == 1 # only 1 kind of answer
a_answer = std_query_answer[0]
assert a_answer.rdtype == dns.rdatatype.A
assert len(a_answer) == 32 # 32 ipv4 addresses
elif request_type == dns.rdatatype.AAAA:
assert len(std_query_answer) == 1 # only 1 kind of answer
aaaa_answer = std_query_answer[0]
assert aaaa_answer.rdtype == dns.rdatatype.AAAA
assert len(aaaa_answer) == 32 # 32 ipv6 addresses
elif request_type == dns.rdatatype.ANY:
assert len(std_query_answer) == 4 # 4 kinds of answers
for answer in std_query_answer:
if answer.rdtype == dns.rdatatype.A:
assert len(answer) == 16
elif answer.rdtype == dns.rdatatype.AAAA:
assert len(answer) == 16
elif answer.rdtype == dns.rdatatype.NS:
assert len(answer) == num_ns
else:
assert len(answer) == 1
elif request_type == dns.rdatatype.NS:
assert len(std_query_answer) == 1 # only 1 kind of answer
ns_answer = std_query_answer[0]
assert ns_answer.rdtype == dns.rdatatype.NS
assert len(ns_answer) == num_ns # ns records
else:
assert len(std_query_answer) == 1 # soa
soa_answer = std_query_answer[0]
assert soa_answer.rdtype == dns.rdatatype.SOA
assert len(soa_answer) == 1
@pytest.mark.asyncio
@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations)
async def test_error_conditions(
seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType
) -> None:
"""
We check having no peers, an invalid packet, an early EOF, and a packet then an EOF halfway through (tcp only).
We also check for a dns record that does not exist, and a dns record outside the domain.
"""
port = seeder_service.dns_port
domain = seeder_service.domain # default is: seeder.example.com
num_ns = len(seeder_service.ns_records)
# No peers
no_peers = dns.message.make_query(domain, request_type)
no_peers_response = await make_dns_query(use_tcp, target_address, port, no_peers)
assert no_peers_response.rcode() == dns.rcode.NOERROR
if request_type == dns.rdatatype.A or request_type == dns.rdatatype.AAAA:
assert len(no_peers_response.answer) == 0 # no response, as expected
elif request_type == dns.rdatatype.ANY: # ns + soa
assert len(no_peers_response.answer) == 2
for answer in no_peers_response.answer:
if answer.rdtype == dns.rdatatype.NS:
assert len(answer.items) == num_ns
else:
assert len(answer.items) == 1
elif request_type == dns.rdatatype.NS:
assert len(no_peers_response.answer) == 1 # ns
assert no_peers_response.answer[0].rdtype == dns.rdatatype.NS
assert len(no_peers_response.answer[0].items) == num_ns
else:
assert len(no_peers_response.answer) == 1 # soa
assert no_peers_response.answer[0].rdtype == dns.rdatatype.SOA
assert len(no_peers_response.answer[0].items) == 1
# Authority Records
assert len(no_peers_response.authority) == num_ns + 1 # ns + soa
for record_list in no_peers_response.authority:
if record_list.rdtype == dns.rdatatype.NS:
assert len(record_list.items) == num_ns
else:
assert len(record_list.items) == 1 # soa
# Invalid packet (this is kinda a pain)
invalid_packet = cast(dns.message.Message, FakeDnsPacket(dns.message.make_query(domain, request_type)))
with pytest.raises(EOFError if use_tcp else dns.exception.Timeout): # UDP will time out, TCP will EOF
await make_dns_query(use_tcp, target_address, port, invalid_packet)
# early EOF packet
if use_tcp:
backend = dns.asyncbackend.get_default_backend()
tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call]
AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout
)
async with tcp_socket as socket:
await socket.close()
with pytest.raises(EOFError):
await dns.asyncquery.receive_tcp(tcp_socket, timeout)
# Packet then length header then EOF
if use_tcp:
backend = dns.asyncbackend.get_default_backend()
tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call]
AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout
)
async with tcp_socket as socket: # send packet then length then eof
r = await dns.asyncquery.tcp(q=no_peers, where="::1", timeout=timeout, sock=socket)
assert r.answer == no_peers_response.answer
# send 120, as the first 2 bytes / the length of the packet, so that the server expects more.
await socket.sendall(int(120).to_bytes(2, byteorder="big"), int(time.time() + timeout))
await socket.close()
with pytest.raises(EOFError):
await dns.asyncquery.receive_tcp(tcp_socket, timeout)
# Record does not exist
record_does_not_exist = dns.message.make_query("doesnotexist." + domain, request_type)
record_does_not_exist_response = await make_dns_query(use_tcp, target_address, port, record_does_not_exist)
assert record_does_not_exist_response.rcode() == dns.rcode.NXDOMAIN
assert len(record_does_not_exist_response.answer) == 0
assert len(record_does_not_exist_response.authority) == num_ns + 1 # ns + soa
# Record outside domain
record_outside_domain = dns.message.make_query("chia.net", request_type)
record_outside_domain_response = await make_dns_query(use_tcp, target_address, port, record_outside_domain)
assert record_outside_domain_response.rcode() == dns.rcode.REFUSED
assert len(record_outside_domain_response.answer) == 0
assert len(record_outside_domain_response.authority) == 0
@pytest.mark.asyncio
@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations)
async def test_dns_queries(
seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType
) -> None:
"""
We add 5000 peers directly, then try every kind of query many times over both the TCP and UDP protocols.
"""
port = seeder_service.dns_port
domain = seeder_service.domain # default is: seeder.example.com
num_ns = len(seeder_service.ns_records)
# add 5000 peers (2500 ipv4, 2500 ipv6)
seeder_service.reliable_peers_v4, seeder_service.reliable_peers_v6 = get_addresses()
# now we query for each type of record a lot of times and make sure we get the right number of responses
for i in range(150):
query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id.
std_query_response = await make_dns_query(use_tcp, target_address, port, query)
assert std_query_response.rcode() == dns.rcode.NOERROR
assert_standard_results(std_query_response.answer, request_type, num_ns)
# Assert Authority Records
assert len(std_query_response.authority) == num_ns + 1 # ns + soa
for record_list in std_query_response.authority:
if record_list.rdtype == dns.rdatatype.NS:
assert len(record_list.items) == num_ns
else:
assert len(record_list.items) == 1 # soa
# Validate EDNS
e_query = dns.message.make_query(domain, dns.rdatatype.ANY, use_edns=False)
with pytest.raises(dns.query.BadResponse): # response is truncated without EDNS
await make_dns_query(False, target_address, port, e_query)
@pytest.mark.asyncio
async def test_db_processing(seeder_service: DNSServer) -> None:
"""
We add 1000 peers through the db, then try every kind of query over both the TCP and UDP protocols.
"""
port = seeder_service.dns_port
domain = seeder_service.domain # default is: seeder.example.com
num_ns = len(seeder_service.ns_records)
crawl_store = seeder_service.crawl_store
assert crawl_store is not None
ipv4, ipv6 = get_addresses(2) # get 1000 addresses
# add peers to the in memory part of the db.
for peer in [str(peer) for pair in zip(ipv4, ipv6) for peer in pair]:
new_peer = PeerRecord(
peer,
peer,
uint32(58444),
False,
uint64(0),
uint32(0),
uint64(0),
uint64(int(time.time())),
uint64(0),
"undefined",
uint64(0),
tls_version="unknown",
)
new_peer_reliability = PeerReliability(peer, tries=3, successes=3) # make sure the peer starts as reliable.
crawl_store.maybe_add_peer(new_peer, new_peer_reliability) # we don't fully add it because we don't need to.
# Write these new peers to db.
await crawl_store.load_reliable_peers_to_db()
# wait for the new db to be read.
await time_out_assert(30, lambda: seeder_service.reliable_peers_v4 != [])
# now we check all the combinations once (not a stupid amount of times)
for use_tcp, target_address, request_type in all_test_combinations:
query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id.
std_query_response = await make_dns_query(use_tcp, target_address, port, query)
assert std_query_response.rcode() == dns.rcode.NOERROR
assert_standard_results(std_query_response.answer, request_type, num_ns)