From 0d36874caadace6921dca14029ebc565ee36167c Mon Sep 17 00:00:00 2001 From: Jack Nelson Date: Fri, 11 Aug 2023 06:25:19 -0400 Subject: [PATCH] Refactor Seeder & Crawler code + add tests (#15781) * Cleanup seeder & mypy all files holy crap mother of all tech debt this was horrid * Make UDP Protocol class a dataclass & separate out functions. * add TCP protocol class to DNS server * Fix mypy types & other cleanup also fix a couple bugs * Add seeder and crawler tests * change log levels * re add db lock timeout oops * add edns & edns tests * fix repeated shutdown on close signal * fix binding to use both ipv6 and ipv4 whyyyyyyy * add ipv6 and ipv4 tests + add ipv4 if windows --- chia/seeder/crawl_store.py | 108 ++--- chia/seeder/crawler.py | 331 ++++++++------- chia/seeder/dns_server.py | 678 ++++++++++++++++++++---------- chia/seeder/peer_record.py | 16 +- chia/seeder/start_crawler.py | 11 +- chia/simulator/setup_services.py | 49 ++- chia/util/initial-config.yaml | 4 +- mypy-exclusions.txt | 5 - tests/conftest.py | 21 +- tests/core/data_layer/conftest.py | 6 - tests/core/test_crawler.py | 44 +- tests/core/test_crawler_rpc.py | 2 +- tests/core/test_seeder.py | 275 ++++++++++++ 13 files changed, 1082 insertions(+), 468 deletions(-) create mode 100644 tests/core/test_seeder.py diff --git a/chia/seeder/crawl_store.py b/chia/seeder/crawl_store.py index be70e05f62b0..702f893036a3 100644 --- a/chia/seeder/crawl_store.py +++ b/chia/seeder/crawl_store.py @@ -1,37 +1,34 @@ from __future__ import annotations -import asyncio -import dataclasses import ipaddress import logging import random import time +from dataclasses import dataclass, field, replace from typing import Dict, List import aiosqlite from chia.seeder.peer_record import PeerRecord, PeerReliability +from chia.util.ints import uint64 log = logging.getLogger(__name__) +@dataclass class CrawlStore: crawl_db: aiosqlite.Connection - last_timestamp: int - lock: asyncio.Lock - - host_to_records: Dict - host_to_selected_time: Dict - host_to_reliability: Dict - banned_peers: int - ignored_peers: int - reliable_peers: int + host_to_records: Dict[str, PeerRecord] = field(default_factory=dict) # peer_id: PeerRecord + host_to_selected_time: Dict[str, float] = field(default_factory=dict) # peer_id: timestamp (as a float) + host_to_reliability: Dict[str, PeerReliability] = field(default_factory=dict) # peer_id: PeerReliability + banned_peers: int = 0 + ignored_peers: int = 0 + reliable_peers: int = 0 @classmethod - async def create(cls, connection: aiosqlite.Connection): - self = cls() + async def create(cls, connection: aiosqlite.Connection) -> CrawlStore: + self = cls(connection) - self.crawl_db = connection await self.crawl_db.execute( ( "CREATE TABLE IF NOT EXISTS peer_records(" @@ -82,21 +79,16 @@ class CrawlStore: await self.crawl_db.execute("CREATE INDEX IF NOT EXISTS ignore_till on peer_reliability(ignore_till)") await self.crawl_db.commit() - self.last_timestamp = 0 - self.ignored_peers = 0 - self.banned_peers = 0 - self.reliable_peers = 0 - self.host_to_selected_time = {} await self.unload_from_db() return self - def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability): + def maybe_add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability) -> None: if peer_record.peer_id not in self.host_to_records: self.host_to_records[peer_record.peer_id] = peer_record if peer_reliability.peer_id not in self.host_to_reliability: self.host_to_reliability[peer_reliability.peer_id] = peer_reliability - async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False): + async def add_peer(self, peer_record: PeerRecord, peer_reliability: PeerReliability, save_db: bool = False) -> None: if not save_db: self.host_to_records[peer_record.peer_id] = peer_record self.host_to_reliability[peer_reliability.peer_id] = peer_reliability @@ -152,41 +144,41 @@ class CrawlStore: async def get_peer_reliability(self, peer_id: str) -> PeerReliability: return self.host_to_reliability[peer_id] - async def peer_failed_to_connect(self, peer: PeerRecord): - now = int(time.time()) + async def peer_failed_to_connect(self, peer: PeerRecord) -> None: + now = uint64(time.time()) age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp)) if age_timestamp == 0: age_timestamp = now - 1000 - replaced = dataclasses.replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now) + replaced = replace(peer, try_count=peer.try_count + 1, last_try_timestamp=now) reliability = await self.get_peer_reliability(peer.peer_id) if reliability is None: reliability = PeerReliability(peer.peer_id) reliability.update(False, now - age_timestamp) await self.add_peer(replaced, reliability) - async def peer_connected(self, peer: PeerRecord, tls_version: str): - now = int(time.time()) + async def peer_connected(self, peer: PeerRecord, tls_version: str) -> None: + now = uint64(time.time()) age_timestamp = int(max(peer.last_try_timestamp, peer.connected_timestamp)) if age_timestamp == 0: age_timestamp = now - 1000 - replaced = dataclasses.replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version) + replaced = replace(peer, connected=True, connected_timestamp=now, tls_version=tls_version) reliability = await self.get_peer_reliability(peer.peer_id) if reliability is None: reliability = PeerReliability(peer.peer_id) reliability.update(True, now - age_timestamp) await self.add_peer(replaced, reliability) - async def update_best_timestamp(self, host: str, timestamp): + async def update_best_timestamp(self, host: str, timestamp: uint64) -> None: if host not in self.host_to_records: return record = self.host_to_records[host] - replaced = dataclasses.replace(record, best_timestamp=timestamp) + replaced = replace(record, best_timestamp=timestamp) if host not in self.host_to_reliability: return reliability = self.host_to_reliability[host] await self.add_peer(replaced, reliability) - async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown"): + async def peer_connected_hostname(self, host: str, connected: bool = True, tls_version: str = "unknown") -> None: if host not in self.host_to_records: return record = self.host_to_records[host] @@ -195,7 +187,7 @@ class CrawlStore: else: await self.peer_failed_to_connect(record) - async def get_peers_to_crawl(self, min_batch_size, max_batch_size) -> List[PeerRecord]: + async def get_peers_to_crawl(self, min_batch_size: int, max_batch_size: int) -> List[PeerRecord]: now = int(time.time()) records = [] records_v6 = [] @@ -270,20 +262,20 @@ class CrawlStore: def get_reliable_peers(self) -> int: return self.reliable_peers - async def load_to_db(self): - log.error("Saving peers to DB...") + async def load_to_db(self) -> None: + log.info("Saving peers to DB...") for peer_id in list(self.host_to_reliability.keys()): if peer_id in self.host_to_reliability and peer_id in self.host_to_records: reliability = self.host_to_reliability[peer_id] record = self.host_to_records[peer_id] await self.add_peer(record, reliability, True) await self.crawl_db.commit() - log.error(" - Done saving peers to DB") + log.info(" - Done saving peers to DB") - async def unload_from_db(self): + async def unload_from_db(self) -> None: self.host_to_records = {} self.host_to_reliability = {} - log.error("Loading peer reliability records...") + log.info("Loading peer reliability records...") cursor = await self.crawl_db.execute( "SELECT * from peer_reliability", ) @@ -312,9 +304,9 @@ class CrawlStore: row[18], row[19], ) - self.host_to_reliability[row[0]] = reliability - log.error(" - Done loading peer reliability records...") - log.error("Loading peer records...") + self.host_to_reliability[reliability.peer_id] = reliability + log.info(" - Done loading peer reliability records...") + log.info("Loading peer records...") cursor = await self.crawl_db.execute( "SELECT * from peer_records", ) @@ -324,24 +316,23 @@ class CrawlStore: peer = PeerRecord( row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7], row[8], row[9], row[10], row[11] ) - self.host_to_records[row[0]] = peer - log.error(" - Done loading peer records...") + self.host_to_records[peer.peer_id] = peer + log.info(" - Done loading peer records...") # Crawler -> DNS. - async def load_reliable_peers_to_db(self): + async def load_reliable_peers_to_db(self) -> None: peers = [] - for peer_id in self.host_to_reliability: - reliability = self.host_to_reliability[peer_id] + for peer_id, reliability in self.host_to_reliability.items(): if reliability.is_reliable(): peers.append(peer_id) self.reliable_peers = len(peers) - log.error("Deleting old good_peers from DB...") + log.info("Deleting old good_peers from DB...") cursor = await self.crawl_db.execute( "DELETE from good_peers", ) await cursor.close() - log.error(" - Done deleting old good_peers...") - log.error("Saving new good_peers to DB...") + log.info(" - Done deleting old good_peers...") + log.info("Saving new good_peers to DB...") for peer in peers: cursor = await self.crawl_db.execute( "INSERT OR REPLACE INTO good_peers VALUES(?)", @@ -349,9 +340,9 @@ class CrawlStore: ) await cursor.close() await self.crawl_db.commit() - log.error(" - Done saving new good_peers to DB...") + log.info(" - Done saving new good_peers to DB...") - def load_host_to_version(self): + def load_host_to_version(self) -> tuple[dict[str, str], dict[str, uint64]]: versions = {} handshake = {} @@ -366,19 +357,30 @@ class CrawlStore: versions[host] = record.version handshake[host] = record.handshake_time - return (versions, handshake) + return versions, handshake - def load_best_peer_reliability(self): + def load_best_peer_reliability(self) -> dict[str, uint64]: best_timestamp = {} for host, record in self.host_to_records.items(): if record.best_timestamp > time.time() - 5 * 24 * 3600: best_timestamp[host] = record.best_timestamp return best_timestamp - async def update_version(self, host, version, now): + async def update_version(self, host: str, version: str, timestamp_now: uint64) -> None: record = self.host_to_records.get(host, None) reliability = self.host_to_reliability.get(host, None) if record is None or reliability is None: return - record.update_version(version, now) + record.update_version(version, timestamp_now) await self.add_peer(record, reliability) + + async def get_good_peers(self) -> list[str]: # This is for the DNS server + cursor = await self.crawl_db.execute( + "SELECT * from good_peers", + ) + rows = await cursor.fetchall() + await cursor.close() + result = [row[0] for row in rows] + if len(result) > 0: + random.shuffle(result) # mix up the peers + return result diff --git a/chia/seeder/crawler.py b/chia/seeder/crawler.py index bced767578ee..9f5265eeec7c 100644 --- a/chia/seeder/crawler.py +++ b/chia/seeder/crawler.py @@ -6,15 +6,16 @@ import logging import time import traceback from collections import defaultdict +from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple import aiosqlite from chia.consensus.constants import ConsensusConstants -from chia.full_node.coin_store import CoinStore from chia.full_node.full_node_api import FullNodeAPI from chia.protocols import full_node_protocol +from chia.protocols.full_node_protocol import RespondPeers from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections from chia.seeder.crawl_store import CrawlStore from chia.seeder.peer_record import PeerRecord, PeerReliability @@ -29,20 +30,28 @@ from chia.util.path import path_from_root log = logging.getLogger(__name__) +@dataclass class Crawler: - sync_store: Any - coin_store: CoinStore - connection: Optional[aiosqlite.Connection] - config: Dict - _server: Optional[ChiaServer] - crawl_store: Optional[CrawlStore] - log: logging.Logger - constants: ConsensusConstants - _shut_down: bool + config: Dict[str, Any] root_path: Path - peer_count: int - with_peak: set - minimum_version_count: int + constants: ConsensusConstants + print_status: bool = True + state_changed_callback: Optional[StateChangedProtocol] = None + _server: Optional[ChiaServer] = None + crawl_task: Optional[asyncio.Task[None]] = None + crawl_store: Optional[CrawlStore] = None + log: logging.Logger = log + _shut_down: bool = False + peer_count: int = 0 + with_peak: Set[PeerInfo] = field(default_factory=set) + seen_nodes: Set[str] = field(default_factory=set) + minimum_version_count: int = 0 + peers_retrieved: List[RespondPeers] = field(default_factory=list) + host_to_version: Dict[str, str] = field(default_factory=dict) + versions: Dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0)) + version_cache: List[Tuple[str, str]] = field(default_factory=list) + handshake_time: Dict[str, uint64] = field(default_factory=dict) + best_timestamp_per_peer: Dict[str, uint64] = field(default_factory=dict) @property def server(self) -> ChiaServer: @@ -53,38 +62,16 @@ class Crawler: return self._server - def __init__( - self, - config: Dict, - root_path: Path, - consensus_constants: ConsensusConstants, - name: str = None, - ): - self.initialized = False - self.root_path = root_path - self.config = config - self.connection = None - self._server = None - self._shut_down = False # Set to true to close all infinite loops - self.constants = consensus_constants - self.state_changed_callback: Optional[StateChangedProtocol] = None - self.crawl_store = None - self.log = log - self.peer_count = 0 - self.with_peak = set() - self.peers_retrieved: List[Any] = [] - self.host_to_version: Dict[str, str] = {} - self.version_cache: List[Tuple[str, str]] = [] - self.handshake_time: Dict[str, int] = {} - self.best_timestamp_per_peer: Dict[str, int] = {} - crawler_db_path: str = config.get("crawler_db_path", "crawler.db") - self.db_path = path_from_root(root_path, crawler_db_path) + def __post_init__(self) -> None: + # get db path + crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db") + self.db_path = path_from_root(self.root_path, crawler_db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) - self.bootstrap_peers = config["bootstrap_peers"] - self.minimum_height = config["minimum_height"] - self.other_peers_port = config["other_peers_port"] - self.versions: Dict[str, int] = defaultdict(lambda: 0) - self.minimum_version_count = self.config.get("minimum_version_count", 100) + # load data from config + self.bootstrap_peers = self.config["bootstrap_peers"] + self.minimum_height = self.config["minimum_height"] + self.other_peers_port = self.config["other_peers_port"] + self.minimum_version_count: int = self.config.get("minimum_version_count", 100) if self.minimum_version_count < 1: self.log.warning( f"Crawler configuration minimum_version_count expected to be greater than zero: " @@ -97,11 +84,16 @@ class Crawler: def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]: return default_get_connections(server=self.server, request_node_type=request_node_type) - async def create_client(self, peer_info, on_connect): + async def create_client( + self, peer_info: PeerInfo, on_connect: Callable[[WSChiaConnection], Awaitable[None]] + ) -> bool: return await self.server.start_client(peer_info, on_connect) - async def connect_task(self, peer): - async def peer_action(peer: WSChiaConnection): + async def connect_task(self, peer: PeerRecord) -> None: + if self.crawl_store is None: + raise ValueError("Not Connected to DB") + + async def peer_action(peer: WSChiaConnection) -> None: peer_info = peer.get_peer_info() version = peer.get_version() if peer_info is not None and version is not None: @@ -134,41 +126,34 @@ class Crawler: if not connected: await self.crawl_store.peer_failed_to_connect(peer) except Exception as e: - self.log.info(f"Exception: {e}. Traceback: {traceback.format_exc()}.") + self.log.warning(f"Exception: {e}. Traceback: {traceback.format_exc()}.") await self.crawl_store.peer_failed_to_connect(peer) - async def _start(self): + async def _start(self) -> None: # We override the default peer_connect_timeout when running from the crawler crawler_peer_timeout = self.config.get("peer_connect_timeout", 2) self.server.config["peer_connect_timeout"] = crawler_peer_timeout - self.task = asyncio.create_task(self.crawl()) - - async def crawl(self): - # Ensure the state_changed callback is set up before moving on - # Sometimes, the daemon connection + state changed callback isn't up and ready - # by the time we get to the first _state_changed call, so this just ensures it's there before moving on - while self.state_changed_callback is None: - self.log.info("Waiting for state changed callback...") - await asyncio.sleep(0.1) + # Connect to the DB + self.crawl_store: CrawlStore = await CrawlStore.create(await aiosqlite.connect(self.db_path)) + # Bootstrap the initial peers + await self.load_bootstrap_peers() + self.crawl_task = asyncio.create_task(self.crawl()) + async def load_bootstrap_peers(self) -> None: + assert self.crawl_store is not None try: - self.connection = await aiosqlite.connect(self.db_path) - self.crawl_store = await CrawlStore.create(self.connection) - self.log.info("Started") + self.log.warning("Bootstrapping initial peers...") t_start = time.time() - total_nodes = 0 - self.seen_nodes = set() - tried_nodes = set() for peer in self.bootstrap_peers: new_peer = PeerRecord( peer, peer, self.other_peers_port, False, - 0, - 0, - 0, + uint64(0), + uint32(0), + uint64(0), uint64(int(time.time())), uint64(0), "undefined", @@ -180,14 +165,26 @@ class Crawler: self.host_to_version, self.handshake_time = self.crawl_store.load_host_to_version() self.best_timestamp_per_peer = self.crawl_store.load_best_peer_reliability() - self.versions = defaultdict(lambda: 0) for host, version in self.host_to_version.items(): self.versions[version] += 1 - self._state_changed("loaded_initial_peers") + self.log.warning(f"Bootstrapped initial peers in {time.time() - t_start} seconds") + except Exception as e: + self.log.error(f"Error bootstrapping initial peers: {e}") - while True: - self.with_peak = set() + async def crawl(self) -> None: + # Ensure the state_changed callback is set up before moving on + # Sometimes, the daemon connection + state changed callback isn't up and ready + # by the time we get to the first _state_changed call, so this just ensures it's there before moving on + while self.state_changed_callback is None: + self.log.info("Waiting for state changed callback...") + await asyncio.sleep(0.1) + assert self.crawl_store is not None + t_start = time.time() + total_nodes = 0 + tried_nodes = set() + try: + while not self._shut_down: peers_to_crawl = await self.crawl_store.get_peers_to_crawl(25000, 250000) tasks = set() for peer in peers_to_crawl: @@ -200,7 +197,6 @@ class Crawler: if len(tasks) >= 250: await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) tasks = set(filter(lambda t: not t.done(), tasks)) - if len(tasks) > 0: await asyncio.wait(tasks, timeout=30) @@ -231,16 +227,15 @@ class Crawler: tls_version="unknown", ) new_peer_reliability = PeerReliability(response_peer.host) - if self.crawl_store is not None: - self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability) + self.crawl_store.maybe_add_peer(new_peer, new_peer_reliability) await self.crawl_store.update_best_timestamp( response_peer.host, self.best_timestamp_per_peer[response_peer.host], ) for host, version in self.version_cache: - self.handshake_time[host] = int(time.time()) + self.handshake_time[host] = uint64(time.time()) self.host_to_version[host] = version - await self.crawl_store.update_version(host, version, int(time.time())) + await self.crawl_store.update_version(host, version, uint64(time.time())) to_remove = set() now = int(time.time()) @@ -264,94 +259,46 @@ class Crawler: self.versions = defaultdict(lambda: 0) for host, version in self.host_to_version.items(): self.versions[version] += 1 - self.version_cache = [] - self.peers_retrieved = [] + # clear caches + self.version_cache: List[Tuple[str, str]] = [] + self.peers_retrieved = [] self.server.banned_peers = {} + self.with_peak = set() + if len(peers_to_crawl) == 0: continue - # Try up to 5 times to write to the DB in case there is a lock that causes a timeout - for i in range(1, 5): - try: - await self.crawl_store.load_to_db() - await self.crawl_store.load_reliable_peers_to_db() - except Exception as e: - self.log.error(f"Exception while saving to DB: {e}.") - self.log.error("Waiting 5 seconds before retry...") - await asyncio.sleep(5) - continue - break - total_records = self.crawl_store.get_total_records() - ipv6_count = self.crawl_store.get_ipv6_peers() - self.log.error("***") - self.log.error("Finished batch:") - self.log.error(f"Total IPs stored in DB: {total_records}.") - self.log.error(f"Total IPV6 addresses stored in DB: {ipv6_count}") - self.log.error(f"Total connections attempted since crawler started: {total_nodes}.") - self.log.error(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.") - t_now = time.time() - t_delta = int(t_now - t_start) - if t_delta > 0: - self.log.error(f"Avg connections per second: {total_nodes // t_delta}.") - # Periodically print detailed stats. - reliable_peers = self.crawl_store.get_reliable_peers() - self.log.error(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}") - banned_peers = self.crawl_store.get_banned_peers() - ignored_peers = self.crawl_store.get_ignored_peers() - available_peers = len(self.host_to_version) - addresses_count = len(self.best_timestamp_per_peer) - total_records = self.crawl_store.get_total_records() - ipv6_addresses_count = 0 - for host in self.best_timestamp_per_peer.keys(): - try: - ipaddress.IPv6Address(host) - ipv6_addresses_count += 1 - except ipaddress.AddressValueError: - continue - self.log.error( - "IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: " - f"{addresses_count - ipv6_addresses_count}." - ) - self.log.error( - "IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: " - f"{ipv6_addresses_count}." - ) - ipv6_available_peers = 0 - for host in self.host_to_version.keys(): - try: - ipaddress.IPv6Address(host) - ipv6_available_peers += 1 - except ipaddress.AddressValueError: - continue - self.log.error( - f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}." - ) - self.log.error(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.") - self.log.error("Version distribution among reachable in the last 5 days (at least 100 nodes):") - for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True): - if count >= self.minimum_version_count: - self.log.error(f"Version: {version} - Count: {count}") - self.log.error(f"Banned addresses in the DB: {banned_peers}") - self.log.error(f"Temporary ignored addresses in the DB: {ignored_peers}") - self.log.error( - "Peers to crawl from in the next batch (total IPs - ignored - banned): " - f"{total_records - banned_peers - ignored_peers}" - ) - self.log.error("***") - + await self.save_to_db() + await self.print_summary(t_start, total_nodes, tried_nodes) + await asyncio.sleep(15) # 15 seconds between db updates self._state_changed("crawl_batch_completed") except Exception as e: self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") - def set_server(self, server: ChiaServer): + async def save_to_db(self) -> None: + # Try up to 5 times to write to the DB in case there is a lock that causes a timeout + if self.crawl_store is None: + raise ValueError("Not Connected to DB") + for i in range(1, 5): + try: + await self.crawl_store.load_to_db() + await self.crawl_store.load_reliable_peers_to_db() + return + except Exception as e: + self.log.error(f"Exception while saving to DB: {e}.") + self.log.error("Waiting 5 seconds before retry...") + await asyncio.sleep(5) + continue + + def set_server(self, server: ChiaServer) -> None: self._server = server - def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None): + def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None: if self.state_changed_callback is not None: self.state_changed_callback(change, change_data) - async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection): + async def new_peak(self, request: full_node_protocol.NewPeak, peer: WSChiaConnection) -> None: try: peer_info = peer.get_peer_info() tls_version = peer.get_tls_version() @@ -359,6 +306,11 @@ class Crawler: tls_version = "unknown" if peer_info is None: return + # validate peer ip address: + try: + ipaddress.ip_address(peer_info.host) + except ValueError: + raise ValueError(f"Invalid peer ip address: {peer_info.host}") if request.height >= self.minimum_height: if self.crawl_store is not None: await self.crawl_store.peer_connected_hostname(peer_info.host, True, tls_version) @@ -366,12 +318,79 @@ class Crawler: except Exception as e: self.log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") - async def on_connect(self, connection: WSChiaConnection): + async def on_connect(self, connection: WSChiaConnection) -> None: pass - def _close(self): + def _close(self) -> None: self._shut_down = True - async def _await_closed(self): - if self.connection is not None: - await self.connection.close() + async def _await_closed(self) -> None: + if self.crawl_task is not None: + try: + await asyncio.wait_for(self.crawl_task, timeout=10) # wait 10 seconds before giving up + except asyncio.TimeoutError: + self.log.error("Crawl task did not exit in time, killing task.") + self.crawl_task.cancel() + if self.crawl_store is not None: + self.log.info("Closing connection to DB.") + await self.crawl_store.crawl_db.close() + + async def print_summary(self, t_start: float, total_nodes: int, tried_nodes: Set[str]) -> None: + assert self.crawl_store is not None # this is only ever called from the crawl task + if not self.print_status: + return + total_records = self.crawl_store.get_total_records() + ipv6_count = self.crawl_store.get_ipv6_peers() + self.log.warning("***") + self.log.warning("Finished batch:") + self.log.warning(f"Total IPs stored in DB: {total_records}.") + self.log.warning(f"Total IPV6 addresses stored in DB: {ipv6_count}") + self.log.warning(f"Total connections attempted since crawler started: {total_nodes}.") + self.log.warning(f"Total unique nodes attempted since crawler started: {len(tried_nodes)}.") + t_now = time.time() + t_delta = int(t_now - t_start) + if t_delta > 0: + self.log.warning(f"Avg connections per second: {total_nodes // t_delta}.") + # Periodically print detailed stats. + reliable_peers = self.crawl_store.get_reliable_peers() + self.log.warning(f"High quality reachable nodes, used by DNS introducer in replies: {reliable_peers}") + banned_peers = self.crawl_store.get_banned_peers() + ignored_peers = self.crawl_store.get_ignored_peers() + available_peers = len(self.host_to_version) + addresses_count = len(self.best_timestamp_per_peer) + total_records = self.crawl_store.get_total_records() + ipv6_addresses_count = 0 + for host in self.best_timestamp_per_peer.keys(): + try: + ipaddress.IPv6Address(host) + ipv6_addresses_count += 1 + except ipaddress.AddressValueError: + continue + self.log.warning( + "IPv4 addresses gossiped with timestamp in the last 5 days with respond_peers messages: " + f"{addresses_count - ipv6_addresses_count}." + ) + self.log.warning( + "IPv6 addresses gossiped with timestamp in the last 5 days with respond_peers messages: " + f"{ipv6_addresses_count}." + ) + ipv6_available_peers = 0 + for host in self.host_to_version.keys(): + try: + ipaddress.IPv6Address(host) + ipv6_available_peers += 1 + except ipaddress.AddressValueError: + continue + self.log.warning(f"Total IPv4 nodes reachable in the last 5 days: {available_peers - ipv6_available_peers}.") + self.log.warning(f"Total IPv6 nodes reachable in the last 5 days: {ipv6_available_peers}.") + self.log.warning("Version distribution among reachable in the last 5 days (at least 100 nodes):") + for version, count in sorted(self.versions.items(), key=lambda kv: kv[1], reverse=True): + if count >= self.minimum_version_count: + self.log.warning(f"Version: {version} - Count: {count}") + self.log.warning(f"Banned addresses in the DB: {banned_peers}") + self.log.warning(f"Temporary ignored addresses in the DB: {ignored_peers}") + self.log.warning( + "Peers to crawl from in the next batch (total IPs - ignored - banned): " + f"{total_records - banned_peers - ignored_peers}" + ) + self.log.warning("***") diff --git a/chia/seeder/dns_server.py b/chia/seeder/dns_server.py index fba7692b20af..bcb32d4a2cea 100644 --- a/chia/seeder/dns_server.py +++ b/chia/seeder/dns_server.py @@ -1,294 +1,534 @@ from __future__ import annotations import asyncio -import ipaddress import logging -import random import signal +import sys import traceback +from asyncio import AbstractEventLoop +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from ipaddress import IPv4Address, IPv6Address, ip_address +from multiprocessing import freeze_support from pathlib import Path -from typing import Any, Dict, List +from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional import aiosqlite -from dnslib import AAAA, CNAME, MX, NS, QTYPE, RR, SOA, A, DNSHeader, DNSRecord +from dnslib import AAAA, EDNS0, NS, QTYPE, RCODE, RD, RR, SOA, A, DNSError, DNSHeader, DNSQuestion, DNSRecord -from chia.util.chia_logging import initialize_logging -from chia.util.config import load_config +from chia.seeder.crawl_store import CrawlStore +from chia.util.chia_logging import initialize_service_logging +from chia.util.config import load_config, load_config_cli from chia.util.default_root import DEFAULT_ROOT_PATH from chia.util.path import path_from_root SERVICE_NAME = "seeder" log = logging.getLogger(__name__) +DnsCallback = Callable[[DNSRecord], Awaitable[DNSRecord]] + # DNS snippet taken from: https://gist.github.com/pklaus/b5a7876d4d2cf7271873 class DomainName(str): - def __getattr__(self, item): - return DomainName(item + "." + self) + def __getattr__(self, item: str) -> DomainName: + return DomainName(item + "." + self) # DomainName.NS becomes DomainName("NS.DomainName") -D = None -ns = None -IP = "127.0.0.1" -TTL = None -soa_record = None -ns_records: List[Any] = [] +@dataclass(frozen=True) +class PeerList: + ipv4: List[IPv4Address] + ipv6: List[IPv6Address] + + @property + def no_peers(self) -> bool: + return not self.ipv4 and not self.ipv6 -class EchoServerProtocol(asyncio.DatagramProtocol): - def __init__(self, callback): - self.data_queue = asyncio.Queue() - self.callback = callback - asyncio.ensure_future(self.respond()) +@dataclass +class UDPDNSServerProtocol(asyncio.DatagramProtocol): + """ + This is a really simple UDP Server, that converts all requests to DNSRecord objects and passes them to the callback. + """ - def connection_made(self, transport): - self.transport = transport + callback: DnsCallback + transport: Optional[asyncio.DatagramTransport] = field(init=False, default=None) + data_queue: asyncio.Queue[tuple[DNSRecord, tuple[str, int]]] = field(default_factory=asyncio.Queue) + queue_task: Optional[asyncio.Task[None]] = field(init=False, default=None) - def datagram_received(self, data, addr): - asyncio.ensure_future(self.handler(data, addr)) + def start(self) -> None: + self.queue_task = asyncio.create_task(self.respond()) # This starts the dns respond loop. - async def respond(self): - while True: + async def stop(self) -> None: + if self.queue_task is not None: + self.queue_task.cancel() try: - resp, caller = await self.data_queue.get() - self.transport.sendto(resp, caller) + await self.queue_task + except asyncio.CancelledError: # we dont care + pass + if self.transport is not None: + self.transport.close() + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + # we use the #ignore because transport is a subclass of BaseTransport, but we need the real type. + self.transport = transport # type: ignore[assignment] + + def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: + log.debug(f"Received UDP DNS request from {addr}.") + dns_request: Optional[DNSRecord] = parse_dns_request(data) + if dns_request is None: # Invalid Request, we can just drop it and move on. + return + asyncio.create_task(self.handler(dns_request, addr)) + + async def respond(self) -> None: + log.info("UDP DNS responder started.") + while self.transport is None: # we wait for the transport to be set. + await asyncio.sleep(0.1) + while not self.transport.is_closing(): + try: + edns_max_size = 0 + reply, caller = await self.data_queue.get() + if len(reply.ar) > 0 and reply.ar[0].rtype == QTYPE.OPT: + edns_max_size = reply.ar[0].edns_len + + reply_packed = reply.pack() + + if len(reply_packed) > max(512, edns_max_size): # 512 is the default max size for DNS: + log.debug(f"DNS response to {caller} is too large, truncating.") + reply_packed = reply.truncate().pack() + + self.transport.sendto(reply_packed, caller) + log.debug(f"Sent UDP DNS response to {caller}, of size {len(reply_packed)}.") except Exception as e: - log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") + log.error(f"Exception while responding to UDP DNS request: {e}. Traceback: {traceback.format_exc()}.") + log.info("UDP DNS responder stopped.") - async def handler(self, data, caller): + async def handler(self, data: DNSRecord, caller: tuple[str, int]) -> None: + r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response. + await self.data_queue.put((r_data, caller)) + + +@dataclass +class TCPDNSServerProtocol(asyncio.BufferedProtocol): + """ + This TCP server is a little more complicated, because we need to handle the length field, however it still + converts all requests to DNSRecord objects and passes them to the callback, after receiving the full message. + """ + + callback: DnsCallback + transport: Optional[asyncio.Transport] = field(init=False, default=None) + peer_info: str = field(init=False, default="") + expected_length: int = 0 + buffer: bytearray = field(init=False, default_factory=lambda: bytearray(2)) + futures: List[asyncio.Future[None]] = field(init=False, default_factory=list) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """ + This is called whenever we get a new connection. + """ + # we use the #ignore because transport is a subclass of BaseTransport, but we need the real type. + self.transport = transport # type: ignore[assignment] + peer_info = transport.get_extra_info("peername") + self.peer_info = f"{peer_info[0]}:{peer_info[1]}" + log.debug(f"TCP connection established with {self.peer_info}.") + + def connection_lost(self, exc: Optional[Exception]) -> None: + """ + This is called whenever a connection is lost, or closed. + """ + if exc is not None: + log.debug(f"TCP DNS connection lost with {self.peer_info}. Exception: {exc}.") + else: + log.debug(f"TCP DNS connection closed with {self.peer_info}.") + # reset the state of the protocol. + for future in self.futures: + future.cancel() + self.futures = [] + self.buffer = bytearray(2) + self.expected_length = 0 + + def get_buffer(self, sizehint: int) -> memoryview: + """ + This is the first function called after connection_made, it returns a buffer that the tcp server will write to. + Once a buffer is written to, buffer_updated is called. + """ + return memoryview(self.buffer) + + def buffer_updated(self, nbytes: int) -> None: + """ + This is called whenever the buffer is written to, and it loops through the buffer, grouping them into messages + and then dns records. + """ + while not len(self.buffer) == 0 and self.transport is not None: + if not self.expected_length: + # Length field received (This is the first part of the message) + self.expected_length = int.from_bytes(self.buffer, byteorder="big") + self.buffer = self.buffer[2:] # Remove the length field from the buffer. + + if len(self.buffer) >= self.expected_length: + # This is the rest of the message (after the length field) + message = self.buffer[: self.expected_length] + self.buffer = self.buffer[self.expected_length :] # Remove the message from the buffer + self.expected_length = 0 # Reset the expected length + + dns_request: Optional[DNSRecord] = parse_dns_request(message) + if dns_request is None: # Invalid Request, so we disconnect and don't send anything back. + self.transport.close() + return + self.futures.append(asyncio.create_task(self.handle_and_respond(dns_request))) + + self.buffer = bytearray(2 if self.expected_length == 0 else self.expected_length) # Reset the buffer if empty. + + def eof_received(self) -> Optional[bool]: + """ + This is called when the client closes the connection, False or None means we close the connection. + True means we keep the connection open. + """ + if len(self.futures) > 0: # Successful requests + if self.expected_length != 0: # Incomplete requests + log.warning( + f"Received incomplete TCP DNS request of length {self.expected_length} from {self.peer_info}, " + f"closing connection after dns replies are sent." + ) + asyncio.create_task(self.wait_for_futures()) + return True # Keep connection open, until the futures are done. + log.info(f"Received early EOF from {self.peer_info}, closing connection.") + return False + + async def wait_for_futures(self) -> None: + """ + Waits for all the futures to complete, and then closes the connection. + """ try: - data = await self.callback(data) - if data is None: - return - await self.data_queue.put((data, caller)) + await asyncio.wait_for(asyncio.gather(*self.futures), timeout=10) + except asyncio.TimeoutError: + log.warning(f"Timed out waiting for DNS replies to be sent to {self.peer_info}.") + if self.transport is not None: + self.transport.close() + + async def handle_and_respond(self, data: DNSRecord) -> None: + r_data = await get_dns_reply(self.callback, data) # process the request, returning a DNSRecord response. + try: + # If the client closed the connection, we don't want to send anything. + if self.transport is not None and not self.transport.is_closing(): + self.transport.write(dns_response_to_tcp(r_data)) # send data back to the client + log.debug(f"Sent DNS response for {data.q.qname}, to {self.peer_info}.") except Exception as e: - log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") + log.error(f"Exception while responding to TCP DNS request: {e}. Traceback: {traceback.format_exc()}.") +def dns_response_to_tcp(data: DNSRecord) -> bytes: + """ + Converts a DNSRecord response to a TCP DNS response, by adding a 2 byte length field to the start. + """ + dns_response = data.pack() + dns_response_length = len(dns_response).to_bytes(2, byteorder="big") + return bytes(dns_response_length + dns_response) + + +def create_dns_reply(dns_request: DNSRecord) -> DNSRecord: + """ + Creates a DNS response with the correct header and section flags set. + """ + # QR means query response, AA means authoritative answer, RA means recursion available + return DNSRecord(DNSHeader(id=dns_request.header.id, qr=1, aa=1, ra=0), q=dns_request.q) + + +def parse_dns_request(data: bytes) -> Optional[DNSRecord]: + """ + Parses the DNS request, and returns a DNSRecord object, or None if the request is invalid. + """ + dns_request: Optional[DNSRecord] = None + try: + dns_request = DNSRecord.parse(data) + except DNSError as e: + log.warning(f"Received invalid DNS request: {e}. Traceback: {traceback.format_exc()}.") + return dns_request + + +async def get_dns_reply(callback: DnsCallback, dns_request: DNSRecord) -> DNSRecord: + """ + This function calls the callback, and returns SERVFAIL if the callback raises an exception. + """ + try: + dns_reply = await callback(dns_request) + except Exception as e: + log.error(f"Exception during DNS record processing: {e}. Traceback: {traceback.format_exc()}.") + # we return an empty response with an error code + dns_reply = create_dns_reply(dns_request) # This is an empty response, with only the header set. + dns_reply.header.rcode = RCODE.SERVFAIL + return dns_reply + + +@dataclass class DNSServer: - reliable_peers_v4: List[str] - reliable_peers_v6: List[str] - lock: asyncio.Lock - pointer: int - crawl_db: aiosqlite.Connection + config: Dict[str, Any] + root_path: Path + lock: asyncio.Lock = field(default_factory=asyncio.Lock) + shutdown_event: asyncio.Event = field(default_factory=asyncio.Event) + crawl_store: Optional[CrawlStore] = field(init=False, default=None) + reliable_task: Optional[asyncio.Task[None]] = field(init=False, default=None) + shutdown_task: Optional[asyncio.Task[None]] = field(init=False, default=None) + udp_transport_ipv4: Optional[asyncio.DatagramTransport] = field(init=False, default=None) + udp_protocol_ipv4: Optional[UDPDNSServerProtocol] = field(init=False, default=None) + udp_transport_ipv6: Optional[asyncio.DatagramTransport] = field(init=False, default=None) + udp_protocol_ipv6: Optional[UDPDNSServerProtocol] = field(init=False, default=None) + # TODO: After 3.10 is dropped change to asyncio.Server + tcp_server: Optional[asyncio.base_events.Server] = field(init=False, default=None) + # these are all set in __post_init__ + dns_port: int = field(init=False) + db_path: Path = field(init=False) + domain: DomainName = field(init=False) + ns1: DomainName = field(init=False) + ns_records: List[RR] = field(init=False) + ttl: int = field(init=False) + soa_record: RR = field(init=False) + reliable_peers_v4: List[IPv4Address] = field(default_factory=list) + reliable_peers_v6: List[IPv6Address] = field(default_factory=list) + pointer_v4: int = 0 + pointer_v6: int = 0 - def __init__(self, config: Dict, root_path: Path): - self.reliable_peers_v4 = [] - self.reliable_peers_v6 = [] - self.lock = asyncio.Lock() - self.pointer_v4 = 0 - self.pointer_v6 = 0 - - crawler_db_path: str = config.get("crawler_db_path", "crawler.db") - self.db_path = path_from_root(root_path, crawler_db_path) + def __post_init__(self) -> None: + """ + We initialize all the variables set to field(init=False) here. + """ + # From Config + self.dns_port: int = self.config.get("dns_port", 53) + # DB Path + crawler_db_path: str = self.config.get("crawler_db_path", "crawler.db") + self.db_path: Path = path_from_root(self.root_path, crawler_db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) - - async def start(self): - # self.crawl_db = await aiosqlite.connect(self.db_path) - # Get a reference to the event loop as we plan to use - # low-level APIs. - loop = asyncio.get_running_loop() - - # One protocol instance will be created to serve all - # client requests. - self.transport, self.protocol = await loop.create_datagram_endpoint( - lambda: EchoServerProtocol(self.dns_response), local_addr=("::0", 53) + # DNS info + self.domain: DomainName = DomainName(self.config["domain_name"]) + if not self.domain.endswith("."): + self.domain = DomainName(self.domain + ".") # Make sure the domain ends with a period, as per RFC 1035. + self.ns1: DomainName = DomainName(self.config["nameserver"]) + self.ns_records: List[NS] = [NS(self.ns1)] + self.ttl: int = self.config["ttl"] + self.soa_record: SOA = SOA( + mname=self.ns1, # primary name server + rname=self.config["soa"]["rname"], # email of the domain administrator + times=( + self.config["soa"]["serial_number"], + self.config["soa"]["refresh"], + self.config["soa"]["retry"], + self.config["soa"]["expire"], + self.config["soa"]["minimum"], + ), ) + + @asynccontextmanager + async def run(self) -> AsyncIterator[None]: + log.info("Starting DNS server.") + # Get a reference to the event loop as we plan to use low-level APIs. + loop = asyncio.get_running_loop() + await self.setup_signal_handlers(loop) + + # Set up the crawl store and the peer update task. + self.crawl_store = await CrawlStore.create(await aiosqlite.connect(self.db_path, timeout=120)) self.reliable_task = asyncio.create_task(self.periodically_get_reliable_peers()) - async def periodically_get_reliable_peers(self): + # One protocol instance will be created for each udp transport, so that we can accept ipv4 and ipv6 + self.udp_transport_ipv6, self.udp_protocol_ipv6 = await loop.create_datagram_endpoint( + lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("::0", self.dns_port) + ) + self.udp_protocol_ipv6.start() # start ipv6 udp transmit task + + # in case the port is 0 we need all protocols on the same port. + self.dns_port = self.udp_transport_ipv6.get_extra_info("sockname")[1] # get the actual port we are listening to + + if sys.platform.startswith("win32") or sys.platform.startswith("cygwin"): + # Windows does not support dual stack sockets, so we need to create a new socket for ipv4. + self.udp_transport_ipv4, self.udp_protocol_ipv4 = await loop.create_datagram_endpoint( + lambda: UDPDNSServerProtocol(self.dns_response), local_addr=("0.0.0.0", self.dns_port) + ) + self.udp_protocol_ipv4.start() # start ipv4 udp transmit task + + # One tcp server will handle both ipv4 and ipv6 on both linux and windows. + self.tcp_server = await loop.create_server( + lambda: TCPDNSServerProtocol(self.dns_response), ["::0", "0.0.0.0"], self.dns_port + ) + + log.info("DNS server started.") + try: + yield + finally: # catches any errors and properly shuts down the server + await self.stop() + log.info("DNS server stopped.") + + async def setup_signal_handlers(self, loop: AbstractEventLoop) -> None: + try: + loop.add_signal_handler(signal.SIGINT, self._accept_signal) + loop.add_signal_handler(signal.SIGTERM, self._accept_signal) + except NotImplementedError: + log.warning("signal handlers unsupported on this platform") + + def _accept_signal(self) -> None: # pragma: no cover + if self.shutdown_task is None: # otherwise we are already shutting down, so we ignore the signal + self.shutdown_task = asyncio.create_task(self.stop()) + + async def stop(self) -> None: + log.info("Stopping DNS server...") + if self.reliable_task is not None: + self.reliable_task.cancel() # cancel the peer update task + if self.crawl_store is not None: + await self.crawl_store.crawl_db.close() + if self.udp_protocol_ipv6 is not None: + await self.udp_protocol_ipv6.stop() # stop responding to and accepting udp requests (ipv6) & ipv4 if linux. + if self.udp_protocol_ipv4 is not None: + await self.udp_protocol_ipv4.stop() # stop responding to and accepting udp requests (ipv4) if windows. + if self.tcp_server is not None: + self.tcp_server.close() # stop accepting new tcp requests (ipv4 and ipv6) + await self.tcp_server.wait_closed() # wait for existing TCP requests to finish (ipv4 and ipv6) + self.shutdown_event.set() + + async def periodically_get_reliable_peers(self) -> None: sleep_interval = 0 - while True: + while not self.shutdown_event.is_set() and self.crawl_store is not None: try: - # TODO: double check this. It shouldn't take this long to connect. - crawl_db = await aiosqlite.connect(self.db_path, timeout=600) - cursor = await crawl_db.execute( - "SELECT * from good_peers", - ) - new_reliable_peers = [] - rows = await cursor.fetchall() - await cursor.close() - await crawl_db.close() - for row in rows: - new_reliable_peers.append(row[0]) - if len(new_reliable_peers) > 0: - random.shuffle(new_reliable_peers) - async with self.lock: - self.reliable_peers_v4 = [] - self.reliable_peers_v6 = [] - for peer in new_reliable_peers: - ipv4 = True - try: - _ = ipaddress.IPv4Address(peer) - except ValueError: - ipv4 = False - if ipv4: - self.reliable_peers_v4.append(peer) - else: - try: - _ = ipaddress.IPv6Address(peer) - except ValueError: - continue - self.reliable_peers_v6.append(peer) - self.pointer_v4 = 0 - self.pointer_v6 = 0 - log.error( + new_reliable_peers = await self.crawl_store.get_good_peers() + except Exception as e: + log.error(f"Error loading reliable peers from database: {e}. Traceback: {traceback.format_exc()}.") + continue + if len(new_reliable_peers) == 0: + log.info("No reliable peers found in database, waiting for db to be populated.") + await asyncio.sleep(2) # sleep for 2 seconds, because the db has not been populated yet. + continue + async with self.lock: + self.reliable_peers_v4 = [] + self.reliable_peers_v6 = [] + self.pointer_v4 = 0 + self.pointer_v6 = 0 + for peer in new_reliable_peers: + try: + validated_peer = ip_address(peer) + if validated_peer.version == 4: + self.reliable_peers_v4.append(validated_peer) + elif validated_peer.version == 6: + self.reliable_peers_v6.append(validated_peer) + except ValueError: + log.error(f"Invalid peer: {peer}") + continue + log.info( f"Number of reliable peers discovered in dns server:" f" IPv4 count - {len(self.reliable_peers_v4)}" f" IPv6 count - {len(self.reliable_peers_v6)}" ) - except Exception as e: - log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") - sleep_interval = min(15, sleep_interval + 1) await asyncio.sleep(sleep_interval * 60) - async def get_peers_to_respond(self, ipv4_count, ipv6_count): - peers = [] + async def get_peers_to_respond(self, ipv4_count: int, ipv6_count: int) -> PeerList: async with self.lock: # Append IPv4. + ipv4_peers: List[IPv4Address] = [] size = len(self.reliable_peers_v4) if ipv4_count > 0 and size <= ipv4_count: - peers = self.reliable_peers_v4 + ipv4_peers = self.reliable_peers_v4 elif ipv4_count > 0: - peers = [self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)] - self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size + ipv4_peers = [ + self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count) + ] + self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size # mark where we left off # Append IPv6. + ipv6_peers: List[IPv6Address] = [] size = len(self.reliable_peers_v6) if ipv6_count > 0 and size <= ipv6_count: - peers = peers + self.reliable_peers_v6 + ipv6_peers = self.reliable_peers_v6 elif ipv6_count > 0: - peers = peers + [ + ipv6_peers = [ self.reliable_peers_v6[i % size] for i in range(self.pointer_v6, self.pointer_v6 + ipv6_count) ] - self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size - return peers + self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size # mark where we left off + return PeerList(ipv4_peers, ipv6_peers) - async def dns_response(self, data): - try: - request = DNSRecord.parse(data) - IPs = [MX(D.mail), soa_record] + ns_records - ipv4_count = 0 - ipv6_count = 0 - if request.q.qtype == 1: - ipv4_count = 32 - elif request.q.qtype == 28: - ipv6_count = 32 - elif request.q.qtype == 255: - ipv4_count = 16 - ipv6_count = 16 - else: - ipv4_count = 32 - peers = await self.get_peers_to_respond(ipv4_count, ipv6_count) - if len(peers) == 0: - return None - for peer in peers: - ipv4 = True - try: - _ = ipaddress.IPv4Address(peer) - except ValueError: - ipv4 = False - if ipv4: - IPs.append(A(peer)) - else: - try: - _ = ipaddress.IPv6Address(peer) - except ValueError: - continue - IPs.append(AAAA(peer)) - reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=len(IPs), ra=1), q=request.q) + async def dns_response(self, request: DNSRecord) -> DNSRecord: + """ + This function is called when a DNS request is received, and it returns a DNS response. + It does not catch any errors as it is called from within a try-except block. + """ + reply = create_dns_reply(request) + dns_question: DNSQuestion = request.q # this is the question / request + question_type: int = dns_question.qtype # the type of the record being requested + qname = dns_question.qname # the name being queried / requested + # ADD EDNS0 to response if supported + if len(request.ar) > 0 and request.ar[0].rtype == QTYPE.OPT: # OPT Means EDNS + udp_len = min(4096, request.ar[0].edns_len) + edns_reply = EDNS0(udp_len=udp_len) + reply.add_ar(edns_reply) + # DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve + # transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00 + qname_str = str(qname).lower() + if qname_str != self.domain and not qname_str.endswith("." + self.domain): + # we don't answer for other domains (we have the not recursive bit set) + log.warning(f"Invalid request for {qname_str}, returning REFUSED.") + reply.header.rcode = RCODE.REFUSED + return reply - records = { - D: IPs, - D.ns1: [A(IP)], # MX and NS records must never point to a CNAME alias (RFC 2181 section 10.3) - D.ns2: [A(IP)], - D.mail: [A(IP)], - D.andrei: [CNAME(D)], - } + ttl: int = self.ttl + # we add these to the list as it will allow us to respond to ns and soa requests + ips: List[RD] = [self.soa_record] + self.ns_records + ipv4_count = 0 + ipv6_count = 0 + if question_type is QTYPE.A: + ipv4_count = 32 + elif question_type is QTYPE.AAAA: + ipv6_count = 32 + elif question_type is QTYPE.ANY: + ipv4_count = 16 + ipv6_count = 16 + else: + ipv4_count = 32 + peers: PeerList = await self.get_peers_to_respond(ipv4_count, ipv6_count) + if peers.no_peers: + log.error("No peers found, returning SOA and NS records only.") + ttl = 60 # 1 minute as we should have some peers very soon + # we always return the SOA and NS records, so we continue even if there are no peers + ips.extend([A(str(peer)) for peer in peers.ipv4]) + ips.extend([AAAA(str(peer)) for peer in peers.ipv6]) - qname = request.q.qname - # DNS labels are mixed case with DNS resolvers that implement the use of bit 0x20 to improve - # transaction identity. See https://datatracker.ietf.org/doc/html/draft-vixie-dnsext-dns0x20-00 - qn = str(qname).lower() - qtype = request.q.qtype - qt = QTYPE[qtype] - if qn == D or qn.endswith("." + D): - for name, rrs in records.items(): - if name == qn: - for rdata in rrs: - rqt = rdata.__class__.__name__ - if qt in ["*", rqt] or (qt == "ANY" and (rqt == "A" or rqt == "AAAA")): - reply.add_answer( - RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata) - ) + records: Dict[DomainName, List[RD]] = { # this is where we can add other records we want to serve + self.domain: ips, + } - for rdata in ns_records: - reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata)) - - reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record)) - - return reply.pack() - except Exception as e: - log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.") + valid_domain = False + for domain_name, domain_responses in records.items(): + if domain_name == qname_str: # if the dns name is the same as the requested name + valid_domain = True + for response in domain_responses: + rqt: int = getattr(QTYPE, response.__class__.__name__) + if question_type == rqt or question_type == QTYPE.ANY: + reply.add_answer(RR(rname=qname, rtype=rqt, rclass=1, ttl=ttl, rdata=response)) + if not valid_domain and len(reply.rr) == 0: # if we didn't find any records to return + reply.header.rcode = RCODE.NXDOMAIN + # always put nameservers and the SOA records + for nameserver in self.ns_records: + reply.add_auth(RR(rname=self.domain, rtype=QTYPE.NS, rclass=1, ttl=ttl, rdata=nameserver)) + reply.add_auth(RR(rname=self.domain, rtype=QTYPE.SOA, rclass=1, ttl=ttl, rdata=self.soa_record)) + return reply -async def serve_dns(config: Dict, root_path: Path): - dns_server = DNSServer(config, root_path) - await dns_server.start() - - # TODO: Make this cleaner? - while True: - await asyncio.sleep(3600) +async def run_dns_server(dns_server: DNSServer) -> None: # pragma: no cover + async with dns_server.run(): + await dns_server.shutdown_event.wait() # this is released on SIGINT or SIGTERM or any unhandled exception -async def kill_processes(): - # TODO: implement. - pass +def create_dns_server_service(config: Dict[str, Any], root_path: Path) -> DNSServer: + service_config = config[SERVICE_NAME] + + return DNSServer(service_config, root_path) -def signal_received(): - asyncio.create_task(kill_processes()) - - -async def async_main(config, root_path): - loop = asyncio.get_running_loop() - - try: - loop.add_signal_handler(signal.SIGINT, signal_received) - loop.add_signal_handler(signal.SIGTERM, signal_received) - except NotImplementedError: - log.info("signal handlers unsupported") - - await serve_dns(config, root_path) - - -def main(): +def main() -> None: # pragma: no cover + freeze_support() root_path = DEFAULT_ROOT_PATH - config = load_config(root_path, "config.yaml", SERVICE_NAME) - initialize_logging(SERVICE_NAME, config["logging"], root_path) - global D - global ns - global TTL - global soa_record - global ns_records - D = DomainName(config["domain_name"]) - ns = DomainName(config["nameserver"]) - TTL = config["ttl"] - soa_record = SOA( - mname=ns, # primary name server - rname=config["soa"]["rname"], # email of the domain administrator - times=( - config["soa"]["serial_number"], - config["soa"]["refresh"], - config["soa"]["retry"], - config["soa"]["expire"], - config["soa"]["minimum"], - ), - ) - ns_records = [NS(ns)] + # TODO: refactor to avoid the double load + config = load_config(DEFAULT_ROOT_PATH, "config.yaml") + service_config = load_config_cli(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME) + config[SERVICE_NAME] = service_config + initialize_service_logging(service_name=SERVICE_NAME, config=config) - asyncio.run(async_main(config=config, root_path=root_path)) + dns_server = create_dns_server_service(config, root_path) + asyncio.run(run_dns_server(dns_server)) if __name__ == "__main__": diff --git a/chia/seeder/peer_record.py b/chia/seeder/peer_record.py index dea90533875f..2bc97a3711d4 100644 --- a/chia/seeder/peer_record.py +++ b/chia/seeder/peer_record.py @@ -24,23 +24,19 @@ class PeerRecord(Streamable): handshake_time: uint64 tls_version: str - def update_version(self, version, now): + def update_version(self, version: str, now: uint64) -> None: if version != "undefined": object.__setattr__(self, "version", version) - object.__setattr__(self, "handshake_time", uint64(now)) + object.__setattr__(self, "handshake_time", now) +@dataclass class PeerStat: weight: float count: float reliability: float - def __init__(self, weight, count, reliability): - self.weight = weight - self.count = count - self.reliability = reliability - - def update(self, is_reachable: bool, age: int, tau: int): + def update(self, is_reachable: bool, age: int, tau: int) -> None: f = math.exp(-age / tau) self.reliability = self.reliability * f + (1.0 - f if is_reachable else 0.0) self.count = self.count * f + 1.0 @@ -81,7 +77,7 @@ class PeerReliability: stat_1m_reliability: float = 0.0, tries: int = 0, successes: int = 0, - ): + ) -> None: self.peer_id = peer_id self.ignore_till = ignore_till self.ban_till = ban_till @@ -132,7 +128,7 @@ class PeerReliability: return 2 * 3600 return 0 - def update(self, is_reachable: bool, age: int): + def update(self, is_reachable: bool, age: int) -> None: self.stat_2h.update(is_reachable, age, 2 * 3600) self.stat_8h.update(is_reachable, age, 8 * 3600) self.stat_1d.update(is_reachable, age, 24 * 3600) diff --git a/chia/seeder/start_crawler.py b/chia/seeder/start_crawler.py index ac28ca9393ca..48902c05ebff 100644 --- a/chia/seeder/start_crawler.py +++ b/chia/seeder/start_crawler.py @@ -4,7 +4,7 @@ import logging import pathlib import sys from multiprocessing import freeze_support -from typing import Dict, Optional +from typing import Any, Dict, Optional from chia.consensus.constants import ConsensusConstants from chia.consensus.default_constants import DEFAULT_CONSTANTS @@ -26,24 +26,25 @@ log = logging.getLogger(__name__) def create_full_node_crawler_service( root_path: pathlib.Path, - config: Dict, + config: Dict[str, Any], consensus_constants: ConsensusConstants, connect_to_daemon: bool = True, ) -> Service[Crawler, CrawlerAPI]: service_config = config[SERVICE_NAME] + crawler_config = service_config["crawler"] crawler = Crawler( service_config, root_path=root_path, - consensus_constants=consensus_constants, + constants=consensus_constants, ) api = CrawlerAPI(crawler) network_id = service_config["selected_network"] rpc_info: Optional[RpcInfo] = None - if service_config.get("start_rpc_server", True): - rpc_info = (CrawlerRpcApi, service_config.get("rpc_port", 8561)) + if crawler_config.get("start_rpc_server", True): + rpc_info = (CrawlerRpcApi, crawler_config.get("rpc_port", 8561)) return Service( root_path=root_path, diff --git a/chia/simulator/setup_services.py b/chia/simulator/setup_services.py index 66dea8f152eb..eb3af06fc5a7 100644 --- a/chia/simulator/setup_services.py +++ b/chia/simulator/setup_services.py @@ -24,6 +24,7 @@ from chia.introducer.introducer_api import IntroducerAPI from chia.protocols.shared_protocol import Capability, capabilities from chia.seeder.crawler import Crawler from chia.seeder.crawler_api import CrawlerAPI +from chia.seeder.dns_server import DNSServer, create_dns_server_service from chia.seeder.start_crawler import create_full_node_crawler_service from chia.server.start_farmer import create_farmer_service from chia.server.start_full_node import create_full_node_service @@ -32,15 +33,17 @@ from chia.server.start_introducer import create_introducer_service from chia.server.start_service import Service from chia.server.start_timelord import create_timelord_service from chia.server.start_wallet import create_wallet_service -from chia.simulator.block_tools import BlockTools +from chia.simulator.block_tools import BlockTools, test_constants from chia.simulator.keyring import TempKeyring +from chia.simulator.ssl_certs import get_next_nodes_certs_and_keys, get_next_private_ca_cert_and_key from chia.simulator.start_simulator import create_full_node_simulator_service +from chia.ssl.create_ssl import create_all_ssl from chia.timelord.timelord import Timelord from chia.timelord.timelord_api import TimelordAPI from chia.timelord.timelord_launcher import kill_processes, spawn_process from chia.types.peer_info import UnresolvedPeerInfo from chia.util.bech32m import encode_puzzle_hash -from chia.util.config import config_path_for_filename, lock_and_load_config, save_config +from chia.util.config import config_path_for_filename, load_config, lock_and_load_config, save_config from chia.util.ints import uint16 from chia.util.keychain import bytes_to_mnemonic from chia.util.lock import Lockfile @@ -175,24 +178,36 @@ async def setup_full_node( async def setup_crawler( - bt: BlockTools, + root_path_populated_with_config: Path, database_uri: str ) -> AsyncGenerator[Service[Crawler, CrawlerAPI], None]: - config = bt.config + create_all_ssl( + root_path=root_path_populated_with_config, + private_ca_crt_and_key=get_next_private_ca_cert_and_key().collateral.cert_and_key, + node_certs_and_keys=get_next_nodes_certs_and_keys().collateral.certs_and_keys, + ) + config = load_config(root_path_populated_with_config, "config.yaml") service_config = config["seeder"] + service_config["selected_network"] = "testnet0" service_config["port"] = 0 - service_config["start_rpc_server"] = False + service_config["crawler"]["start_rpc_server"] = False + service_config["other_peers_port"] = 58444 + service_config["crawler_db_path"] = database_uri + overrides = service_config["network_overrides"]["constants"][service_config["selected_network"]] - updated_constants = bt.constants.replace_str_to_bytes(**overrides) - bt.change_config(config) + updated_constants = test_constants.replace_str_to_bytes(**overrides) + service = create_full_node_crawler_service( - bt.root_path, + root_path_populated_with_config, config, updated_constants, connect_to_daemon=False, ) await service.start() + if not service_config["crawler"]["start_rpc_server"]: # otherwise the loops don't work. + service._node.state_changed_callback = lambda x, y: None + try: yield service finally: @@ -200,6 +215,24 @@ async def setup_crawler( await service.wait_closed() +async def setup_seeder(root_path_populated_with_config: Path, database_uri: str) -> AsyncGenerator[DNSServer, None]: + config = load_config(root_path_populated_with_config, "config.yaml") + service_config = config["seeder"] + + service_config["selected_network"] = "testnet0" + if service_config["domain_name"].endswith("."): # remove the trailing . so that we can test that logic. + service_config["domain_name"] = service_config["domain_name"][:-1] + service_config["dns_port"] = 0 + service_config["crawler_db_path"] = database_uri + + service = create_dns_server_service( + config, + root_path_populated_with_config, + ) + async with service.run(): + yield service + + # Note: convert these setup functions to fixtures, or push it one layer up, # keeping these usable independently? async def setup_wallet_node( diff --git a/chia/util/initial-config.yaml b/chia/util/initial-config.yaml index 35eb316d5845..4eff825b7e85 100644 --- a/chia/util/initial-config.yaml +++ b/chia/util/initial-config.yaml @@ -139,6 +139,8 @@ seeder: port: 8444 # Most full nodes on the network run on this port. (i.e. 8444 for mainnet, 58444 for testnet). other_peers_port: 8444 + # What port to run the DNS server on, (this is useful if you are already using port 53 for DNS). + dns_port: 53 # This will override the default full_node.peer_connect_timeout for the crawler full node peer_connect_timeout: 2 # Path to crawler DB. Defaults to $CHIA_ROOT/crawler.db @@ -154,7 +156,7 @@ seeder: nameserver: "example.com." ttl: 300 soa: - rname: "hostmaster.example.com." + rname: "hostmaster.example.com" # all @ symbols need to be replaced with . in dns records. serial_number: 1619105223 refresh: 10800 retry: 10800 diff --git a/mypy-exclusions.txt b/mypy-exclusions.txt index 1fe94345d31b..3551933c108d 100644 --- a/mypy-exclusions.txt +++ b/mypy-exclusions.txt @@ -16,11 +16,6 @@ chia.rpc.rpc_client chia.rpc.util chia.rpc.wallet_rpc_api chia.rpc.wallet_rpc_client -chia.seeder.crawl_store -chia.seeder.crawler -chia.seeder.dns_server -chia.seeder.peer_record -chia.seeder.start_crawler chia.simulator.full_node_simulator chia.simulator.keyring chia.simulator.setup_services diff --git a/tests/conftest.py b/tests/conftest.py index c455686c1573..b3bd6c019b99 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,6 +4,7 @@ from __future__ import annotations import datetime import multiprocessing import os +import random import sysconfig import tempfile from enum import Enum @@ -32,6 +33,7 @@ from chia.rpc.harvester_rpc_client import HarvesterRpcClient from chia.rpc.wallet_rpc_client import WalletRpcClient from chia.seeder.crawler import Crawler from chia.seeder.crawler_api import CrawlerAPI +from chia.seeder.dns_server import DNSServer from chia.server.server import ChiaServer from chia.server.start_service import Service from chia.simulator.full_node_simulator import FullNodeSimulator @@ -44,7 +46,7 @@ from chia.simulator.setup_nodes import ( setup_simulators_and_wallets_service, setup_two_nodes, ) -from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_timelord +from chia.simulator.setup_services import setup_crawler, setup_daemon, setup_introducer, setup_seeder, setup_timelord from chia.simulator.time_out_assert import time_out_assert from chia.simulator.wallet_tools import WalletTool from chia.types.peer_info import PeerInfo @@ -872,11 +874,19 @@ async def timelord_service(bt): @pytest_asyncio.fixture(scope="function") -async def crawler_service(bt: BlockTools) -> AsyncIterator[Service[Crawler, CrawlerAPI]]: - async for service in setup_crawler(bt): +async def crawler_service( + root_path_populated_with_config: Path, database_uri: str +) -> AsyncIterator[Service[Crawler, CrawlerAPI]]: + async for service in setup_crawler(root_path_populated_with_config, database_uri): yield service +@pytest_asyncio.fixture(scope="function") +async def seeder_service(root_path_populated_with_config: Path, database_uri: str) -> AsyncIterator[DNSServer]: + async for seeder in setup_seeder(root_path_populated_with_config, database_uri): + yield seeder + + @pytest.fixture(scope="function") def tmp_chia_root(tmp_path): """ @@ -980,3 +990,8 @@ async def harvester_farmer_environment( harvester_rpc_cl.close() await farmer_rpc_cl.await_closed() await harvester_rpc_cl.await_closed() + + +@pytest.fixture(name="database_uri") +def database_uri_fixture() -> str: + return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared" diff --git a/tests/core/data_layer/conftest.py b/tests/core/data_layer/conftest.py index 4d7a8c767c33..8858bfa1f7b8 100644 --- a/tests/core/data_layer/conftest.py +++ b/tests/core/data_layer/conftest.py @@ -2,7 +2,6 @@ from __future__ import annotations import os import pathlib -import random import sys import time from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator @@ -54,11 +53,6 @@ def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32] return request.param # type: ignore[no-any-return] -@pytest.fixture(name="database_uri") -def database_uri_fixture() -> str: - return f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared" - - @pytest.fixture(name="tree_id", scope="function") def tree_id_fixture() -> bytes32: base = b"a tree id" diff --git a/tests/core/test_crawler.py b/tests/core/test_crawler.py index 6c80b9d600b5..67b79516bf0c 100644 --- a/tests/core/test_crawler.py +++ b/tests/core/test_crawler.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +import time from typing import cast import pytest @@ -11,13 +12,14 @@ from chia.protocols.protocol_message_types import ProtocolMessageTypes from chia.protocols.wallet_protocol import RequestChildren from chia.seeder.crawler import Crawler from chia.seeder.crawler_api import CrawlerAPI +from chia.seeder.peer_record import PeerRecord, PeerReliability from chia.server.outbound_message import make_msg from chia.server.start_service import Service from chia.simulator.setup_nodes import SimulatorsAndWalletsServices from chia.simulator.time_out_assert import time_out_assert from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.peer_info import PeerInfo -from chia.util.ints import uint16, uint32, uint128 +from chia.util.ints import uint16, uint32, uint64, uint128 @pytest.mark.asyncio @@ -68,3 +70,43 @@ async def test_valid_message( ) assert await connection.send_message(msg) await time_out_assert(10, peer_added) + + +@pytest.mark.asyncio +async def test_crawler_to_db( + crawler_service: Service[Crawler, CrawlerAPI], one_node: SimulatorsAndWalletsServices +) -> None: + """ + This is a lot more of an integration test, but it tests the whole process. We add a node to the crawler, then we + save it to the db and validate. + """ + [full_node_service], _, _ = one_node + full_node = full_node_service._node + crawler = crawler_service._node + crawl_store = crawler.crawl_store + assert crawl_store is not None + peer_address = "127.0.0.1" + + # create peer records + peer_record = PeerRecord( + peer_address, + peer_address, + uint32(full_node.server.get_port()), + False, + uint64(0), + uint32(0), + uint64(0), + uint64(int(time.time())), + uint64(0), + "undefined", + uint64(0), + tls_version="unknown", + ) + peer_reliability = PeerReliability(peer_address, tries=1, successes=1) + + # add peer to the db & mark it as connected + await crawl_store.add_peer(peer_record, peer_reliability) + assert peer_record == crawl_store.host_to_records[peer_address] + + # validate the db data + await time_out_assert(20, crawl_store.get_good_peers, [peer_address]) diff --git a/tests/core/test_crawler_rpc.py b/tests/core/test_crawler_rpc.py index 701d9e547c3f..fc742f00db86 100644 --- a/tests/core/test_crawler_rpc.py +++ b/tests/core/test_crawler_rpc.py @@ -9,7 +9,7 @@ from chia.seeder.crawler import Crawler class TestCrawlerRpc: @pytest.mark.asyncio async def test_get_ips_after_timestamp(self, bt): - crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, consensus_constants=bt.constants) + crawler = Crawler(bt.config.get("seeder", {}), bt.root_path, constants=bt.constants) crawler_rpc_api = CrawlerRpcApi(crawler) # Should raise ValueError when `after` is not supplied diff --git a/tests/core/test_seeder.py b/tests/core/test_seeder.py new file mode 100644 index 000000000000..9458c5bac0b9 --- /dev/null +++ b/tests/core/test_seeder.py @@ -0,0 +1,275 @@ +from __future__ import annotations + +import time +from dataclasses import dataclass +from ipaddress import IPv4Address, IPv6Address +from socket import AF_INET6, SOCK_STREAM +from typing import List, Tuple, cast + +import dns +import pytest + +from chia.seeder.dns_server import DNSServer +from chia.seeder.peer_record import PeerRecord, PeerReliability +from chia.simulator.time_out_assert import time_out_assert +from chia.util.ints import uint32, uint64 + +timeout = 0.5 + + +def generate_test_combs() -> List[Tuple[bool, str, dns.rdatatype.RdataType]]: + """ + Generates all the combinations of tests we want to run. + """ + output = [] + use_tcp = [True, False] + target_address = ["::1", "127.0.0.1"] + request_types = [dns.rdatatype.A, dns.rdatatype.AAAA, dns.rdatatype.ANY, dns.rdatatype.NS, dns.rdatatype.SOA] + # (use_tcp, target-addr, request_type), so udp + tcp, ipv6 +4 for every request type we support. + for addr in target_address: + for tcp in use_tcp: + for req in request_types: + output.append((tcp, addr, req)) + return output + + +all_test_combinations = generate_test_combs() + + +@dataclass(frozen=True) +class FakeDnsPacket: + real_packet: dns.message.Message + + def to_wire(self) -> bytes: + return self.real_packet.to_wire()[23:] + + def __getattr__(self, item: object) -> None: + # This is definitely cheating, but it works + return None + + +async def make_dns_query( + use_tcp: bool, target_address: str, port: int, dns_message: dns.message.Message, d_timeout: float = timeout +) -> dns.message.Message: + """ + Makes a DNS query for the given domain name using the given protocol type. + """ + if use_tcp: + return await dns.asyncquery.tcp(q=dns_message, where=target_address, timeout=d_timeout, port=port) + return await dns.asyncquery.udp(q=dns_message, where=target_address, timeout=d_timeout, port=port) + + +def get_addresses(num_subnets: int = 10) -> Tuple[List[IPv4Address], List[IPv6Address]]: + ipv4 = [] + ipv6 = [] + # generate 2500 ipv4 and 2500 ipv6 peers, it's just a string so who cares + for s in range(num_subnets): + for i in range(1, 251): # im being lazy as we can only have 255 per subnet + ipv4.append(IPv4Address(f"192.168.{s}.{i}")) + ipv6.append(IPv6Address(f"2001:db8::{s}:{i}")) + return ipv4, ipv6 + + +def assert_standard_results( + std_query_answer: List[dns.rrset.RRset], request_type: dns.rdatatype.RdataType, num_ns: int +) -> None: + if request_type == dns.rdatatype.A: + assert len(std_query_answer) == 1 # only 1 kind of answer + a_answer = std_query_answer[0] + assert a_answer.rdtype == dns.rdatatype.A + assert len(a_answer) == 32 # 32 ipv4 addresses + elif request_type == dns.rdatatype.AAAA: + assert len(std_query_answer) == 1 # only 1 kind of answer + aaaa_answer = std_query_answer[0] + assert aaaa_answer.rdtype == dns.rdatatype.AAAA + assert len(aaaa_answer) == 32 # 32 ipv6 addresses + elif request_type == dns.rdatatype.ANY: + assert len(std_query_answer) == 4 # 4 kinds of answers + for answer in std_query_answer: + if answer.rdtype == dns.rdatatype.A: + assert len(answer) == 16 + elif answer.rdtype == dns.rdatatype.AAAA: + assert len(answer) == 16 + elif answer.rdtype == dns.rdatatype.NS: + assert len(answer) == num_ns + else: + assert len(answer) == 1 + elif request_type == dns.rdatatype.NS: + assert len(std_query_answer) == 1 # only 1 kind of answer + ns_answer = std_query_answer[0] + assert ns_answer.rdtype == dns.rdatatype.NS + assert len(ns_answer) == num_ns # ns records + else: + assert len(std_query_answer) == 1 # soa + soa_answer = std_query_answer[0] + assert soa_answer.rdtype == dns.rdatatype.SOA + assert len(soa_answer) == 1 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations) +async def test_error_conditions( + seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType +) -> None: + """ + We check having no peers, an invalid packet, an early EOF, and a packet then an EOF halfway through (tcp only). + We also check for a dns record that does not exist, and a dns record outside the domain. + """ + port = seeder_service.dns_port + domain = seeder_service.domain # default is: seeder.example.com + num_ns = len(seeder_service.ns_records) + + # No peers + no_peers = dns.message.make_query(domain, request_type) + no_peers_response = await make_dns_query(use_tcp, target_address, port, no_peers) + assert no_peers_response.rcode() == dns.rcode.NOERROR + + if request_type == dns.rdatatype.A or request_type == dns.rdatatype.AAAA: + assert len(no_peers_response.answer) == 0 # no response, as expected + elif request_type == dns.rdatatype.ANY: # ns + soa + assert len(no_peers_response.answer) == 2 + for answer in no_peers_response.answer: + if answer.rdtype == dns.rdatatype.NS: + assert len(answer.items) == num_ns + else: + assert len(answer.items) == 1 + elif request_type == dns.rdatatype.NS: + assert len(no_peers_response.answer) == 1 # ns + assert no_peers_response.answer[0].rdtype == dns.rdatatype.NS + assert len(no_peers_response.answer[0].items) == num_ns + else: + assert len(no_peers_response.answer) == 1 # soa + assert no_peers_response.answer[0].rdtype == dns.rdatatype.SOA + assert len(no_peers_response.answer[0].items) == 1 + # Authority Records + assert len(no_peers_response.authority) == num_ns + 1 # ns + soa + for record_list in no_peers_response.authority: + if record_list.rdtype == dns.rdatatype.NS: + assert len(record_list.items) == num_ns + else: + assert len(record_list.items) == 1 # soa + + # Invalid packet (this is kinda a pain) + invalid_packet = cast(dns.message.Message, FakeDnsPacket(dns.message.make_query(domain, request_type))) + with pytest.raises(EOFError if use_tcp else dns.exception.Timeout): # UDP will time out, TCP will EOF + await make_dns_query(use_tcp, target_address, port, invalid_packet) + + # early EOF packet + if use_tcp: + backend = dns.asyncbackend.get_default_backend() + tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call] + AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout + ) + async with tcp_socket as socket: + await socket.close() + with pytest.raises(EOFError): + await dns.asyncquery.receive_tcp(tcp_socket, timeout) + + # Packet then length header then EOF + if use_tcp: + backend = dns.asyncbackend.get_default_backend() + tcp_socket = await backend.make_socket( # type: ignore[no-untyped-call] + AF_INET6, SOCK_STREAM, 0, ("::", 0), ("::1", port), timeout=timeout + ) + async with tcp_socket as socket: # send packet then length then eof + r = await dns.asyncquery.tcp(q=no_peers, where="::1", timeout=timeout, sock=socket) + assert r.answer == no_peers_response.answer + # send 120, as the first 2 bytes / the length of the packet, so that the server expects more. + await socket.sendall(int(120).to_bytes(2, byteorder="big"), int(time.time() + timeout)) + await socket.close() + with pytest.raises(EOFError): + await dns.asyncquery.receive_tcp(tcp_socket, timeout) + + # Record does not exist + record_does_not_exist = dns.message.make_query("doesnotexist." + domain, request_type) + record_does_not_exist_response = await make_dns_query(use_tcp, target_address, port, record_does_not_exist) + assert record_does_not_exist_response.rcode() == dns.rcode.NXDOMAIN + assert len(record_does_not_exist_response.answer) == 0 + assert len(record_does_not_exist_response.authority) == num_ns + 1 # ns + soa + + # Record outside domain + record_outside_domain = dns.message.make_query("chia.net", request_type) + record_outside_domain_response = await make_dns_query(use_tcp, target_address, port, record_outside_domain) + assert record_outside_domain_response.rcode() == dns.rcode.REFUSED + assert len(record_outside_domain_response.answer) == 0 + assert len(record_outside_domain_response.authority) == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("use_tcp, target_address, request_type", all_test_combinations) +async def test_dns_queries( + seeder_service: DNSServer, use_tcp: bool, target_address: str, request_type: dns.rdatatype.RdataType +) -> None: + """ + We add 5000 peers directly, then try every kind of query many times over both the TCP and UDP protocols. + """ + port = seeder_service.dns_port + domain = seeder_service.domain # default is: seeder.example.com + num_ns = len(seeder_service.ns_records) + + # add 5000 peers (2500 ipv4, 2500 ipv6) + seeder_service.reliable_peers_v4, seeder_service.reliable_peers_v6 = get_addresses() + + # now we query for each type of record a lot of times and make sure we get the right number of responses + for i in range(150): + query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id. + std_query_response = await make_dns_query(use_tcp, target_address, port, query) + assert std_query_response.rcode() == dns.rcode.NOERROR + assert_standard_results(std_query_response.answer, request_type, num_ns) + + # Assert Authority Records + assert len(std_query_response.authority) == num_ns + 1 # ns + soa + for record_list in std_query_response.authority: + if record_list.rdtype == dns.rdatatype.NS: + assert len(record_list.items) == num_ns + else: + assert len(record_list.items) == 1 # soa + # Validate EDNS + e_query = dns.message.make_query(domain, dns.rdatatype.ANY, use_edns=False) + with pytest.raises(dns.query.BadResponse): # response is truncated without EDNS + await make_dns_query(False, target_address, port, e_query) + + +@pytest.mark.asyncio +async def test_db_processing(seeder_service: DNSServer) -> None: + """ + We add 1000 peers through the db, then try every kind of query over both the TCP and UDP protocols. + """ + port = seeder_service.dns_port + domain = seeder_service.domain # default is: seeder.example.com + num_ns = len(seeder_service.ns_records) + crawl_store = seeder_service.crawl_store + assert crawl_store is not None + + ipv4, ipv6 = get_addresses(2) # get 1000 addresses + # add peers to the in memory part of the db. + for peer in [str(peer) for pair in zip(ipv4, ipv6) for peer in pair]: + new_peer = PeerRecord( + peer, + peer, + uint32(58444), + False, + uint64(0), + uint32(0), + uint64(0), + uint64(int(time.time())), + uint64(0), + "undefined", + uint64(0), + tls_version="unknown", + ) + new_peer_reliability = PeerReliability(peer, tries=3, successes=3) # make sure the peer starts as reliable. + crawl_store.maybe_add_peer(new_peer, new_peer_reliability) # we don't fully add it because we don't need to. + + # Write these new peers to db. + await crawl_store.load_reliable_peers_to_db() + + # wait for the new db to be read. + await time_out_assert(30, lambda: seeder_service.reliable_peers_v4 != []) + + # now we check all the combinations once (not a stupid amount of times) + for use_tcp, target_address, request_type in all_test_combinations: + query = dns.message.make_query(domain, request_type, use_edns=True) # we need to generate a new request id. + std_query_response = await make_dns_query(use_tcp, target_address, port, query) + assert std_query_response.rcode() == dns.rcode.NOERROR + assert_standard_results(std_query_response.answer, request_type, num_ns)