Merge remote-tracking branch 'origin/main' into EL.trusted_peer_fix

This commit is contained in:
Earle Lowe 2023-03-02 13:07:23 -08:00
commit 3e13236eef
No known key found for this signature in database
214 changed files with 2877 additions and 2026 deletions

View File

@ -135,7 +135,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("ubuntu-arm64.tar.gz")).browser_download_url')
PRERELEASE_URL=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("ubuntu-arm64.tar.gz")).browser_download_url')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit

View File

@ -135,7 +135,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("ubuntu-x86-64.tar.gz")).browser_download_url')
PRERELEASE_URL=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("ubuntu-x86-64.tar.gz")).browser_download_url')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit

View File

@ -134,7 +134,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("centos-x86-64.tar.gz")).browser_download_url')
PRERELEASE_URL=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("centos-x86-64.tar.gz")).browser_download_url')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit

View File

@ -145,7 +145,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("${{ matrix.os.bladebit-suffix }}")).browser_download_url')
PRERELEASE_URL=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("${{ matrix.os.bladebit-suffix }}")).browser_download_url')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
@ -156,7 +156,7 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
FULLRELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease | not)) | first.assets[] | select(.browser_download_url | endswith("${{ matrix.os.bladebit-suffix }}")).browser_download_url')
FULLRELEASE_URL=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease | not)) | first.assets[] | select(.browser_download_url | endswith("${{ matrix.os.bladebit-suffix }}")).browser_download_url')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $FULLRELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit

View File

@ -180,7 +180,7 @@ jobs:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
shell: bash
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("windows-x86-64.zip")).browser_download_url')
PRERELEASE_URL=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("windows-x86-64.zip")).browser_download_url')
mkdir $GITHUB_WORKSPACE\\bladebit
ls
echo $PRERELEASE_URL

View File

@ -49,7 +49,6 @@ REPETITIONS = 100
async def main(db_path: Path):
random.seed(0x213FB154)
async with aiosqlite.connect(db_path) as connection:

View File

@ -40,7 +40,6 @@ random.seed(123456789)
async def run_add_block_benchmark(version: int):
verbose: bool = "--verbose" in sys.argv
db_wrapper: DBWrapper2 = await setup_db("block-store-benchmark.db", version)
@ -73,7 +72,6 @@ async def run_add_block_benchmark(version: int):
print("profiling add_full_block", end="")
for height in range(block_height, block_height + NUM_ITERS):
is_transaction = transaction_block_counter == 0
fees = uint64(random.randint(0, 150000))
farmer_coin, pool_coin = rewards(uint32(height))

View File

@ -38,7 +38,6 @@ def make_coins(num: int) -> Tuple[List[Coin], List[bytes32]]:
async def run_new_block_benchmark(version: int):
verbose: bool = "--verbose" in sys.argv
db_wrapper: DBWrapper2 = await setup_db("coin-store-benchmark.db", version)
@ -56,7 +55,6 @@ async def run_new_block_benchmark(version: int):
print("Building database ", end="")
for height in range(block_height, block_height + NUM_ITERS):
# add some new coins
additions, hashes = make_coins(2000)
@ -94,7 +92,6 @@ async def run_new_block_benchmark(version: int):
if verbose:
print("Profiling mostly additions ", end="")
for height in range(block_height, block_height + NUM_ITERS):
# add some new coins
additions, hashes = make_coins(2000)
total_add += 2000
@ -193,7 +190,6 @@ async def run_new_block_benchmark(version: int):
total_remove = 0
total_time = 0
for height in range(block_height, block_height + NUM_ITERS):
# add some new coins
additions, hashes = make_coins(2000)
total_add += 2000

View File

@ -3,27 +3,25 @@ from __future__ import annotations
import asyncio
import cProfile
from contextlib import contextmanager
from dataclasses import dataclass
from subprocess import check_call
from time import monotonic
from typing import Iterator, List
from typing import Dict, Iterator, List, Optional, Tuple
from utils import setup_db
from chia.consensus.block_record import BlockRecord
from chia.consensus.coinbase import create_farmer_coin, create_pool_coin
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.coin_store import CoinStore
from chia.full_node.mempool_manager import MempoolManager
from chia.simulator.wallet_tools import WalletTool
from chia.types.blockchain_format.classgroup import ClassgroupElement
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32, bytes100
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.util.db_wrapper import DBWrapper2
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.types.spend_bundle_conditions import Spend, SpendBundleConditions
from chia.util.ints import uint32, uint64
NUM_ITERS = 100
NUM_ITERS = 200
NUM_PEERS = 5
@ -42,105 +40,109 @@ def enable_profiler(profile: bool, name: str) -> Iterator[None]:
check_call(["gprof2dot", "-f", "pstats", "-o", output_file + ".dot", output_file + ".profile"])
with open(output_file + ".png", "w+") as f:
check_call(["dot", "-T", "png", output_file + ".dot"], stdout=f)
print("output written to: %s.png" % output_file)
print(" output written to: %s.png" % output_file)
def fake_block_record(block_height: uint32, timestamp: uint64) -> BlockRecord:
return BlockRecord(
bytes32(b"a" * 32), # header_hash
bytes32(b"b" * 32), # prev_hash
block_height, # height
uint128(0), # weight
uint128(0), # total_iters
uint8(0), # signage_point_index
ClassgroupElement(bytes100(b"1" * 100)), # challenge_vdf_output
None, # infused_challenge_vdf_output
bytes32(b"f" * 32), # reward_infusion_new_challenge
bytes32(b"c" * 32), # challenge_block_info_hash
uint64(0), # sub_slot_iters
bytes32(b"d" * 32), # pool_puzzle_hash
bytes32(b"e" * 32), # farmer_puzzle_hash
uint64(0), # required_iters
uint8(0), # deficit
False, # overflow
uint32(block_height - 1), # prev_transaction_block_height
timestamp, # timestamp
None, # prev_transaction_block_hash
uint64(0), # fees
None, # reward_claims_incorporated
None, # finished_challenge_slot_hashes
None, # finished_infused_challenge_slot_hashes
None, # finished_reward_slot_hashes
None, # sub_epoch_summary_included
def make_hash(height: int) -> bytes32:
return bytes32(height.to_bytes(32, byteorder="big"))
@dataclass(frozen=True)
class BenchBlockRecord:
"""
This is a subset of BlockRecord that the mempool manager uses for peak.
"""
header_hash: bytes32
height: uint32
timestamp: Optional[uint64]
prev_transaction_block_height: uint32
prev_transaction_block_hash: Optional[bytes32]
@property
def is_transaction_block(self) -> bool:
return self.timestamp is not None
def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockRecord:
this_hash = make_hash(block_height)
prev_hash = make_hash(block_height - 1)
return BenchBlockRecord(
header_hash=this_hash,
height=block_height,
timestamp=timestamp,
prev_transaction_block_height=uint32(block_height - 1),
prev_transaction_block_hash=prev_hash,
)
async def run_mempool_benchmark(single_threaded: bool) -> None:
async def run_mempool_benchmark() -> None:
all_coins: Dict[bytes32, CoinRecord] = {}
suffix = "st" if single_threaded else "mt"
db_wrapper: DBWrapper2 = await setup_db(f"mempool-benchmark-coins-{suffix}.db", 2)
async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
return all_coins.get(coin_id)
try:
coin_store = await CoinStore.create(db_wrapper)
mempool = MempoolManager(coin_store.get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded)
wt = WalletTool(DEFAULT_CONSTANTS)
wt = WalletTool(DEFAULT_CONSTANTS)
spend_bundles: List[List[SpendBundle]] = []
spend_bundles: List[List[SpendBundle]] = []
# these spend the same coins as spend_bundles but with a higher fee
replacement_spend_bundles: List[List[SpendBundle]] = []
timestamp = uint64(1631794488)
timestamp = uint64(1631794488)
height = uint32(1)
height = uint32(1)
print("Building SpendBundles")
for peer in range(NUM_PEERS):
print(f" peer {peer}")
print(" reward coins")
unspent: List[Coin] = []
for idx in range(NUM_ITERS):
height = uint32(height + 1)
# farm rewards
farmer_coin = create_farmer_coin(
height, wt.get_new_puzzlehash(), uint64(250000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
pool_coin = create_pool_coin(
height, wt.get_new_puzzlehash(), uint64(1750000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
unspent.extend([farmer_coin, pool_coin])
await coin_store.new_block(
height,
timestamp,
set([pool_coin, farmer_coin]),
[],
[],
)
bundles: List[SpendBundle] = []
print(" spend bundles")
for coin in unspent:
tx: SpendBundle = wt.generate_signed_transaction(
uint64(coin.amount // 2), wt.get_new_puzzlehash(), coin
)
bundles.append(tx)
spend_bundles.append(bundles)
print("Building SpendBundles")
for peer in range(NUM_PEERS):
print(f" peer {peer}")
print(" reward coins")
unspent: List[Coin] = []
for idx in range(NUM_ITERS):
height = uint32(height + 1)
# 19 seconds per block
timestamp = uint64(timestamp + 19)
if single_threaded:
print("Single-threaded")
else:
print("Multi-threaded")
print("Profiling add_spendbundle()")
# farm rewards
farmer_coin = create_farmer_coin(
height, wt.get_new_puzzlehash(), uint64(250000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
pool_coin = create_pool_coin(
height, wt.get_new_puzzlehash(), uint64(1750000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
all_coins[farmer_coin.name()] = CoinRecord(farmer_coin, height, uint32(0), True, timestamp)
all_coins[pool_coin.name()] = CoinRecord(pool_coin, height, uint32(0), True, timestamp)
unspent.extend([farmer_coin, pool_coin])
# the mempool only looks at:
# timestamp
# height
# is_transaction_block
# header_hash
print("initialize MempoolManager")
print(" spend bundles")
bundles: List[SpendBundle] = []
for coin in unspent:
tx: SpendBundle = wt.generate_signed_transaction(
uint64(coin.amount // 2), wt.get_new_puzzlehash(), coin, fee=peer + idx
)
bundles.append(tx)
spend_bundles.append(bundles)
bundles = []
print(" replacement spend bundles")
for coin in unspent:
tx = wt.generate_signed_transaction(
uint64(coin.amount // 2), wt.get_new_puzzlehash(), coin, fee=peer + idx + 10000000
)
bundles.append(tx)
replacement_spend_bundles.append(bundles)
start_height = height
for single_threaded in [False, True]:
if single_threaded:
print("\n== Single-threaded")
else:
print("\n== Multi-threaded")
mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded)
height = start_height
rec = fake_block_record(height, timestamp)
await mempool.new_peak(rec, None)
@ -153,6 +155,9 @@ async def run_mempool_benchmark(single_threaded: bool) -> None:
assert status == MempoolInclusionStatus.SUCCESS
assert error is None
suffix = "st" if single_threaded else "mt"
print("\nProfiling add_spend_bundle()")
total_bundles = 0
tasks = []
with enable_profiler(True, f"add-{suffix}"):
@ -162,20 +167,76 @@ async def run_mempool_benchmark(single_threaded: bool) -> None:
tasks.append(asyncio.create_task(add_spend_bundles(spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f"add_spendbundle time: {stop - start:0.4f}s")
print(f"{(stop - start) / total_bundles * 1000:0.2f}ms per add_spendbundle() call")
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / total_bundles * 1000:0.2f}ms")
print("\nProfiling add_spend_bundle() with replace-by-fee")
total_bundles = 0
tasks = []
with enable_profiler(True, f"replace-{suffix}"):
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(replacement_spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(replacement_spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / total_bundles * 1000:0.2f}ms")
print("\nProfiling create_bundle_from_mempool()")
with enable_profiler(True, f"create-{suffix}"):
start = monotonic()
for _ in range(2000):
mempool.create_bundle_from_mempool(bytes32(b"a" * 32))
for _ in range(500):
mempool.create_bundle_from_mempool(rec.header_hash)
stop = monotonic()
print(f"create_bundle_from_mempool time: {stop - start:0.4f}s")
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / 500 * 1000:0.2f}ms")
# TODO: add benchmark for new_peak()
print("\nProfiling new_peak() (optimized)")
blocks: List[Tuple[BenchBlockRecord, NPCResult]] = []
for coin_id in all_coins.keys():
height = uint32(height + 1)
timestamp = uint64(timestamp + 19)
rec = fake_block_record(height, timestamp)
npc_result = NPCResult(
None,
SpendBundleConditions(
[Spend(coin_id, bytes32(b" " * 32), None, 0, None, None, [], [], 0)], 0, 0, 0, None, None, [], 0
),
uint64(1000000000),
)
blocks.append((rec, npc_result))
finally:
await db_wrapper.close()
with enable_profiler(True, f"new-peak-{suffix}"):
start = monotonic()
for rec, npc_result in blocks:
await mempool.new_peak(rec, npc_result)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / len(blocks) * 1000:0.2f}ms")
print("\nProfiling new_peak() (reorg)")
blocks = []
for coin_id in all_coins.keys():
height = uint32(height + 2)
timestamp = uint64(timestamp + 28)
rec = fake_block_record(height, timestamp)
npc_result = NPCResult(
None,
SpendBundleConditions(
[Spend(coin_id, bytes32(b" " * 32), None, 0, None, None, [], [], 0)], 0, 0, 0, None, None, [], 0
),
uint64(1000000000),
)
blocks.append((rec, npc_result))
with enable_profiler(True, f"new-peak-reorg-{suffix}"):
start = monotonic()
for rec, npc_result in blocks:
await mempool.new_peak(rec, npc_result)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / len(blocks) * 1000:0.2f}ms")
if __name__ == "__main__":
@ -184,5 +245,4 @@ if __name__ == "__main__":
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.WARNING)
asyncio.run(run_mempool_benchmark(True))
asyncio.run(run_mempool_benchmark(False))
asyncio.run(run_mempool_benchmark())

View File

@ -5,7 +5,6 @@ from setuptools_scm import get_version
# example: 1.0b5.dev225
def main():
scm_full_version = get_version(root="..", relative_to=__file__)
# scm_full_version = "1.0.5.dev22"

@ -1 +1 @@
Subproject commit 9ea8e1a4ccb38947d2d645915cb9361dac01c014
Subproject commit 71995c6eb722e8d92008e2d3df38fd71f1231171

View File

@ -14,6 +14,7 @@ from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.bundle_tools import simple_solution_generator
from chia.full_node.coin_store import CoinStore
from chia.full_node.mempool import Mempool
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, get_puzzle_and_solution_for_coin
from chia.full_node.mempool_manager import MempoolManager
from chia.types.blockchain_format.coin import Coin
@ -128,7 +129,6 @@ _T_SpendSim = TypeVar("_T_SpendSim", bound="SpendSim")
class SpendSim:
db_wrapper: DBWrapper2
coin_store: CoinStore
mempool_manager: MempoolManager
@ -166,8 +166,7 @@ class SpendSim:
self.block_height = store_data.block_height
self.block_records = store_data.block_records
self.blocks = store_data.blocks
# Create a protocol to make BlockRecord and SimBlockRecord interchangeable.
self.mempool_manager.peak = self.block_records[-1] # type: ignore[assignment]
self.mempool_manager.peak = self.block_records[-1]
else:
self.timestamp = uint64(1)
self.block_height = uint32(0)
@ -187,8 +186,7 @@ class SpendSim:
await self.db_wrapper.close()
async def new_peak(self) -> None:
# Create a protocol to make BlockRecord and SimBlockRecord interchangeable.
await self.mempool_manager.new_peak(self.block_records[-1], None) # type: ignore[arg-type]
await self.mempool_manager.new_peak(self.block_records[-1], None)
def new_coin_record(self, coin: Coin, coinbase: bool = False) -> CoinRecord:
return CoinRecord(
@ -221,13 +219,12 @@ class SpendSim:
async def farm_block(
self,
puzzle_hash: bytes32 = bytes32(b"0" * 32),
item_inclusion_filter: Optional[Callable[[MempoolManager, MempoolItem], bool]] = None,
item_inclusion_filter: Optional[Callable[[bytes32], bool]] = None,
) -> Tuple[List[Coin], List[Coin]]:
# Fees get calculated
fees = uint64(0)
if self.mempool_manager.mempool.spends:
for _, item in self.mempool_manager.mempool.spends.items():
fees = uint64(fees + item.spend_bundle.fees())
for item in self.mempool_manager.mempool.all_spends():
fees = uint64(fees + item.fee)
# Rewards get created
next_block_height: uint32 = uint32(self.block_height + 1) if len(self.block_records) > 0 else self.block_height
@ -251,7 +248,7 @@ class SpendSim:
generator_bundle: Optional[SpendBundle] = None
return_additions: List[Coin] = []
return_removals: List[Coin] = []
if (len(self.block_records) > 0) and (self.mempool_manager.mempool.spends):
if (len(self.block_records) > 0) and (self.mempool_manager.mempool.size() > 0):
peak = self.mempool_manager.peak
if peak is not None:
result = self.mempool_manager.create_bundle_from_mempool(peak.header_hash, item_inclusion_filter)
@ -300,7 +297,8 @@ class SpendSim:
self.block_records = new_br_list
self.blocks = new_block_list
await self.coin_store.rollback_to_block(block_height)
self.mempool_manager.mempool.spends = {}
old_pool = self.mempool_manager.mempool
self.mempool_manager.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.block_height = block_height
if new_br_list:
self.timestamp = new_br_list[-1].timestamp
@ -429,12 +427,12 @@ class SimClient:
return CoinSpend(coin_record.coin, puzzle, solution)
async def get_all_mempool_tx_ids(self) -> List[bytes32]:
return list(self.service.mempool_manager.mempool.spends.keys())
return self.service.mempool_manager.mempool.all_spend_ids()
async def get_all_mempool_items(self) -> Dict[bytes32, MempoolItem]:
spends = {}
for tx_id, item in self.service.mempool_manager.mempool.spends.items():
spends[tx_id] = item
for item in self.service.mempool_manager.mempool.all_spends():
spends[item.name] = item
return spends
async def get_mempool_item_by_tx_id(self, tx_id: bytes32) -> Optional[Dict[str, Any]]:

View File

@ -0,0 +1,415 @@
from __future__ import annotations
import asyncio
import sys
from collections import defaultdict
from pathlib import Path
from sqlite3 import Row
from typing import Any, Dict, Iterable, List, Optional, Set
from chia.util.collection import find_duplicates
from chia.util.db_synchronous import db_synchronous_on
from chia.util.db_wrapper import DBWrapper2, execute_fetchone
from chia.util.pprint import print_compact_ranges
from chia.wallet.util.wallet_types import WalletType
# TODO: Check for missing paired wallets (eg. No DID wallet for an NFT)
# TODO: Check for missing DID Wallets
help_text = """
\b
The purpose of this command is find potential issues in Chia wallet databases.
The core chia client currently uses sqlite to store the wallet databases, one database per key.
\b
Guide to warning diagnostics:
----------------------------
"Missing Wallet IDs": A wallet was created and later deleted. By itself, this is okay because
the wallet does not reuse wallet IDs. However, this information may be useful
in conjunction with other information.
\b
Guide to error diagnostics:
--------------------------
Diagnostics in the error section indicate an error in the database structure.
In general, this does not indicate an error in on-chain data, nor does it mean that you have lost coins.
\b
An example is "Missing DerivationPath indexes" - a derivation path is a sub-key of your master key. Missing
derivation paths could cause your wallet to not "know" about transactions that happened on the blockchain.
\b
"""
def _validate_args_addresses_used(wallet_id: int, last_index: int, last_hardened: int, dp: DerivationPath) -> None:
if last_hardened:
if last_hardened != dp.hardened:
raise ValueError(f"Invalid argument: Mix of hardened and unhardened columns wallet_id={wallet_id}")
if last_index:
if last_index != dp.derivation_index:
raise ValueError(f"Invalid argument: noncontiguous derivation_index at {last_index} wallet_id={wallet_id}")
def check_addresses_used_contiguous(derivation_paths: List[DerivationPath]) -> List[str]:
"""
The used column for addresses in the derivation_paths table should be a
zero or greater run of 1's, followed by a zero or greater run of 0's.
There should be no used derivations after seeing a used derivation.
"""
errors: List[str] = []
for wallet_id, dps in dp_by_wallet_id(derivation_paths).items():
saw_unused = False
bad_used_values: Set[int] = set()
ordering_errors: List[str] = []
# last_index = None
# last_hardened = None
for dp in dps:
# _validate_args_addresses_used(wallet_id, last_index, last_hardened, dp)
if saw_unused and dp.used == 1 and ordering_errors == []:
ordering_errors.append(
f"Wallet {dp.wallet_id}: "
f"Used address after unused address at derivation index {dp.derivation_index}"
)
if dp.used == 1:
pass
elif dp.used == 0:
saw_unused = True
else:
bad_used_values.add(dp.used)
# last_hardened = dp.hardened
# last_index = dp.derivation_index
if len(bad_used_values) > 0:
errors.append(f"Wallet {wallet_id}: Bad values in 'used' column: {bad_used_values}")
if ordering_errors != []:
errors.extend(ordering_errors)
return errors
def check_for_gaps(array: List[int], start: int, end: int, *, data_type_plural: str = "Elements") -> List[str]:
"""
Check for compact sequence:
Check that every value from start to end is present in array, and no more.
start and end are values, not indexes
start and end should be included in array
array can be unsorted
"""
if start > end:
raise ValueError(f"{__name__} called with incorrect arguments: start={start} end={end} (start > end)")
errors: List[str] = []
if start == end and len(array) == 1:
return errors
expected_set = set(range(start, end + 1))
actual_set = set(array)
missing = expected_set.difference(actual_set)
extras = actual_set.difference(expected_set)
duplicates = find_duplicates(array)
if len(missing) > 0:
errors.append(f"Missing {data_type_plural}: {print_compact_ranges(list(missing))}")
if len(extras) > 0:
errors.append(f"Unexpected {data_type_plural}: {extras}")
if len(duplicates) > 0:
errors.append(f"Duplicate {data_type_plural}: {duplicates}")
return errors
class FromDB:
def __init__(self, row: Iterable[Any], fields: List[str]) -> None:
self.fields = fields
for field, value in zip(fields, row):
setattr(self, field, value)
def __repr__(self) -> str:
s = ""
for f in self.fields:
s += f"{f}={getattr(self, f)} "
return s
def wallet_type_name(
wallet_type: int,
) -> str:
if wallet_type in set(wt.value for wt in WalletType):
return f"{WalletType(wallet_type).name} ({wallet_type})"
else:
return f"INVALID_WALLET_TYPE ({wallet_type})"
def _cwr(row: Row) -> List[Any]:
r = []
for i, v in enumerate(row):
if i == 2:
r.append(wallet_type_name(v))
else:
r.append(v)
return r
# wallet_types_that_dont_need_derivations: See require_derivation_paths for each wallet type
wallet_types_that_dont_need_derivations = {WalletType.POOLING_WALLET, WalletType.NFT}
class DerivationPath(FromDB):
derivation_index: int
pubkey: str
puzzle_hash: str
wallet_type: WalletType
wallet_id: int
used: int # 1 or 0
hardened: int # 1 or 0
class Wallet(FromDB):
id: int # id >= 1
name: str
wallet_type: WalletType
data: str
def dp_by_wallet_id(derivation_paths: List[DerivationPath]) -> Dict[int, List[DerivationPath]]:
d = defaultdict(list)
for derivation_path in derivation_paths:
d[derivation_path.wallet_id].append(derivation_path)
for k, v in d.items():
d[k] = sorted(v, key=lambda dp: dp.derivation_index)
return d
def derivation_indices_by_wallet_id(derivation_paths: List[DerivationPath]) -> Dict[int, List[int]]:
d = dp_by_wallet_id(derivation_paths)
di = {}
for k, v in d.items():
di[k] = [dp.derivation_index for dp in v]
return di
def print_min_max_derivation_for_wallets(derivation_paths: List[DerivationPath]) -> None:
d = derivation_indices_by_wallet_id(derivation_paths)
print("Min, Max, Count of derivations for each wallet:")
for wallet_id, derivation_index_list in d.items():
# TODO: Fix count by separating hardened and unhardened
print(
f"Wallet ID {wallet_id:2} derivation index min: {derivation_index_list[0]} "
f"max: {derivation_index_list[-1]} count: {len(derivation_index_list)}"
)
class WalletDBReader:
db_wrapper: DBWrapper2 # TODO: Remove db_wrapper member
config = {"db_readers": 1}
sql_log_path = None
verbose = False
async def get_all_wallets(self) -> List[Wallet]:
wallet_fields = ["id", "name", "wallet_type", "data"]
async with self.db_wrapper.reader_no_transaction() as reader:
# TODO: if table doesn't exist
cursor = await reader.execute(f"""SELECT {", ".join(wallet_fields)} FROM users_wallets""")
rows = await cursor.fetchall()
return [Wallet(r, wallet_fields) for r in rows]
async def get_derivation_paths(self) -> List[DerivationPath]:
fields = ["derivation_index", "pubkey", "puzzle_hash", "wallet_type", "wallet_id", "used", "hardened"]
async with self.db_wrapper.reader_no_transaction() as reader:
# TODO: if table doesn't exist
cursor = await reader.execute(f"""SELECT {", ".join(fields)} FROM derivation_paths;""")
rows = await cursor.fetchall()
return [DerivationPath(row, fields) for row in rows]
async def show_tables(self) -> List[str]:
async with self.db_wrapper.reader_no_transaction() as reader:
cursor = await reader.execute("""SELECT name FROM sqlite_master WHERE type='table';""")
print("\nWallet DB Tables:")
print(*([r[0] for r in await cursor.fetchall()]), sep=",\n")
print("\nWallet Schema:")
print(*(await (await cursor.execute("PRAGMA table_info('users_wallets')")).fetchall()), sep=",\n")
print("\nDerivationPath Schema:")
print(*(await (await cursor.execute("PRAGMA table_info('derivation_paths')")).fetchall()), sep=",\n")
print()
return []
async def check_wallets(self) -> List[str]:
# id, name, wallet_type, data
# TODO: Move this SQL up a level
async with self.db_wrapper.reader_no_transaction() as reader:
errors = []
try:
main_wallet_id = 1
main_wallet_type = WalletType.STANDARD_WALLET
row = await execute_fetchone(reader, "SELECT * FROM users_wallets WHERE id=?", (main_wallet_id,))
if row is None:
errors.append(f"There is no wallet with ID {main_wallet_id} in table users_wallets")
elif row[2] != main_wallet_type:
errors.append(
f"We expect wallet {main_wallet_id} to have type {wallet_type_name(main_wallet_type)}, "
f"but it has {wallet_type_name(row[2])}"
)
except Exception as e:
errors.append(f"Exception while trying to access wallet {main_wallet_id} from users_wallets: {e}")
max_id_row = await execute_fetchone(reader, "SELECT MAX(id) FROM users_wallets")
if max_id_row is None:
errors.append("Error fetching max wallet ID from table users_wallets. No wallets ?!?")
else:
cursor = await reader.execute("""SELECT * FROM users_wallets""")
rows = await cursor.fetchall()
max_id = max_id_row[0]
errors.extend(check_for_gaps([r[0] for r in rows], main_wallet_id, max_id, data_type_plural="Wallet IDs"))
if self.verbose:
print("\nWallets:")
print(*[_cwr(r) for r in rows], sep=",\n")
# Check for invalid wallet types in users_wallets
invalid_wallet_types = set()
for row in rows:
if row[2] not in set(wt.value for wt in WalletType):
invalid_wallet_types.add(row[2])
if len(invalid_wallet_types) > 0:
errors.append(f"Invalid Wallet Types found in table users_wallets: {invalid_wallet_types}")
return errors
def check_wallets_missing_derivations(
self, wallets: List[Wallet], derivation_paths: List[DerivationPath]
) -> List[str]:
p = []
d = derivation_indices_by_wallet_id(derivation_paths) # TODO: calc this once, pass in
for w in wallets:
if w.wallet_type not in wallet_types_that_dont_need_derivations and w.id not in d:
p.append(w.id)
if len(p) > 0:
return [f"Wallet IDs with no derivations that require them: {p}"]
return []
def check_derivations_are_compact(self, wallets: List[Wallet], derivation_paths: List[DerivationPath]) -> List[str]:
errors = []
"""
Gaps in derivation index
Missing hardened or unhardened derivations
TODO: Gaps in used derivations
"""
for wallet_id in [w.id for w in wallets]:
for hardened in [0, 1]:
dps = list(filter(lambda x: x.wallet_id == wallet_id and x.hardened == hardened, derivation_paths))
if len(dps) < 1:
continue
dpi = [x.derivation_index for x in dps]
dpi.sort()
max_id = dpi[-1]
h = [" hardened", "unhardened"][hardened]
errors.extend(
check_for_gaps(
dpi, 0, max_id, data_type_plural=f"DerivationPath indexes for {h} wallet_id={wallet_id}"
)
)
return errors
def check_unexpected_derivation_entries(
self, wallets: List[Wallet], derivation_paths: List[DerivationPath]
) -> List[str]:
"""
Check for unexpected derivation path entries
Invalid Wallet Type
Wallet IDs not in table 'users_wallets'
Wallet ID with different wallet_type
"""
errors = []
wallet_id_to_type = {w.id: w.wallet_type for w in wallets}
invalid_wallet_types = []
missing_wallet_ids = []
wrong_type = defaultdict(list)
for d in derivation_paths:
if d.wallet_type not in set(wt.value for wt in WalletType):
invalid_wallet_types.append(d.wallet_type)
if d.wallet_id not in wallet_id_to_type:
missing_wallet_ids.append(d.wallet_id)
elif d.wallet_type != wallet_id_to_type[d.wallet_id]:
wrong_type[(d.hardened, d.wallet_id, d.wallet_type, wallet_id_to_type[d.wallet_id])].append(
d.derivation_index
)
if len(invalid_wallet_types) > 0:
errors.append(f"Invalid wallet_types in derivation_paths table: {invalid_wallet_types}")
if len(missing_wallet_ids) > 0:
errors.append(
f"Wallet IDs found in derivation_paths table, but not in users_wallets table: {missing_wallet_ids}"
)
for k, v in wrong_type.items():
errors.append(
f"""{[" ", "un"][int(k[0])]}hardened Wallet ID {k[1]} uses type {wallet_type_name(k[2])} in """
f"derivation_paths, but type {wallet_type_name(k[3])} in wallet table at these derivation indices: {v}"
)
return errors
async def scan(self, db_path: Path) -> int:
"""Returns number of lines of error output (not warnings)"""
self.db_wrapper = await DBWrapper2.create(
database=db_path,
reader_count=self.config.get("db_readers", 4),
log_path=self.sql_log_path,
synchronous=db_synchronous_on("auto"),
)
# TODO: Pass down db_wrapper
wallets = await self.get_all_wallets()
derivation_paths = await self.get_derivation_paths()
errors = []
warnings = []
try:
if self.verbose:
await self.show_tables()
print_min_max_derivation_for_wallets(derivation_paths)
warnings.extend(await self.check_wallets())
errors.extend(self.check_wallets_missing_derivations(wallets, derivation_paths))
errors.extend(self.check_unexpected_derivation_entries(wallets, derivation_paths))
errors.extend(self.check_derivations_are_compact(wallets, derivation_paths))
errors.extend(check_addresses_used_contiguous(derivation_paths))
if len(warnings) > 0:
print(f" ---- Warnings Found for {db_path.name} ----")
print("\n".join(warnings))
if len(errors) > 0:
print(f" ---- Errors Found for {db_path.name}----")
print("\n".join(errors))
finally:
await self.db_wrapper.close()
return len(errors)
async def scan(root_path: str, db_path: Optional[str] = None, *, verbose: bool = False) -> None:
if db_path is None:
wallet_db_path = Path(root_path) / "wallet" / "db"
wallet_db_paths = list(wallet_db_path.glob("blockchain_wallet_*.sqlite"))
else:
wallet_db_paths = [Path(db_path)]
num_errors = 0
for wallet_db_path in wallet_db_paths:
w = WalletDBReader()
w.verbose = verbose
print(f"Reading {wallet_db_path}")
num_errors += await w.scan(Path(wallet_db_path))
if num_errors > 0:
sys.exit(2)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(scan("", sys.argv[1]))

View File

@ -4,7 +4,7 @@ import logging
import traceback
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from aiohttp import ClientConnectorError
@ -31,6 +31,17 @@ NODE_TYPES: Dict[str, Type[RpcClient]] = {
"data_layer": DataLayerRpcClient,
}
node_config_section_names: Dict[Type[RpcClient], str] = {
FarmerRpcClient: "farmer",
WalletRpcClient: "wallet",
FullNodeRpcClient: "full_node",
HarvesterRpcClient: "harvester",
DataLayerRpcClient: "data_layer",
}
_T_RpcClient = TypeVar("_T_RpcClient", bound=RpcClient)
def transaction_submitted_msg(tx: TransactionRecord) -> str:
sent_to = [MempoolSubmissionStatus(s[0], s[1], s[2]).to_json_dict_convenience() for s in tx.sent_to]
@ -49,7 +60,6 @@ async def validate_client_connection(
fingerprint: Optional[int],
login_to_wallet: bool,
) -> Optional[int]:
try:
await rpc_client.healthz()
if type(rpc_client) == WalletRpcClient and login_to_wallet:
@ -66,28 +76,29 @@ async def validate_client_connection(
@asynccontextmanager
async def get_any_service_client(
node_type: str,
client_type: Type[_T_RpcClient],
rpc_port: Optional[int] = None,
root_path: Path = DEFAULT_ROOT_PATH,
fingerprint: Optional[int] = None,
login_to_wallet: bool = True,
) -> AsyncIterator[Tuple[Optional[Any], Dict[str, Any], Optional[int]]]:
) -> AsyncIterator[Tuple[Optional[_T_RpcClient], Dict[str, Any], Optional[int]]]:
"""
Yields a tuple with a RpcClient for the applicable node type a dictionary of the node's configuration,
and a fingerprint if applicable. However, if connecting to the node fails then we will return None for
the RpcClient.
"""
if node_type not in NODE_TYPES.keys():
node_type = node_config_section_names.get(client_type)
if node_type is None:
# Click already checks this, so this should never happen
raise ValueError(f"Invalid node type: {node_type}")
raise ValueError(f"Invalid client type requested: {client_type.__name__}")
# load variables from config file
config = load_config(root_path, "config.yaml", fill_missing_services=node_type == "data_layer")
config = load_config(root_path, "config.yaml", fill_missing_services=issubclass(client_type, DataLayerRpcClient))
self_hostname = config["self_hostname"]
if rpc_port is None:
rpc_port = config[node_type]["rpc_port"]
# select node client type based on string
node_client = await NODE_TYPES[node_type].create(self_hostname, uint16(rpc_port), root_path, config)
node_client = await client_type.create(self_hostname, uint16(rpc_port), root_path, config)
try:
# check if we can connect to node, and if we can then validate
# fingerprint access, otherwise return fingerprint and shutdown client
@ -111,89 +122,90 @@ async def get_wallet(root_path: Path, wallet_client: WalletRpcClient, fingerprin
keychain_proxy: Optional[KeychainProxy] = None
all_keys: List[KeyData] = []
if fingerprint is not None:
selected_fingerprint = fingerprint
else:
keychain_proxy = await connect_to_keychain_and_validate(root_path, log=logging.getLogger(__name__))
if keychain_proxy is None:
raise RuntimeError("Failed to connect to keychain")
# we're only interested in the fingerprints and labels
all_keys = await keychain_proxy.get_keys(include_secrets=False)
# we don't immediately close the keychain proxy connection because it takes a noticeable amount of time
fingerprints = [key.fingerprint for key in all_keys]
if len(fingerprints) == 0:
print("No keys loaded. Run 'chia keys generate' or import a key")
elif len(fingerprints) == 1:
# if only a single key is available, select it automatically
selected_fingerprint = fingerprints[0]
try:
if fingerprint is not None:
selected_fingerprint = fingerprint
else:
keychain_proxy = await connect_to_keychain_and_validate(root_path, log=logging.getLogger(__name__))
if keychain_proxy is None:
raise RuntimeError("Failed to connect to keychain")
# we're only interested in the fingerprints and labels
all_keys = await keychain_proxy.get_keys(include_secrets=False)
# we don't immediately close the keychain proxy connection because it takes a noticeable amount of time
fingerprints = [key.fingerprint for key in all_keys]
if len(fingerprints) == 0:
print("No keys loaded. Run 'chia keys generate' or import a key")
elif len(fingerprints) == 1:
# if only a single key is available, select it automatically
selected_fingerprint = fingerprints[0]
if selected_fingerprint is None and len(all_keys) > 0:
logged_in_fingerprint: Optional[int] = await wallet_client.get_logged_in_fingerprint()
logged_in_key: Optional[KeyData] = None
if logged_in_fingerprint is not None:
logged_in_key = next((key for key in all_keys if key.fingerprint == logged_in_fingerprint), None)
current_sync_status: str = ""
indent = " "
if logged_in_key is not None:
if await wallet_client.get_synced():
current_sync_status = "Synced"
elif await wallet_client.get_sync_status():
current_sync_status = "Syncing"
else:
current_sync_status = "Not Synced"
print()
print("Active Wallet Key (*):")
print(f"{indent}{'-Fingerprint:'.ljust(23)} {logged_in_key.fingerprint}")
if logged_in_key.label is not None:
print(f"{indent}{'-Label:'.ljust(23)} {logged_in_key.label}")
print(f"{indent}{'-Sync Status:'.ljust(23)} {current_sync_status}")
max_key_index_width = 5 # e.g. "12) *", "1) *", or "2) "
max_fingerprint_width = 10 # fingerprint is a 32-bit number
print()
print("Wallet Keys:")
for i, key in enumerate(all_keys):
key_index_str = f"{(str(i + 1) + ')'):<4}"
key_index_str += "*" if key.fingerprint == logged_in_fingerprint else " "
print(
f"{key_index_str:<{max_key_index_width}} "
f"{key.fingerprint:<{max_fingerprint_width}}"
f"{(indent + key.label) if key.label else ''}"
)
val = None
prompt: str = (
f"Choose a wallet key [1-{len(fingerprints)}] ('q' to quit, or Enter to use {logged_in_fingerprint}): "
)
while val is None:
val = input(prompt)
if val == "q":
break
elif val == "" and logged_in_fingerprint is not None:
fingerprint = logged_in_fingerprint
break
elif not val.isdigit():
val = None
else:
index = int(val) - 1
if index < 0 or index >= len(fingerprints):
print("Invalid value")
val = None
continue
if selected_fingerprint is None and len(all_keys) > 0:
logged_in_fingerprint: Optional[int] = await wallet_client.get_logged_in_fingerprint()
logged_in_key: Optional[KeyData] = None
if logged_in_fingerprint is not None:
logged_in_key = next((key for key in all_keys if key.fingerprint == logged_in_fingerprint), None)
current_sync_status: str = ""
indent = " "
if logged_in_key is not None:
if await wallet_client.get_synced():
current_sync_status = "Synced"
elif await wallet_client.get_sync_status():
current_sync_status = "Syncing"
else:
fingerprint = fingerprints[index]
current_sync_status = "Not Synced"
selected_fingerprint = fingerprint
print()
print("Active Wallet Key (*):")
print(f"{indent}{'-Fingerprint:'.ljust(23)} {logged_in_key.fingerprint}")
if logged_in_key.label is not None:
print(f"{indent}{'-Label:'.ljust(23)} {logged_in_key.label}")
print(f"{indent}{'-Sync Status:'.ljust(23)} {current_sync_status}")
max_key_index_width = 5 # e.g. "12) *", "1) *", or "2) "
max_fingerprint_width = 10 # fingerprint is a 32-bit number
print()
print("Wallet Keys:")
for i, key in enumerate(all_keys):
key_index_str = f"{(str(i + 1) + ')'):<4}"
key_index_str += "*" if key.fingerprint == logged_in_fingerprint else " "
print(
f"{key_index_str:<{max_key_index_width}} "
f"{key.fingerprint:<{max_fingerprint_width}}"
f"{(indent + key.label) if key.label else ''}"
)
val = None
prompt: str = (
f"Choose a wallet key [1-{len(fingerprints)}] ('q' to quit, or Enter to use {logged_in_fingerprint}): "
)
while val is None:
val = input(prompt)
if val == "q":
break
elif val == "" and logged_in_fingerprint is not None:
fingerprint = logged_in_fingerprint
break
elif not val.isdigit():
val = None
else:
index = int(val) - 1
if index < 0 or index >= len(fingerprints):
print("Invalid value")
val = None
continue
else:
fingerprint = fingerprints[index]
if selected_fingerprint is not None:
log_in_response = await wallet_client.log_in(selected_fingerprint)
selected_fingerprint = fingerprint
if log_in_response["success"] is False:
print(f"Login failed for fingerprint {selected_fingerprint}: {log_in_response}")
selected_fingerprint = None
if selected_fingerprint is not None:
log_in_response = await wallet_client.log_in(selected_fingerprint)
# Closing the keychain proxy takes a moment, so we wait until after the login is complete
if keychain_proxy is not None:
await keychain_proxy.close()
if log_in_response["success"] is False:
print(f"Login failed for fingerprint {selected_fingerprint}: {log_in_response}")
selected_fingerprint = None
finally:
# Closing the keychain proxy takes a moment, so we wait until after the login is complete
if keychain_proxy is not None:
await keychain_proxy.close()
return selected_fingerprint
@ -204,8 +216,11 @@ async def execute_with_wallet(
extra_params: Dict[str, Any],
function: Callable[[Dict[str, Any], WalletRpcClient, int], Awaitable[None]],
) -> None:
wallet_client: Optional[WalletRpcClient]
async with get_any_service_client("wallet", wallet_rpc_port, fingerprint=fingerprint) as (wallet_client, _, new_fp):
async with get_any_service_client(WalletRpcClient, wallet_rpc_port, fingerprint=fingerprint) as (
wallet_client,
_,
new_fp,
):
if wallet_client is not None:
assert new_fp is not None # wallet only sanity check
await function(extra_params, wallet_client, new_fp)

View File

@ -24,7 +24,7 @@ def configure(
crawler_minimum_version_count: Optional[int],
seeder_domain_name: str,
seeder_nameserver: str,
):
) -> None:
config_yaml = "config.yaml"
with lock_and_load_config(root_path, config_yaml, fill_missing_services=True) as config:
config.update(load_defaults_for_missing_services(config=config, config_name=config_yaml))
@ -269,22 +269,22 @@ def configure(
)
@click.pass_context
def configure_cmd(
ctx,
set_farmer_peer,
set_node_introducer,
set_fullnode_port,
set_harvester_port,
set_log_level,
enable_upnp,
set_outbound_peer_count,
set_peer_count,
testnet,
set_peer_connect_timeout,
crawler_db_path,
crawler_minimum_version_count,
seeder_domain_name,
seeder_nameserver,
):
ctx: click.Context,
set_farmer_peer: str,
set_node_introducer: str,
set_fullnode_port: str,
set_harvester_port: str,
set_log_level: str,
enable_upnp: str,
set_outbound_peer_count: str,
set_peer_count: str,
testnet: str,
set_peer_connect_timeout: str,
crawler_db_path: str,
crawler_minimum_version_count: int,
seeder_domain_name: str,
seeder_nameserver: str,
) -> None:
configure(
ctx.obj["root_path"],
set_farmer_peer,

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional
from chia.cmds.cmds_util import get_any_service_client
from chia.cmds.units import units
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.ints import uint64
@ -13,7 +14,7 @@ from chia.util.ints import uint64
async def create_data_store_cmd(rpc_port: Optional[int], fee: Optional[str]) -> None:
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.create_data_store(fee=final_fee)
print(res)
@ -23,7 +24,7 @@ async def get_value_cmd(rpc_port: Optional[int], store_id: str, key: str, root_h
store_id_bytes = bytes32.from_hexstr(store_id)
key_bytes = hexstr_to_bytes(key)
root_hash_bytes = None if root_hash is None else bytes32.from_hexstr(root_hash)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_value(store_id=store_id_bytes, key=key_bytes, root_hash=root_hash_bytes)
print(res)
@ -37,7 +38,7 @@ async def update_data_store_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.update_data_store(store_id=store_id_bytes, changelist=changelist, fee=final_fee)
print(res)
@ -50,7 +51,7 @@ async def get_keys_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
root_hash_bytes = None if root_hash is None else bytes32.from_hexstr(root_hash)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_keys(store_id=store_id_bytes, root_hash=root_hash_bytes)
print(res)
@ -63,7 +64,7 @@ async def get_keys_values_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
root_hash_bytes = None if root_hash is None else bytes32.from_hexstr(root_hash)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_keys_values(store_id=store_id_bytes, root_hash=root_hash_bytes)
print(res)
@ -74,7 +75,7 @@ async def get_root_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_root(store_id=store_id_bytes)
print(res)
@ -86,7 +87,7 @@ async def subscribe_cmd(
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.subscribe(store_id=store_id_bytes, urls=urls)
print(res)
@ -97,7 +98,7 @@ async def unsubscribe_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.unsubscribe(store_id=store_id_bytes)
print(res)
@ -109,7 +110,7 @@ async def remove_subscriptions_cmd(
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.remove_subscriptions(store_id=store_id_bytes, urls=urls)
print(res)
@ -124,7 +125,7 @@ async def get_kv_diff_cmd(
store_id_bytes = bytes32.from_hexstr(store_id)
hash_1_bytes = bytes32.from_hexstr(hash_1)
hash_2_bytes = bytes32.from_hexstr(hash_2)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_kv_diff(store_id=store_id_bytes, hash_1=hash_1_bytes, hash_2=hash_2_bytes)
print(res)
@ -135,7 +136,7 @@ async def get_root_history_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_root_history(store_id=store_id_bytes)
print(res)
@ -144,7 +145,7 @@ async def get_root_history_cmd(
async def add_missing_files_cmd(
rpc_port: Optional[int], ids: Optional[List[str]], overwrite: bool, foldername: Optional[Path]
) -> None:
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.add_missing_files(
store_ids=(None if ids is None else [bytes32.from_hexstr(id) for id in ids]),
@ -159,7 +160,7 @@ async def add_mirror_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.add_mirror(
store_id=store_id_bytes,
@ -173,7 +174,7 @@ async def add_mirror_cmd(
async def delete_mirror_cmd(rpc_port: Optional[int], coin_id: str, fee: Optional[str]) -> None:
coin_id_bytes = bytes32.from_hexstr(coin_id)
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.delete_mirror(
coin_id=coin_id_bytes,
@ -184,21 +185,21 @@ async def delete_mirror_cmd(rpc_port: Optional[int], coin_id: str, fee: Optional
async def get_mirrors_cmd(rpc_port: Optional[int], store_id: str) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_mirrors(store_id=store_id_bytes)
print(res)
async def get_subscriptions_cmd(rpc_port: Optional[int]) -> None:
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_subscriptions()
print(res)
async def get_owned_stores_cmd(rpc_port: Optional[int]) -> None:
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_owned_stores()
print(res)
@ -209,7 +210,7 @@ async def get_sync_status_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, _, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_sync_status(store_id=store_id_bytes)
print(res)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
import click
@ -15,8 +16,8 @@ def db_cmd() -> None:
@db_cmd.command("upgrade", short_help="upgrade a v1 database to v2")
@click.option("--input", default=None, type=click.Path(), help="specify input database file")
@click.option("--output", default=None, type=click.Path(), help="specify output database file")
@click.option("--input", "in_db_path", default=None, type=click.Path(), help="specify input database file")
@click.option("--output", "out_db_path", default=None, type=click.Path(), help="specify output database file")
@click.option(
"--no-update-config",
default=False,
@ -31,11 +32,14 @@ def db_cmd() -> None:
help="force conversion despite warnings",
)
@click.pass_context
def db_upgrade_cmd(ctx: click.Context, no_update_config: bool, force: bool, **kwargs) -> None:
def db_upgrade_cmd(
ctx: click.Context,
in_db_path: Optional[str],
out_db_path: Optional[str],
no_update_config: bool,
force: bool,
) -> None:
try:
in_db_path = kwargs.get("input")
out_db_path = kwargs.get("output")
db_upgrade_func(
Path(ctx.obj["root_path"]),
None if in_db_path is None else Path(in_db_path),
@ -48,7 +52,7 @@ def db_upgrade_cmd(ctx: click.Context, no_update_config: bool, force: bool, **kw
@db_cmd.command("validate", short_help="validate the (v2) blockchain database. Does not verify proofs")
@click.option("--db", default=None, type=click.Path(), help="Specifies which database file to validate")
@click.option("--db", "in_db_path", default=None, type=click.Path(), help="Specifies which database file to validate")
@click.option(
"--validate-blocks",
default=False,
@ -56,9 +60,8 @@ def db_upgrade_cmd(ctx: click.Context, no_update_config: bool, force: bool, **kw
help="validate consistency of properties of the encoded blocks and block records",
)
@click.pass_context
def db_validate_cmd(ctx: click.Context, validate_blocks: bool, **kwargs) -> None:
def db_validate_cmd(ctx: click.Context, in_db_path: Optional[str], validate_blocks: bool) -> None:
try:
in_db_path = kwargs.get("db")
db_validate_func(
Path(ctx.obj["root_path"]),
None if in_db_path is None else Path(in_db_path),
@ -69,12 +72,11 @@ def db_validate_cmd(ctx: click.Context, validate_blocks: bool, **kwargs) -> None
@db_cmd.command("backup", short_help="backup the blockchain database using VACUUM INTO command")
@click.option("--backup_file", default=None, type=click.Path(), help="Specifies the backup file")
@click.option("--backup_file", "db_backup_file", default=None, type=click.Path(), help="Specifies the backup file")
@click.option("--no_indexes", default=False, is_flag=True, help="Create backup without indexes")
@click.pass_context
def db_backup_cmd(ctx: click.Context, no_indexes: bool, **kwargs) -> None:
def db_backup_cmd(ctx: click.Context, db_backup_file: Optional[str], no_indexes: bool) -> None:
try:
db_backup_file = kwargs.get("backup_file")
db_backup_func(
Path(ctx.obj["root_path"]),
None if db_backup_file is None else Path(db_backup_file),

View File

@ -7,7 +7,7 @@ import sys
import textwrap
from pathlib import Path
from time import time
from typing import Dict, Optional
from typing import Any, Dict, Optional
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.config import load_config, lock_and_load_config, save_config
@ -27,10 +27,9 @@ def db_upgrade_func(
no_update_config: bool = False,
force: bool = False,
) -> None:
update_config: bool = in_db_path is None and out_db_path is None and not no_update_config
config: Dict
config: Dict[str, Any]
selected_network: str
db_pattern: str
if in_db_path is None or out_db_path is None:
@ -81,7 +80,6 @@ def db_upgrade_func(
except RuntimeError as e:
print(f"conversion failed with error: {e}.")
except Exception as e:
print(
textwrap.dedent(
f"""\
@ -167,7 +165,7 @@ def convert_v1_to_v2(in_path: Path, out_path: Path) -> None:
"block_record blob)"
)
out_db.execute(
"CREATE TABLE sub_epoch_segments_v3(" "ses_block_hash blob PRIMARY KEY," "challenge_segments blob)"
"CREATE TABLE sub_epoch_segments_v3(ses_block_hash blob PRIMARY KEY, challenge_segments blob)"
)
out_db.execute("CREATE TABLE current_peak(key int PRIMARY KEY, hash blob)")
@ -202,10 +200,8 @@ def convert_v1_to_v2(in_path: Path, out_path: Path) -> None:
"SELECT header_hash, height, is_fully_compactified, block FROM full_blocks ORDER BY height DESC"
)
) as cursor_2:
out_db.execute("begin transaction")
for row in cursor:
header_hash = bytes.fromhex(row[0])
if header_hash != hh:
continue

View File

@ -41,7 +41,6 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
print(f"opening file for reading: {in_path}")
with closing(sqlite3.connect(in_path)) as in_db:
# read the database version
try:
with closing(in_db.execute("SELECT * FROM database_version")) as cursor:
@ -91,9 +90,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
"FROM full_blocks ORDER BY height DESC"
)
) as cursor:
for row in cursor:
hh = row[0]
prev = row[1]
height = row[2]
@ -111,7 +108,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
actual_prev_hash = block.prev_header_hash
if actual_header_hash != hh:
raise RuntimeError(
f"Block {hh.hex()} has a blob with mismatching " f"hash: {actual_header_hash.hex()}"
f"Block {hh.hex()} has a blob with mismatching hash: {actual_header_hash.hex()}"
)
if block_record.header_hash != hh:
raise RuntimeError(
@ -130,7 +127,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
)
if block.height != height:
raise RuntimeError(
f"Block {hh.hex()} has a mismatching " f"height: {block.height} expected {height}"
f"Block {hh.hex()} has a mismatching height: {block.height} expected {height}"
)
if height != current_height:
@ -146,7 +143,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
if hh == expect_hash:
if next_hash is not None:
raise RuntimeError(f"Database has multiple blocks with hash {hh.hex()}, " f"at height {height}")
raise RuntimeError(f"Database has multiple blocks with hash {hh.hex()}, at height {height}")
if not in_main_chain:
raise RuntimeError(
f"block {hh.hex()} (height: {height}) is part of the main chain, "
@ -168,9 +165,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
else:
if in_main_chain:
raise RuntimeError(
f"block {hh.hex()} (height: {height}) is orphaned, " "but in_main_chain is set"
)
raise RuntimeError(f"block {hh.hex()} (height: {height}) is orphaned, but in_main_chain is set")
num_orphans += 1
print("")

View File

@ -44,9 +44,7 @@ def farm_cmd() -> None:
@click.option(
"-fp",
"--farmer-rpc-port",
help=(
"Set the port where the Farmer is hosting the RPC interface. " "See the rpc_port under farmer in config.yaml"
),
help=("Set the port where the Farmer is hosting the RPC interface. See the rpc_port under farmer in config.yaml"),
type=int,
default=None,
show_default=True,

View File

@ -15,8 +15,7 @@ SECONDS_PER_BLOCK = (24 * 3600) / 4608
async def get_harvesters_summary(farmer_rpc_port: Optional[int]) -> Optional[Dict[str, Any]]:
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer", farmer_rpc_port) as node_config_fp:
async with get_any_service_client(FarmerRpcClient, farmer_rpc_port) as node_config_fp:
farmer_client, _, _ = node_config_fp
if farmer_client is not None:
return await farmer_client.get_harvesters_summary()
@ -24,8 +23,7 @@ async def get_harvesters_summary(farmer_rpc_port: Optional[int]) -> Optional[Dic
async def get_blockchain_state(rpc_port: Optional[int]) -> Optional[Dict[str, Any]]:
client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port) as node_config_fp:
client, _, _ = node_config_fp
if client is not None:
return await client.get_blockchain_state()
@ -33,8 +31,7 @@ async def get_blockchain_state(rpc_port: Optional[int]) -> Optional[Dict[str, An
async def get_average_block_time(rpc_port: Optional[int]) -> float:
client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port) as node_config_fp:
client, _, _ = node_config_fp
if client is not None:
blocks_to_compare = 500
@ -58,8 +55,7 @@ async def get_average_block_time(rpc_port: Optional[int]) -> float:
async def get_wallets_stats(wallet_rpc_port: Optional[int]) -> Optional[Dict[str, Any]]:
wallet_client: Optional[WalletRpcClient]
async with get_any_service_client("wallet", wallet_rpc_port, login_to_wallet=False) as node_config_fp:
async with get_any_service_client(WalletRpcClient, wallet_rpc_port, login_to_wallet=False) as node_config_fp:
wallet_client, _, _ = node_config_fp
if wallet_client is not None:
return await wallet_client.get_farmed_amount()
@ -67,8 +63,7 @@ async def get_wallets_stats(wallet_rpc_port: Optional[int]) -> Optional[Dict[str
async def get_challenges(farmer_rpc_port: Optional[int]) -> Optional[List[Dict[str, Any]]]:
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer", farmer_rpc_port) as node_config_fp:
async with get_any_service_client(FarmerRpcClient, farmer_rpc_port) as node_config_fp:
farmer_client, _, _ = node_config_fp
if farmer_client is not None:
return await farmer_client.get_signage_points()

View File

@ -24,7 +24,14 @@ import click
help="Initialize the blockchain database in v1 format (compatible with older versions of the full node)",
)
@click.pass_context
def init_cmd(ctx: click.Context, create_certs: str, fix_ssl_permissions: bool, testnet: bool, v1_db: bool, **kwargs):
def init_cmd(
ctx: click.Context,
create_certs: str,
fix_ssl_permissions: bool,
testnet: bool,
set_passphrase: bool,
v1_db: bool,
) -> None:
"""
Create a new configuration or migrate from previous versions to current
@ -43,7 +50,6 @@ def init_cmd(ctx: click.Context, create_certs: str, fix_ssl_permissions: bool, t
from .init_funcs import init
set_passphrase = kwargs.get("set_passphrase")
if set_passphrase:
initialize_passphrase()

View File

@ -41,7 +41,7 @@ from chia.wallet.derive_keys import (
)
def dict_add_new_default(updated: Dict, default: Dict, do_not_migrate_keys: Dict[str, Any]):
def dict_add_new_default(updated: Dict[str, Any], default: Dict[str, Any], do_not_migrate_keys: Dict[str, Any]) -> None:
for k in do_not_migrate_keys:
if k in updated and do_not_migrate_keys[k] == "":
updated.pop(k)
@ -155,7 +155,7 @@ def check_keys(new_root: Path, keychain: Optional[Keychain] = None) -> None:
save_config(new_root, "config.yaml", config)
def copy_files_rec(old_path: Path, new_path: Path):
def copy_files_rec(old_path: Path, new_path: Path) -> None:
if old_path.is_file():
print(f"{new_path}")
new_path.parent.mkdir(parents=True, exist_ok=True)
@ -171,7 +171,7 @@ def migrate_from(
new_root: Path,
manifest: List[str],
do_not_migrate_settings: List[str],
):
) -> int:
"""
Copy all the files in "manifest" to the new config directory.
"""
@ -193,7 +193,7 @@ def migrate_from(
with lock_and_load_config(new_root, "config.yaml") as config:
config_str: str = initial_config_file("config.yaml")
default_config: Dict = yaml.safe_load(config_str)
default_config: Dict[str, Any] = yaml.safe_load(config_str)
flattened_keys = unflatten_properties({k: "" for k in do_not_migrate_settings})
dict_add_new_default(config, default_config, flattened_keys)
@ -204,7 +204,7 @@ def migrate_from(
return 1
def copy_cert_files(cert_path: Path, new_path: Path):
def copy_cert_files(cert_path: Path, new_path: Path) -> None:
for old_path_child in cert_path.glob("*.crt"):
new_path_child = new_path / old_path_child.name
copy_files_rec(old_path_child, new_path_child)
@ -222,7 +222,7 @@ def init(
fix_ssl_permissions: bool = False,
testnet: bool = False,
v1_db: bool = False,
):
) -> Optional[int]:
if create_certs is not None:
if root_path.exists():
if os.path.isdir(create_certs):
@ -255,6 +255,8 @@ def init(
else:
return chia_init(root_path, fix_ssl_permissions=fix_ssl_permissions, testnet=testnet, v1_db=v1_db)
return None
def chia_version_number() -> Tuple[str, str, str, str]:
scm_full_version = __version__
@ -316,7 +318,7 @@ def chia_init(
fix_ssl_permissions: bool = False,
testnet: bool = False,
v1_db: bool = False,
):
) -> int:
"""
Standard first run initialization or migration steps. Handles config creation,
generation of SSL certs, and setting target addresses (via check_keys).
@ -383,7 +385,7 @@ def chia_init(
if should_check_keys:
check_keys(root_path)
config: Dict
config: Dict[str, Any]
db_path_replaced: str
if v1_db:

View File

@ -246,7 +246,7 @@ def derive_sk_from_hd_path(master_sk: PrivateKey, hd_path_root: str) -> Tuple[Pr
current_sk: PrivateKey = master_sk
# Derive keys along the path
for (current_index, derivation_type) in index_and_derivation_types:
for current_index, derivation_type in index_and_derivation_types:
if derivation_type == DerivationType.NONOBSERVER:
current_sk = _derive_path(current_sk, [current_index])
elif derivation_type == DerivationType.OBSERVER:
@ -384,7 +384,7 @@ def _search_derived(
if len(found_items) > 0 and show_progress:
print()
for (term, found_item, found_item_type) in found_items:
for term, found_item, found_item_type in found_items:
# Update remaining_search_terms and found_search_terms
del remaining_search_terms[term]
found_search_terms.append(term)

View File

@ -12,8 +12,7 @@ async def netstorge_async(rpc_port: Optional[int], delta_block_height: str, star
"""
Calculates the estimated space on the network given two block header hashes.
"""
client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port) as node_config_fp:
client, _, _ = node_config_fp
if client is not None:
if delta_block_height:

View File

@ -22,7 +22,7 @@ from chia.cmds.peer_funcs import peer_async
@click.option(
"-c", "--connections", help="List connections to the specified service", is_flag=True, type=bool, default=False
)
@click.option("-a", "--add-connection", help="Connect to another remote Chia service by ip:port", type=str, default="")
@click.option("-a", "--add-connection", help="Connect specified Chia service to ip:port", type=str, default="")
@click.option(
"-r", "--remove-connection", help="Remove a Node by the first 8 characters of NodeID", type=str, default=""
)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Optional
from chia.cmds.cmds_util import get_any_service_client
from chia.cmds.cmds_util import NODE_TYPES, get_any_service_client
from chia.rpc.rpc_client import RpcClient
@ -113,8 +113,8 @@ async def peer_async(
add_connection: str,
remove_connection: str,
) -> None:
rpc_client: Optional[RpcClient]
async with get_any_service_client(node_type, rpc_port, root_path) as node_config_fp:
client_type = NODE_TYPES[node_type]
async with get_any_service_client(client_type, rpc_port, root_path) as node_config_fp:
rpc_client, config, _ = node_config_fp
if rpc_client is not None:
# Check or edit node connections

View File

@ -173,8 +173,7 @@ async def pprint_pool_wallet_state(
async def show(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer") as node_config_fp:
async with get_any_service_client(FarmerRpcClient) as node_config_fp:
farmer_client, config, _ = node_config_fp
if farmer_client is not None:
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
@ -223,8 +222,7 @@ async def show(args: dict, wallet_client: WalletRpcClient, fingerprint: int) ->
async def get_login_link(launcher_id_str: str) -> None:
launcher_id: bytes32 = bytes32.from_hexstr(launcher_id_str)
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer") as node_config_fp:
async with get_any_service_client(FarmerRpcClient) as node_config_fp:
farmer_client, _, _ = node_config_fp
if farmer_client is not None:
login_link: Optional[str] = await farmer_client.get_pool_login_link(launcher_id)

View File

@ -181,7 +181,7 @@ async def print_fee_info(node_client: FullNodeRpcClient) -> None:
print("\nFee Rate Estimates:")
max_name_len = max(len(name) for name in target_times_names)
for (n, e) in zip(target_times_names, res["estimates"]):
for n, e in zip(target_times_names, res["estimates"]):
print(f" {n:>{max_name_len}}: {e:.3f} mojo per CLVM cost")
print("")
@ -196,8 +196,7 @@ async def show_async(
) -> None:
from chia.cmds.cmds_util import get_any_service_client
node_client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port, root_path) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port, root_path) as node_config_fp:
node_client, config, _ = node_config_fp
if node_client is not None:
# Check State
@ -210,7 +209,7 @@ async def show_async(
if block_header_hash_by_height != "":
block_header = await node_client.get_block_record_by_height(block_header_hash_by_height)
if block_header is not None:
print(f"Header hash of block {block_header_hash_by_height}: " f"{block_header.header_hash.hex()}")
print(f"Header hash of block {block_header_hash_by_height}: {block_header.header_hash.hex()}")
else:
print("Block height", block_header_hash_by_height, "not found")
if block_by_header_hash != "":

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
import click
from chia.cmds.check_wallet_db import help_text as check_help_text
from chia.cmds.cmds_util import execute_with_wallet
from chia.cmds.coins import coins_cmd
from chia.cmds.plotnft import validate_fee
@ -515,6 +516,21 @@ def cancel_offer_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: str,
asyncio.run(execute_with_wallet(wallet_rpc_port, fingerprint, extra_params, cancel_offer))
@wallet_cmd.command("check", short_help="Check wallet DB integrity", help=check_help_text)
@click.option("-v", "--verbose", help="Print more information", is_flag=True)
@click.option("--db-path", help="The path to a wallet DB. Default is to scan all active wallet DBs.")
@click.pass_context
# TODO: accept multiple dbs on commandline
# TODO: Convert to Path earlier
def check_wallet_cmd(ctx: click.Context, db_path: str, verbose: bool) -> None:
"""check, scan, diagnose, fsck Chia Wallet DBs"""
import asyncio
from chia.cmds.check_wallet_db import scan
asyncio.run(scan(ctx.obj["root_path"], db_path, verbose=verbose))
@wallet_cmd.group("did", short_help="DID related actions")
def did_cmd():
pass

View File

@ -767,7 +767,7 @@ async def print_balances(args: dict, wallet_client: WalletRpcClient, fingerprint
print()
print(f"{summary['name']}:")
print(f"{indent}{'-Total Balance:'.ljust(23)} {total_balance}")
print(f"{indent}{'-Pending Total Balance:'.ljust(23)} " f"{unconfirmed_wallet_balance}")
print(f"{indent}{'-Pending Total Balance:'.ljust(23)} {unconfirmed_wallet_balance}")
print(f"{indent}{'-Spendable:'.ljust(23)} {spendable_balance}")
print(f"{indent}{'-Type:'.ljust(23)} {typ.name}")
if typ == WalletType.DECENTRALIZED_ID:

View File

@ -319,6 +319,7 @@ async def validate_block_body(
cost_per_byte=constants.COST_PER_BYTE,
mempool_mode=False,
height=curr.height,
constants=constants,
)
removals_in_curr, additions_in_curr = tx_removals_and_additions(curr_npc_result.conds)
else:

View File

@ -388,7 +388,7 @@ def create_unfinished_block(
additions = []
if removals is None:
removals = []
(foliage, foliage_transaction_block, transactions_info,) = create_foliage(
(foliage, foliage_transaction_block, transactions_info) = create_foliage(
constants,
rc_block,
block_generator,

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from typing_extensions import Protocol
from chia.consensus.constants import ConsensusConstants
from chia.consensus.pot_iterations import calculate_ip_iters, calculate_sp_iters
from chia.types.blockchain_format.classgroup import ClassgroupElement
@ -13,6 +15,32 @@ from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.streamable import Streamable, streamable
class BlockRecordProtocol(Protocol):
@property
def header_hash(self) -> bytes32:
...
@property
def height(self) -> uint32:
...
@property
def timestamp(self) -> Optional[uint64]:
...
@property
def prev_transaction_block_height(self) -> uint32:
...
@property
def prev_transaction_block_hash(self) -> Optional[bytes32]:
...
@property
def is_transaction_block(self) -> bool:
return self.timestamp is not None
@streamable
@dataclass(frozen=True)
class BlockRecord(Streamable):

View File

@ -428,7 +428,6 @@ class Blockchain(BlockchainInterface):
async def get_tx_removals_and_additions(
self, block: FullBlock, npc_result: Optional[NPCResult] = None
) -> Tuple[List[bytes32], List[Coin], Optional[NPCResult]]:
if not block.is_transaction_block():
return [], [], None

View File

@ -61,8 +61,12 @@ class ConsensusConstants:
MAX_GENERATOR_SIZE: uint32
MAX_GENERATOR_REF_LIST_SIZE: uint32
POOL_SUB_SLOT_ITERS: uint64
# soft fork initiated in 1.7.0 release
SOFT_FORK_HEIGHT: uint32
# soft fork initiated in 1.8.0 release
SOFT_FORK2_HEIGHT: uint32
def replace(self, **changes: object) -> "ConsensusConstants":
return dataclasses.replace(self, **changes)

View File

@ -56,6 +56,7 @@ default_kwargs = {
"MAX_GENERATOR_REF_LIST_SIZE": 512, # Number of references allowed in the block generator ref list
"POOL_SUB_SLOT_ITERS": 37600000000, # iters limit * NUM_SPS
"SOFT_FORK_HEIGHT": 3630000,
"SOFT_FORK2_HEIGHT": 3830000,
}

View File

@ -27,7 +27,6 @@ def block_to_block_record(
header_block: Optional[HeaderBlock],
sub_slot_iters: Optional[uint64] = None,
) -> BlockRecord:
if full_block is None:
assert header_block is not None
block: Union[HeaderBlock, FullBlock] = header_block
@ -99,7 +98,6 @@ def header_block_to_sub_block_record(
prev_transaction_block_height: uint32,
ses: Optional[SubEpochSummary],
) -> BlockRecord:
reward_claims_incorporated = (
block.transactions_info.reward_claims_incorporated if block.transactions_info is not None else None
)

View File

@ -94,6 +94,7 @@ def batch_pre_validate_blocks(
cost_per_byte=constants.COST_PER_BYTE,
mempool_mode=False,
height=block.height,
constants=constants,
)
removals, tx_additions = tx_removals_and_additions(npc_result.conds)
if npc_result is not None and npc_result.error is not None:
@ -116,7 +117,6 @@ def batch_pre_validate_blocks(
successfully_validated_signatures = False
# If we failed CLVM, no need to validate signature, the block is already invalid
if error_int is None:
# If this is False, it means either we don't have a signature (not a tx block) or we have an invalid
# signature (which also puts in an error) or we didn't validate the signature because we want to
# validate it later. receive_block will attempt to validate the signature later.

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Type
@ -95,23 +95,22 @@ class DeleteLabelRequest(Streamable):
return EmptyResponse()
@dataclass
class KeychainServer:
"""
Implements a remote keychain service for clients to perform key operations on
"""
def __init__(self):
self._default_keychain = Keychain()
self._alt_keychains = {}
_default_keychain: Keychain = field(default_factory=Keychain)
_alt_keychains: Dict[str, Keychain] = field(default_factory=dict)
def get_keychain_for_request(self, request: Dict[str, Any]):
def get_keychain_for_request(self, request: Dict[str, Any]) -> Keychain:
"""
Keychain instances can have user and service strings associated with them.
The keychain backends ultimately point to the same data stores, but the user
and service strings are used to partition those data stores. We attempt to
maintain a mapping of user/service pairs to their corresponding Keychain.
"""
keychain = None
user = request.get("kc_user", self._default_keychain.user)
service = request.get("kc_service", self._default_keychain.service)
if user == self._default_keychain.user and service == self._default_keychain.service:

View File

@ -13,9 +13,10 @@ import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, TextIO, Tuple
from typing import Any, AsyncIterator, Dict, List, Optional, TextIO, Tuple
from chia import __version__
from chia.cmds.init_funcs import check_keys, chia_full_version_str, chia_init
@ -128,7 +129,6 @@ class WebSocketServer:
ca_key_path: Path,
crt_path: Path,
key_path: Path,
shutdown_event: asyncio.Event,
run_check_keys_on_unlock: bool = False,
):
self.root_path = root_path
@ -147,9 +147,10 @@ class WebSocketServer:
self.ssl_context = ssl_context_for_server(ca_crt_path, ca_key_path, crt_path, key_path, log=self.log)
self.keychain_server = KeychainServer()
self.run_check_keys_on_unlock = run_check_keys_on_unlock
self.shutdown_event = shutdown_event
self.shutdown_event = asyncio.Event()
async def start(self) -> None:
@asynccontextmanager
async def run(self) -> AsyncIterator[None]:
self.log.info("Starting Daemon Server")
# Note: the minimum_version has been already set to TLSv1_2
@ -180,6 +181,12 @@ class WebSocketServer:
ssl_context=self.ssl_context,
logger=self.log,
)
try:
yield
finally:
if not self.shutdown_event.is_set():
await self.stop()
await self.exit()
async def setup_process_global_state(self) -> None:
try:
@ -213,7 +220,7 @@ class WebSocketServer:
if stop_service_jobs:
await asyncio.wait(stop_service_jobs)
self.services.clear()
asyncio.create_task(self.exit())
self.shutdown_event.set()
log.info(f"Daemon Server stopping, Services stopped: {service_names}")
return {"success": True, "services_stopped": service_names}
@ -1123,7 +1130,6 @@ class WebSocketServer:
if self.webserver is not None:
self.webserver.close()
await self.webserver.await_closed()
self.shutdown_event.set()
log.info("chia daemon exiting")
async def register_service(self, websocket: WebSocketResponse, request: Dict[str, Any]) -> Dict[str, Any]:
@ -1350,20 +1356,17 @@ async def async_run_daemon(root_path: Path, wait_for_unlock: bool = False) -> in
beta_metrics = BetaMetricsLogger(root_path)
beta_metrics.start_logging()
shutdown_event = asyncio.Event()
ws_server = WebSocketServer(
root_path,
ca_crt_path,
ca_key_path,
crt_path,
key_path,
shutdown_event,
run_check_keys_on_unlock=wait_for_unlock,
)
await ws_server.setup_process_global_state()
await ws_server.start()
await shutdown_event.wait()
async with ws_server.run():
await ws_server.shutdown_event.wait()
if beta_metrics is not None:
await beta_metrics.stop_logging()

View File

@ -117,7 +117,6 @@ class DataLayerServer:
async def async_start(root_path: Path) -> int:
shutdown_event = asyncio.Event()
dl_config = load_config(

View File

@ -78,7 +78,7 @@ async def _dot_dump(data_store: DataStore, store_id: bytes32, root_hash: bytes32
dot_connections.append(f"""node_{hash} -> node_{left} [label="L"];""")
dot_connections.append(f"""node_{hash} -> node_{right} [label="R"];""")
dot_pair_boxes.append(
f"node [shape = box]; " f"{{rank = same; node_{left}->node_{right}[style=invis]; rankdir = LR}}"
f"node [shape = box]; {{rank = same; node_{left}->node_{right}[style=invis]; rankdir = LR}}"
)
lines = [

View File

@ -31,7 +31,7 @@ from chia.protocols.pool_protocol import (
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.server.outbound_message import NodeType, make_msg
from chia.server.server import ssl_context_for_root
from chia.server.server import ChiaServer, ssl_context_for_root
from chia.server.ws_connection import WSChiaConnection
from chia.ssl.create_ssl import get_mozilla_ca_crt
from chia.types.blockchain_format.proof_of_space import ProofOfSpace
@ -41,7 +41,7 @@ from chia.util.byte_types import hexstr_to_bytes
from chia.util.config import config_path_for_filename, load_config, lock_and_load_config, save_config
from chia.util.errors import KeychainProxyConnectionFailure
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64
from chia.util.ints import uint8, uint16, uint64
from chia.util.keychain import Keychain
from chia.util.logging import TimedDuplicateFilter
from chia.wallet.derive_keys import (
@ -70,8 +70,8 @@ class Farmer:
def __init__(
self,
root_path: Path,
farmer_config: Dict,
pool_config: Dict,
farmer_config: Dict[str, Any],
pool_config: Dict[str, Any],
consensus_constants: ConsensusConstants,
local_keychain: Optional[Keychain] = None,
):
@ -98,8 +98,8 @@ class Farmer:
self.plot_sync_receivers: Dict[bytes32, Receiver] = {}
self.cache_clear_task: Optional[asyncio.Task] = None
self.update_pool_state_task: Optional[asyncio.Task] = None
self.cache_clear_task: Optional[asyncio.Task[None]] = None
self.update_pool_state_task: Optional[asyncio.Task[None]] = None
self.constants = consensus_constants
self._shut_down = False
self.server: Any = None
@ -109,16 +109,18 @@ class Farmer:
self.log.addFilter(TimedDuplicateFilter("No pool specific difficulty has been set.*", 60 * 10))
self.started = False
self.harvester_handshake_task: Optional[asyncio.Task] = None
self.harvester_handshake_task: Optional[asyncio.Task[None]] = None
# From p2_singleton_puzzle_hash to pool state dict
self.pool_state: Dict[bytes32, Dict] = {}
self.pool_state: Dict[bytes32, Dict[str, Any]] = {}
# From p2_singleton to auth PrivateKey
self.authentication_keys: Dict[bytes32, PrivateKey] = {}
# Last time we updated pool_state based on the config file
self.last_config_access_time: uint64 = uint64(0)
self.last_config_access_time: float = 0
self.all_root_sks: List[PrivateKey] = []
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type)
@ -133,14 +135,14 @@ class Farmer:
raise KeychainProxyConnectionFailure()
return self.keychain_proxy
async def get_all_private_keys(self):
async def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]:
keychain_proxy = await self.ensure_keychain_proxy()
return await keychain_proxy.get_all_private_keys()
async def setup_keys(self) -> bool:
no_keys_error_str = "No keys exist. Please run 'chia keys generate' or open the UI."
try:
self.all_root_sks: List[PrivateKey] = [sk for sk, _ in await self.get_all_private_keys()]
self.all_root_sks = [sk for sk, _ in await self.get_all_private_keys()]
except KeychainProxyConnectionFailure:
return False
@ -170,9 +172,7 @@ class Farmer:
# This is the self pooling configuration, which is only used for original self-pooled plots
self.pool_target_encoded = self.pool_config["xch_target_address"]
self.pool_target = decode_puzzle_hash(self.pool_target_encoded)
self.pool_sks_map: Dict = {}
for key in self.get_private_keys():
self.pool_sks_map[bytes(key.get_g1())] = key
self.pool_sks_map = {bytes(key.get_g1()): key for key in self.get_private_keys()}
assert len(self.farmer_target) == 32
assert len(self.pool_target) == 32
@ -182,8 +182,8 @@ class Farmer:
return True
async def _start(self):
async def start_task():
async def _start(self) -> None:
async def start_task() -> None:
# `Farmer.setup_keys` returns `False` if there are no keys setup yet. In this case we just try until it
# succeeds or until we need to shut down.
while not self._shut_down:
@ -197,10 +197,10 @@ class Farmer:
asyncio.create_task(start_task())
def _close(self):
def _close(self) -> None:
self._shut_down = True
async def _await_closed(self, shutting_down: bool = True):
async def _await_closed(self, shutting_down: bool = True) -> None:
if self.cache_clear_task is not None:
await self.cache_clear_task
if self.update_pool_state_task is not None:
@ -215,10 +215,10 @@ class Farmer:
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback
async def on_connect(self, peer: WSChiaConnection):
async def on_connect(self, peer: WSChiaConnection) -> None:
self.state_changed("add_connection", {})
async def handshake_task():
async def handshake_task() -> None:
# Wait until the task in `Farmer._start` is done so that we have keys available for the handshake. Bail out
# early if we need to shut down or if the harvester is not longer connected.
while not self.started and not self._shut_down and peer in self.server.get_connections():
@ -247,20 +247,20 @@ class Farmer:
self.plot_sync_receivers[peer.peer_node_id] = Receiver(peer, self.plot_sync_callback)
self.harvester_handshake_task = asyncio.create_task(handshake_task())
def set_server(self, server):
def set_server(self, server: ChiaServer) -> None:
self.server = server
def state_changed(self, change: str, data: Dict[str, Any]):
def state_changed(self, change: str, data: Dict[str, Any]) -> None:
if self.state_changed_callback is not None:
self.state_changed_callback(change, data)
def handle_failed_pool_response(self, p2_singleton_puzzle_hash: bytes32, error_message: str):
def handle_failed_pool_response(self, p2_singleton_puzzle_hash: bytes32, error_message: str) -> None:
self.log.error(error_message)
self.pool_state[p2_singleton_puzzle_hash]["pool_errors_24h"].append(
ErrorResponse(uint16(PoolErrorCode.REQUEST_FAILED.value), error_message).to_json_dict()
)
def on_disconnect(self, connection: WSChiaConnection):
def on_disconnect(self, connection: WSChiaConnection) -> None:
self.log.info(f"peer disconnected {connection.get_peer_logging()}")
self.state_changed("close_connection", {})
if connection.connection_type is NodeType.HARVESTER:
@ -274,14 +274,14 @@ class Farmer:
if receiver.initial_sync() or harvester_updated:
self.state_changed("harvester_update", receiver.to_dict(True))
async def _pool_get_pool_info(self, pool_config: PoolWalletConfig) -> Optional[Dict]:
async def _pool_get_pool_info(self, pool_config: PoolWalletConfig) -> Optional[Dict[str, Any]]:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
f"{pool_config.pool_url}/pool_info", ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log)
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
self.log.info(f"GET /pool_info response: {response}")
return response
else:
@ -299,7 +299,7 @@ class Farmer:
async def _pool_get_farmer(
self, pool_config: PoolWalletConfig, authentication_token_timeout: uint8, authentication_sk: PrivateKey
) -> Optional[Dict]:
) -> Optional[Dict[str, Any]]:
authentication_token = get_current_authentication_token(authentication_token_timeout)
message: bytes32 = std_hash(
AuthenticationPayload(
@ -320,7 +320,7 @@ class Farmer:
ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log),
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
log_level = logging.INFO
if "error_code" in response:
log_level = logging.WARNING
@ -340,7 +340,7 @@ class Farmer:
async def _pool_post_farmer(
self, pool_config: PoolWalletConfig, authentication_token_timeout: uint8, owner_sk: PrivateKey
) -> Optional[Dict]:
) -> Optional[Dict[str, Any]]:
auth_sk: Optional[PrivateKey] = self.get_authentication_sk(pool_config)
assert auth_sk is not None
post_farmer_payload: PostFarmerPayload = PostFarmerPayload(
@ -362,7 +362,7 @@ class Farmer:
ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log),
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
log_level = logging.INFO
if "error_code" in response:
log_level = logging.WARNING
@ -404,7 +404,7 @@ class Farmer:
ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log),
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
log_level = logging.INFO
if "error_code" in response:
log_level = logging.WARNING
@ -428,7 +428,7 @@ class Farmer:
self.authentication_keys[pool_config.p2_singleton_puzzle_hash] = auth_sk
return auth_sk
async def update_pool_state(self):
async def update_pool_state(self) -> None:
config = load_config(self._root_path, "config.yaml")
pool_config_list: List[PoolWalletConfig] = load_pool_config(self._root_path)
@ -514,9 +514,7 @@ class Farmer:
farmer_info, error_code = await update_pool_farmer_info()
if error_code == PoolErrorCode.FARMER_NOT_KNOWN:
# Make the farmer known on the pool with a POST /farmer
owner_sk_and_index: Optional[Tuple[PrivateKey, uint32]] = find_owner_sk(
self.all_root_sks, pool_config.owner_public_key
)
owner_sk_and_index = find_owner_sk(self.all_root_sks, pool_config.owner_public_key)
assert owner_sk_and_index is not None
post_response = await self._pool_post_farmer(
pool_config, authentication_token_timeout, owner_sk_and_index[0]
@ -538,9 +536,7 @@ class Farmer:
and pool_config.payout_instructions.lower() != farmer_info.payout_instructions.lower()
)
if payout_instructions_update_required or error_code == PoolErrorCode.INVALID_SIGNATURE:
owner_sk_and_index: Optional[Tuple[PrivateKey, uint32]] = find_owner_sk(
self.all_root_sks, pool_config.owner_public_key
)
owner_sk_and_index = find_owner_sk(self.all_root_sks, pool_config.owner_public_key)
assert owner_sk_and_index is not None
await self._pool_put_farmer(
pool_config, authentication_token_timeout, owner_sk_and_index[0]
@ -555,13 +551,13 @@ class Farmer:
tb = traceback.format_exc()
self.log.error(f"Exception in update_pool_state for {pool_config.pool_url}, {e} {tb}")
def get_public_keys(self):
def get_public_keys(self) -> List[G1Element]:
return [child_sk.get_g1() for child_sk in self._private_keys]
def get_private_keys(self):
def get_private_keys(self) -> List[PrivateKey]:
return self._private_keys
async def get_reward_targets(self, search_for_private_key: bool, max_ph_to_search: int = 500) -> Dict:
async def get_reward_targets(self, search_for_private_key: bool, max_ph_to_search: int = 500) -> Dict[str, Any]:
if search_for_private_key:
all_sks = await self.get_all_private_keys()
have_farmer_sk, have_pool_sk = False, False
@ -591,7 +587,7 @@ class Farmer:
"pool_target": self.pool_target_encoded,
}
def set_reward_targets(self, farmer_target_encoded: Optional[str], pool_target_encoded: Optional[str]):
def set_reward_targets(self, farmer_target_encoded: Optional[str], pool_target_encoded: Optional[str]) -> None:
with lock_and_load_config(self._root_path, "config.yaml") as config:
if farmer_target_encoded is not None:
self.farmer_target_encoded = farmer_target_encoded
@ -603,7 +599,7 @@ class Farmer:
config["pool"]["xch_target_address"] = pool_target_encoded
save_config(self._root_path, "config.yaml", config)
async def set_payout_instructions(self, launcher_id: bytes32, payout_instructions: str):
async def set_payout_instructions(self, launcher_id: bytes32, payout_instructions: str) -> None:
for p2_singleton_puzzle_hash, pool_state_dict in self.pool_state.items():
if launcher_id == pool_state_dict["pool_config"].launcher_id:
with lock_and_load_config(self._root_path, "config.yaml") as config:
@ -627,7 +623,6 @@ class Farmer:
for pool_state in self.pool_state.values():
pool_config: PoolWalletConfig = pool_state["pool_config"]
if pool_config.launcher_id == launcher_id:
authentication_sk: Optional[PrivateKey] = self.get_authentication_sk(pool_config)
if authentication_sk is None:
self.log.error(f"Could not find authentication sk for {pool_config.p2_singleton_puzzle_hash}")
@ -655,8 +650,8 @@ class Farmer:
return None
async def get_harvesters(self, counts_only: bool = False) -> Dict:
harvesters: List = []
async def get_harvesters(self, counts_only: bool = False) -> Dict[str, Any]:
harvesters: List[Dict[str, Any]] = []
for connection in self.server.get_connections(NodeType.HARVESTER):
self.log.debug(f"get_harvesters host: {connection.peer_host}, node_id: {connection.peer_node_id}")
receiver = self.plot_sync_receivers.get(connection.peer_node_id)
@ -675,26 +670,26 @@ class Farmer:
raise KeyError(f"Receiver missing for {node_id}")
return receiver
async def _periodically_update_pool_state_task(self):
time_slept: uint64 = uint64(0)
async def _periodically_update_pool_state_task(self) -> None:
time_slept = 0
config_path: Path = config_path_for_filename(self._root_path, "config.yaml")
while not self._shut_down:
# Every time the config file changes, read it to check the pool state
stat_info = config_path.stat()
if stat_info.st_mtime > self.last_config_access_time:
# If we detect the config file changed, refresh private keys first just in case
self.all_root_sks: List[PrivateKey] = [sk for sk, _ in await self.get_all_private_keys()]
self.all_root_sks = [sk for sk, _ in await self.get_all_private_keys()]
self.last_config_access_time = stat_info.st_mtime
await self.update_pool_state()
time_slept = uint64(0)
time_slept = 0
elif time_slept > 60:
await self.update_pool_state()
time_slept = uint64(0)
time_slept = 0
time_slept += 1
await asyncio.sleep(1)
async def _periodically_clear_cache_and_refresh_task(self):
time_slept: uint64 = uint64(0)
async def _periodically_clear_cache_and_refresh_task(self) -> None:
time_slept = 0
refresh_slept = 0
while not self._shut_down:
try:
@ -710,7 +705,7 @@ class Farmer:
removed_keys.append(key)
for key in removed_keys:
self.cache_add_time.pop(key, None)
time_slept = uint64(0)
time_slept = 0
log.debug(
f"Cleared farmer cache. Num sps: {len(self.sps)} {len(self.proofs_of_space)} "
f"{len(self.quality_str_to_identifiers)} {len(self.number_of_responses)}"

View File

@ -1,13 +1,12 @@
from __future__ import annotations
from chia.full_node.fee_estimate_store import FeeStore
from chia.full_node.fee_estimation import EmptyFeeMempoolInfo, FeeBlockInfo, FeeMempoolInfo
from chia.full_node.fee_estimation import EmptyFeeMempoolInfo, FeeBlockInfo, FeeMempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator import SmartFeeEstimator
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.fee_tracker import FeeTracker
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRateV2
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint32, uint64
@ -35,11 +34,11 @@ class BitcoinFeeEstimator(FeeEstimatorInterface):
self.block_height = block_info.block_height
self.tracker.process_block(block_info.block_height, block_info.included_items)
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
self.last_mempool_info = mempool_info
self.tracker.add_tx(mempool_item)
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
self.last_mempool_info = mempool_info
self.tracker.remove_tx(mempool_item)

View File

@ -157,7 +157,6 @@ class BlockHeightMap:
# time until we hit a match in the existing map, at which point we can
# assume all previous blocks have already been populated
async def _load_blocks_from(self, height: uint32, prev_hash: bytes32) -> None:
while height > 0:
# load 5000 blocks at a time
window_end = max(0, height - 5000)
@ -175,7 +174,6 @@ class BlockHeightMap:
async with self.db.reader_no_transaction() as conn:
async with conn.execute(query, (window_end, height)) as cursor:
# maps block-hash -> (height, prev-hash, sub-epoch-summary)
ordered: Dict[bytes32, Tuple[uint32, bytes32, Optional[bytes]]] = {}
@ -195,7 +193,6 @@ class BlockHeightMap:
assert height == entry[0] + 1
height = entry[0]
if entry[2] is not None:
if (
self.get_hash(height) == prev_hash
and height in self.__sub_epoch_summaries

View File

@ -34,10 +34,8 @@ class BlockStore:
self = cls(LRUCache(1000), db_wrapper, LRUCache(50))
async with self.db_wrapper.writer_maybe_transaction() as conn:
log.info("DB: Creating block store tables and indexes.")
if self.db_wrapper.db_version == 2:
# TODO: most data in block is duplicated in block_record. The only
# reason for this is that our parsing of a FullBlock is so slow,
# it's faster to store duplicate data to parse less when we just
@ -84,7 +82,6 @@ class BlockStore:
)
else:
await conn.execute(
"CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
" is_block tinyint, is_fully_compactified tinyint, block blob)"
@ -168,7 +165,6 @@ class BlockStore:
raise RuntimeError(f"The blockchain database is corrupt. All of {header_hashes} should exist")
async def replace_proof(self, header_hash: bytes32, block: FullBlock) -> None:
assert header_hash == block.header_hash
block_bytes: bytes
@ -193,7 +189,6 @@ class BlockStore:
self.block_cache.put(header_hash, block)
if self.db_wrapper.db_version == 2:
ses: Optional[bytes] = (
None
if block_record.sub_epoch_summary_included is None
@ -331,7 +326,6 @@ class BlockStore:
return ret
async def get_block_info(self, header_hash: bytes32) -> Optional[GeneratorBlockInfo]:
cached = self.block_cache.get(header_hash)
if cached is not None:
log.debug(f"cache hit for block {header_hash.hex()}")
@ -362,7 +356,6 @@ class BlockStore:
)
async def get_generator(self, header_hash: bytes32) -> Optional[SerializedProgram]:
cached = self.block_cache.get(header_hash)
if cached is not None:
log.debug(f"cache hit for block {header_hash.hex()}")
@ -521,9 +514,7 @@ class BlockStore:
return ret
async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT block_record FROM full_blocks WHERE header_hash=?",
@ -556,7 +547,6 @@ class BlockStore:
ret: Dict[bytes32, BlockRecord] = {}
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ? AND height <= ?",
@ -567,7 +557,6 @@ class BlockStore:
ret[header_hash] = BlockRecord.from_bytes(row[1])
else:
formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"
async with self.db_wrapper.reader_no_transaction() as conn:
@ -601,7 +590,6 @@ class BlockStore:
return [maybe_decompress_blob(row[0]) for row in rows]
async def get_peak(self) -> Optional[Tuple[bytes32, uint32]]:
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute("SELECT hash FROM current_peak WHERE key = 0") as cursor:
@ -636,7 +624,6 @@ class BlockStore:
ret: Dict[bytes32, BlockRecord] = {}
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ?",
@ -683,7 +670,6 @@ class BlockStore:
return bool(row[0])
async def get_random_not_compactified(self, number: int) -> List[int]:
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(

View File

@ -36,10 +36,8 @@ class CoinStore:
self = CoinStore(db_wrapper, LRUCache(100))
async with self.db_wrapper.writer_maybe_transaction() as conn:
log.info("DB: Creating coin store tables and indexes.")
if self.db_wrapper.db_version == 2:
# the coin_name is unique in this table because the CoinStore always
# only represent a single peak
await conn.execute(
@ -55,7 +53,6 @@ class CoinStore:
)
else:
# the coin_name is unique in this table because the CoinStore always
# only represent a single peak
await conn.execute(
@ -273,7 +270,6 @@ class CoinStore:
start_height: uint32 = uint32(0),
end_height: uint32 = uint32((2**32) - 1),
) -> List[CoinRecord]:
coins = set()
async with self.db_wrapper.reader_no_transaction() as conn:
@ -284,7 +280,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
(self.maybe_to_hex(puzzle_hash), start_height, end_height),
) as cursor:
for row in await cursor.fetchall():
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -316,7 +311,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
puzzle_hashes_db + (start_height, end_height),
) as cursor:
for row in await cursor.fetchall():
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -348,7 +342,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
names_db + (start_height, end_height),
) as cursor:
for row in await cursor.fetchall():
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -427,7 +420,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
parent_ids_db + (start_height, end_height),
) as cursor:
async for row in cursor:
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -513,7 +505,6 @@ class CoinStore:
# Store CoinRecord in DB
async def _add_coin_records(self, records: List[CoinRecord]) -> None:
if self.db_wrapper.db_version == 2:
values2 = []
for record in records:
@ -560,7 +551,6 @@ class CoinStore:
# Update coin_record to be spent in DB
async def _set_spent(self, coin_names: List[bytes32], index: uint32) -> None:
assert len(coin_names) == 0 or index > 0
if len(coin_names) == 0:

View File

@ -6,11 +6,26 @@ from typing import List
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRate
from chia.types.mempool_item import MempoolItem
from chia.types.mojos import Mojos
from chia.util.ints import uint32, uint64
@dataclass(frozen=True)
class MempoolItemInfo:
"""
The information the fee estimator is passed for each mempool item that's
added, removed from the mempool and included in blocks
"""
cost: int
fee: int
height_added_to_mempool: uint32
@property
def fee_per_cost(self) -> float:
return self.fee / self.cost
@dataclass(frozen=True)
class MempoolInfo:
"""
@ -75,4 +90,4 @@ class FeeBlockInfo: # See BlockRecord
"""
block_height: uint32
included_items: List[MempoolItem]
included_items: List[MempoolItemInfo]

View File

@ -3,11 +3,10 @@ from __future__ import annotations
from typing import Any, Dict, List
from chia.full_node.fee_estimate import FeeEstimateV2
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRateV2
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint64
MIN_MOJO_PER_COST = 5
@ -31,10 +30,10 @@ class FeeEstimatorExample(FeeEstimatorInterface):
def new_block(self, block_info: FeeBlockInfo) -> None:
pass
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
pass
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
pass
def estimate_fee_rate(self, *, time_offset_seconds: int) -> FeeRateV2:

View File

@ -2,10 +2,9 @@ from __future__ import annotations
from typing_extensions import Protocol
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo, MempoolItemInfo
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRateV2
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint32
@ -18,11 +17,11 @@ class FeeEstimatorInterface(Protocol):
"""A new transaction block has been added to the blockchain"""
pass
def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
"""A MempoolItem (transaction and associated info) has been added to the mempool"""
pass
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
"""A MempoolItem (transaction and associated info) has been removed from the mempool"""
pass

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple
from chia.full_node.fee_estimate_store import FeeStore
from chia.full_node.fee_estimation import MempoolItemInfo
from chia.full_node.fee_estimator_constants import (
FEE_ESTIMATOR_VERSION,
INFINITE_FEE_RATE,
@ -26,7 +27,6 @@ from chia.full_node.fee_estimator_constants import (
SUFFICIENT_FEE_TXS,
)
from chia.full_node.fee_history import FeeStatBackup, FeeTrackerBackup
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint8, uint32, uint64
@ -129,7 +129,7 @@ class FeeStat: # TxConfirmStats
self.old_unconfirmed_txs = [0 for _ in range(0, len(buckets))]
def tx_confirmed(self, blocks_to_confirm: int, item: MempoolItem) -> None:
def tx_confirmed(self, blocks_to_confirm: int, item: MempoolItemInfo) -> None:
if blocks_to_confirm < 1:
raise ValueError("tx_confirmed called with < 1 block to confirm")
@ -164,7 +164,7 @@ class FeeStat: # TxConfirmStats
self.unconfirmed_txs[block_index][bucket_index] += 1
return bucket_index
def remove_tx(self, latest_seen_height: uint32, item: MempoolItem, bucket_index: int) -> None:
def remove_tx(self, latest_seen_height: uint32, item: MempoolItemInfo, bucket_index: int) -> None:
if item.height_added_to_mempool is None:
return
block_ago = latest_seen_height - item.height_added_to_mempool
@ -475,7 +475,7 @@ class FeeTracker:
)
self.fee_store.store_fee_data(backup)
def process_block(self, block_height: uint32, items: List[MempoolItem]) -> None:
def process_block(self, block_height: uint32, items: List[MempoolItemInfo]) -> None:
"""A new block has been farmed and these transactions have been included in that block"""
if block_height <= self.latest_seen_height:
# Ignore reorgs
@ -498,7 +498,7 @@ class FeeTracker:
self.first_recorded_height = block_height
self.log.info(f"Fee Estimator first recorded height: {self.first_recorded_height}")
def process_block_tx(self, current_height: uint32, item: MempoolItem) -> None:
def process_block_tx(self, current_height: uint32, item: MempoolItemInfo) -> None:
if item.height_added_to_mempool is None:
raise ValueError("process_block_tx called with item.height_added_to_mempool=None")
@ -510,8 +510,7 @@ class FeeTracker:
self.med_horizon.tx_confirmed(blocks_to_confirm, item)
self.long_horizon.tx_confirmed(blocks_to_confirm, item)
def add_tx(self, item: MempoolItem) -> None:
def add_tx(self, item: MempoolItemInfo) -> None:
if item.height_added_to_mempool < self.latest_seen_height:
self.log.info(f"Processing Item from pending pool: cost={item.cost} fee={item.fee}")
@ -522,7 +521,7 @@ class FeeTracker:
self.med_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
self.long_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
def remove_tx(self, item: MempoolItem) -> None:
def remove_tx(self, item: MempoolItemInfo) -> None:
bucket_index = get_bucket_index(self.buckets, item.fee_per_cost * 1000)
self.short_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
self.med_horizon.remove_tx(self.latest_seen_height, item, bucket_index)

View File

@ -115,7 +115,7 @@ class FullNode:
_transaction_queue: Optional[TransactionQueue]
_compact_vdf_sem: Optional[LimitedSemaphore]
_new_peak_sem: Optional[LimitedSemaphore]
_respond_transaction_semaphore: Optional[asyncio.Semaphore]
_add_transaction_semaphore: Optional[asyncio.Semaphore]
_db_wrapper: Optional[DBWrapper2]
_hint_store: Optional[HintStore]
transaction_responses: List[Tuple[bytes32, MempoolInclusionStatus, Optional[Err]]]
@ -180,7 +180,7 @@ class FullNode:
self._transaction_queue = None
self._compact_vdf_sem = None
self._new_peak_sem = None
self._respond_transaction_semaphore = None
self._add_transaction_semaphore = None
self._db_wrapper = None
self._hint_store = None
self.transaction_responses = []
@ -231,9 +231,9 @@ class FullNode:
return self._coin_store
@property
def respond_transaction_semaphore(self) -> asyncio.Semaphore:
assert self._respond_transaction_semaphore is not None
return self._respond_transaction_semaphore
def add_transaction_semaphore(self) -> asyncio.Semaphore:
assert self._add_transaction_semaphore is not None
return self._add_transaction_semaphore
@property
def transaction_queue(self) -> TransactionQueue:
@ -308,7 +308,7 @@ class FullNode:
self._new_peak_sem = LimitedSemaphore.create(active_limit=2, waiting_limit=20)
# These many respond_transaction tasks can be active at any point in time
self._respond_transaction_semaphore = asyncio.Semaphore(200)
self._add_transaction_semaphore = asyncio.Semaphore(200)
sql_log_path: Optional[Path] = None
if self.config.get("log_sqlite_cmds", False):
@ -442,7 +442,7 @@ class FullNode:
async def _handle_one_transaction(self, entry: TransactionQueueEntry) -> None:
peer = entry.peer
try:
inc_status, err = await self.respond_transaction(entry.transaction, entry.spend_name, peer, entry.test)
inc_status, err = await self.add_transaction(entry.transaction, entry.spend_name, peer, entry.test)
self.transaction_responses.append((entry.spend_name, inc_status, err))
if len(self.transaction_responses) > 50:
self.transaction_responses = self.transaction_responses[1:]
@ -455,14 +455,14 @@ class FullNode:
if peer is not None:
await peer.close()
finally:
self.respond_transaction_semaphore.release()
self.add_transaction_semaphore.release()
async def _handle_transactions(self) -> None:
try:
while not self._shut_down:
# We use a semaphore to make sure we don't send more than 200 concurrent calls of respond_transaction.
# However, doing them one at a time would be slow, because they get sent to other processes.
await self.respond_transaction_semaphore.acquire()
await self.add_transaction_semaphore.acquire()
item: TransactionQueueEntry = await self.transaction_queue.pop()
asyncio.create_task(self._handle_one_transaction(item))
except asyncio.CancelledError:
@ -578,7 +578,7 @@ class FullNode:
raise ValueError(f"Error short batch syncing, invalid/no response for {height}-{end_height}")
async with self._blockchain_lock_high_priority:
state_change_summary: Optional[StateChangeSummary]
success, state_change_summary = await self.receive_block_batch(response.blocks, peer, None)
success, state_change_summary = await self.add_block_batch(response.blocks, peer, None)
if not success:
raise ValueError(f"Error short batch syncing, failed to validate blocks {height}-{end_height}")
if state_change_summary is not None:
@ -629,7 +629,7 @@ class FullNode:
unfinished_block: Optional[UnfinishedBlock] = self.full_node_store.get_unfinished_block(target_unf_hash)
curr_height: int = target_height
found_fork_point = False
responses = []
blocks = []
while curr_height > peak_height - 5:
# If we already have the unfinished block, don't fetch the transactions. In the normal case, we will
# already have the unfinished block, from when it was broadcast, so we just need to download the header,
@ -644,14 +644,14 @@ class FullNode:
raise ValueError(
f"Failed to fetch block {curr_height} from {peer.get_peer_logging()}, wrong type {type(curr)}"
)
responses.append(curr)
blocks.append(curr.block)
if self.blockchain.contains_block(curr.block.prev_header_hash) or curr_height == 0:
found_fork_point = True
break
curr_height -= 1
if found_fork_point:
for response in reversed(responses):
await self.respond_block(response, peer)
for block in reversed(blocks):
await self.add_block(block, peer)
except (asyncio.CancelledError, Exception):
self.sync_store.backtrack_syncing[peer.peer_node_id] -= 1
raise
@ -988,7 +988,7 @@ class FullNode:
self.log.info(f"Total of {len(peers_with_peak)} peers with peak {target_peak.height}")
weight_proof_peer: WSChiaConnection = random.choice(peers_with_peak)
self.log.info(
f"Requesting weight proof from peer {weight_proof_peer.peer_host} up to height" f" {target_peak.height}"
f"Requesting weight proof from peer {weight_proof_peer.peer_host} up to height {target_peak.height}"
)
cur_peak: Optional[BlockRecord] = self.blockchain.get_peak()
if cur_peak is not None and target_peak.weight <= cur_peak.weight:
@ -1110,7 +1110,7 @@ class FullNode:
peer, blocks = res
start_height = blocks[0].height
end_height = blocks[-1].height
success, state_change_summary = await self.receive_block_batch(
success, state_change_summary = await self.add_block_batch(
blocks, peer, None if advanced_peak else uint32(fork_point_height), summaries
)
if success is False:
@ -1211,7 +1211,7 @@ class FullNode:
msg = make_msg(ProtocolMessageTypes.coin_state_update, state)
await ws_peer.send_message(msg)
async def receive_block_batch(
async def add_block_batch(
self,
all_blocks: List[FullBlock],
peer: WSChiaConnection,
@ -1581,16 +1581,15 @@ class FullNode:
await self.server.send_to_all([msg], NodeType.WALLET)
self._state_changed("new_peak")
async def respond_block(
async def add_block(
self,
respond_block: full_node_protocol.RespondBlock,
block: FullBlock,
peer: Optional[WSChiaConnection] = None,
raise_on_disconnected: bool = False,
) -> Optional[Message]:
"""
Receive a full block from a peer full node (or ourselves).
Add a full block from a peer full node (or ourselves).
"""
block: FullBlock = respond_block.block
if self.sync_store.get_sync_mode():
return None
@ -1649,7 +1648,7 @@ class FullNode:
f"same farmer with the same pospace."
)
# This recursion ends here, we cannot recurse again because transactions_generator is not None
return await self.respond_block(block_response, peer)
return await self.add_block(new_block, peer)
state_change_summary: Optional[StateChangeSummary] = None
ppp_result: Optional[PeakPostProcessingResult] = None
async with self._blockchain_lock_high_priority:
@ -1707,7 +1706,7 @@ class FullNode:
elif added == ReceiveBlockResult.ADDED_AS_ORPHAN:
self.log.info(
f"Received orphan block of height {block.height} rh " f"{block.reward_chain_block.get_hash()}"
f"Received orphan block of height {block.height} rh {block.reward_chain_block.get_hash()}"
)
else:
# Should never reach here, all the cases are covered
@ -1786,19 +1785,18 @@ class FullNode:
self._segment_task = asyncio.create_task(self.weight_proof_handler.create_prev_sub_epoch_segments())
return None
async def respond_unfinished_block(
async def add_unfinished_block(
self,
respond_unfinished_block: full_node_protocol.RespondUnfinishedBlock,
block: UnfinishedBlock,
peer: Optional[WSChiaConnection],
farmed_block: bool = False,
block_bytes: Optional[bytes] = None,
) -> None:
"""
We have received an unfinished block, either created by us, or from another peer.
We can validate it and if it's a good block, propagate it to other peers and
We can validate and add it and if it's a good block, propagate it to other peers and
timelords.
"""
block = respond_unfinished_block.unfinished_block
receive_time = time.time()
if block.prev_header_hash != self.constants.GENESIS_CHALLENGE and not self.blockchain.contains_block(
@ -2094,7 +2092,7 @@ class FullNode:
self.log.warning("Trying to make a pre-farm block but height is not 0")
return None
try:
await self.respond_block(full_node_protocol.RespondBlock(block), raise_on_disconnected=True)
await self.add_block(block, raise_on_disconnected=True)
except Exception as e:
self.log.warning(f"Consensus error validating block: {e}")
if timelord_peer is not None:
@ -2102,11 +2100,10 @@ class FullNode:
await self.send_peak_to_timelords(peer=timelord_peer)
return None
async def respond_end_of_sub_slot(
self, request: full_node_protocol.RespondEndOfSubSlot, peer: WSChiaConnection
async def add_end_of_sub_slot(
self, end_of_slot_bundle: EndOfSubSlotBundle, peer: WSChiaConnection
) -> Tuple[Optional[Message], bool]:
fetched_ss = self.full_node_store.get_sub_slot(request.end_of_slot_bundle.challenge_chain.get_hash())
fetched_ss = self.full_node_store.get_sub_slot(end_of_slot_bundle.challenge_chain.get_hash())
# We are not interested in sub-slots which have the same challenge chain but different reward chain. If there
# is a reorg, we will find out through the broadcast of blocks instead.
@ -2116,16 +2113,16 @@ class FullNode:
async with self.timelord_lock:
fetched_ss = self.full_node_store.get_sub_slot(
request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
)
if (
(fetched_ss is None)
and request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
and end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
!= self.constants.GENESIS_CHALLENGE
):
# If we don't have the prev, request the prev instead
full_node_request = full_node_protocol.RequestSignagePointOrEndOfSubSlot(
request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
uint8(0),
bytes32([0] * 32),
)
@ -2144,7 +2141,7 @@ class FullNode:
# Adds the sub slot and potentially get new infusions
new_infusions = self.full_node_store.new_finished_sub_slot(
request.end_of_slot_bundle,
end_of_slot_bundle,
self.blockchain,
peak,
await self.blockchain.get_full_peak(),
@ -2153,19 +2150,19 @@ class FullNode:
if new_infusions is not None:
self.log.info(
f"⏲️ Finished sub slot, SP {self.constants.NUM_SPS_SUB_SLOT}/{self.constants.NUM_SPS_SUB_SLOT}, "
f"{request.end_of_slot_bundle.challenge_chain.get_hash()}, "
f"{end_of_slot_bundle.challenge_chain.get_hash()}, "
f"number of sub-slots: {len(self.full_node_store.finished_sub_slots)}, "
f"RC hash: {request.end_of_slot_bundle.reward_chain.get_hash()}, "
f"Deficit {request.end_of_slot_bundle.reward_chain.deficit}"
f"RC hash: {end_of_slot_bundle.reward_chain.get_hash()}, "
f"Deficit {end_of_slot_bundle.reward_chain.deficit}"
)
# Reset farmer response timer for sub slot (SP 0)
self.signage_point_times[0] = time.time()
# Notify full nodes of the new sub-slot
broadcast = full_node_protocol.NewSignagePointOrEndOfSubSlot(
request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
request.end_of_slot_bundle.challenge_chain.get_hash(),
end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
end_of_slot_bundle.challenge_chain.get_hash(),
uint8(0),
request.end_of_slot_bundle.reward_chain.end_of_slot_vdf.challenge,
end_of_slot_bundle.reward_chain.end_of_slot_vdf.challenge,
)
msg = make_msg(ProtocolMessageTypes.new_signage_point_or_end_of_sub_slot, broadcast)
await self.server.send_to_all([msg], NodeType.FULL_NODE, peer.peer_node_id)
@ -2175,9 +2172,9 @@ class FullNode:
# Notify farmers of the new sub-slot
broadcast_farmer = farmer_protocol.NewSignagePoint(
request.end_of_slot_bundle.challenge_chain.get_hash(),
request.end_of_slot_bundle.challenge_chain.get_hash(),
request.end_of_slot_bundle.reward_chain.get_hash(),
end_of_slot_bundle.challenge_chain.get_hash(),
end_of_slot_bundle.challenge_chain.get_hash(),
end_of_slot_bundle.reward_chain.get_hash(),
next_difficulty,
next_sub_slot_iters,
uint8(0),
@ -2188,11 +2185,11 @@ class FullNode:
else:
self.log.info(
f"End of slot not added CC challenge "
f"{request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge}"
f"{end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge}"
)
return None, False
async def respond_transaction(
async def add_transaction(
self,
transaction: SpendBundle,
spend_name: bytes32,
@ -2238,8 +2235,8 @@ class FullNode:
if status == MempoolInclusionStatus.SUCCESS:
self.log.debug(
f"Added transaction to mempool: {spend_name} mempool size: "
f"{self.mempool_manager.mempool.total_mempool_cost} normalized "
f"{self.mempool_manager.mempool.total_mempool_cost / 5000000}"
f"{self.mempool_manager.mempool.total_mempool_cost()} normalized "
f"{self.mempool_manager.mempool.total_mempool_cost() / 5000000}"
)
# Only broadcast successful transactions, not pending ones. Otherwise it's a DOS
@ -2263,9 +2260,7 @@ class FullNode:
await self.simulator_transaction_callback(spend_name) # pylint: disable=E1102
else:
self.mempool_manager.remove_seen(spend_name)
self.log.debug(
f"Wasn't able to add transaction with id {spend_name}, " f"status {status} error: {error}"
)
self.log.debug(f"Wasn't able to add transaction with id {spend_name}, status {status} error: {error}")
return status, error
async def _needs_compact_proof(
@ -2359,7 +2354,6 @@ class FullNode:
header_hash: bytes32,
field_vdf: CompressibleVDFField,
) -> bool:
block = await self.block_store.get_full_block(header_hash)
if block is None:
return False
@ -2407,7 +2401,7 @@ class FullNode:
)
raise
async def respond_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
async def add_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
field_vdf = CompressibleVDFField(int(request.field_vdf))
if not await self._can_accept_compact_proof(
request.vdf_info, request.vdf_proof, request.height, request.header_hash, field_vdf
@ -2442,7 +2436,7 @@ class FullNode:
)
response = await peer.call_api(FullNodeAPI.request_compact_vdf, peer_request, timeout=10)
if response is not None and isinstance(response, full_node_protocol.RespondCompactVDF):
await self.respond_compact_vdf(response, peer)
await self.add_compact_vdf(response, peer)
async def request_compact_vdf(self, request: full_node_protocol.RequestCompactVDF, peer: WSChiaConnection) -> None:
header_block = await self.blockchain.get_header_block_by_height(
@ -2488,7 +2482,7 @@ class FullNode:
msg = make_msg(ProtocolMessageTypes.respond_compact_vdf, compact_vdf)
await peer.send_message(msg)
async def respond_compact_vdf(self, request: full_node_protocol.RespondCompactVDF, peer: WSChiaConnection) -> None:
async def add_compact_vdf(self, request: full_node_protocol.RespondCompactVDF, peer: WSChiaConnection) -> None:
field_vdf = CompressibleVDFField(int(request.field_vdf))
if not await self._can_accept_compact_proof(
request.vdf_info, request.vdf_proof, request.height, request.header_hash, field_vdf
@ -2525,7 +2519,6 @@ class FullNode:
self.log.info("Heights found for bluebox to compact: [%s]" % ", ".join(map(str, heights)))
for h in heights:
headers = await self.blockchain.get_header_blocks_in_range(h, h, tx_filter=False)
records: Dict[bytes32, BlockRecord] = {}
if sanitize_weight_proof_only:
@ -2616,7 +2609,6 @@ class FullNode:
async def node_next_block_check(
peer: WSChiaConnection, potential_peek: uint32, blockchain: BlockchainInterface
) -> bool:
block_response: Optional[Any] = await peer.call_api(
FullNodeAPI.request_block, full_node_protocol.RequestBlock(potential_peek, True)
)

View File

@ -465,8 +465,8 @@ class FullNodeAPI:
) -> Optional[Message]:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_unfinished_block(
respond_unfinished_block, peer, block_bytes=respond_unfinished_block_bytes
await self.full_node.add_unfinished_block(
respond_unfinished_block.unfinished_block, peer, block_bytes=respond_unfinished_block_bytes
)
return None
@ -558,7 +558,6 @@ class FullNodeAPI:
async def request_signage_point_or_end_of_sub_slot(
self, request: full_node_protocol.RequestSignagePointOrEndOfSubSlot
) -> Optional[Message]:
if request.index_from_challenge == 0:
sub_slot: Optional[Tuple[EndOfSubSlotBundle, int, uint128]] = self.full_node.full_node_store.get_sub_slot(
request.challenge_hash
@ -619,7 +618,6 @@ class FullNodeAPI:
return None
peak = self.full_node.blockchain.get_peak()
if peak is not None and peak.height > self.full_node.constants.MAX_SUB_SLOT_BLOCKS:
next_sub_slot_iters = self.full_node.blockchain.get_next_slot_iters(peak.header_hash, True)
sub_slots_for_peak = await self.full_node.blockchain.get_sp_and_ip_sub_slots(peak.header_hash)
assert sub_slots_for_peak is not None
@ -658,7 +656,7 @@ class FullNodeAPI:
) -> Optional[Message]:
if self.full_node.sync_store.get_sync_mode():
return None
msg, _ = await self.full_node.respond_end_of_sub_slot(request, peer)
msg, _ = await self.full_node.add_end_of_sub_slot(request.end_of_slot_bundle, peer)
return msg
@api_request(peer_required=True)
@ -1008,12 +1006,11 @@ class FullNodeAPI:
return None
# Propagate to ourselves (which validates and does further propagations)
request = full_node_protocol.RespondUnfinishedBlock(new_candidate)
try:
await self.full_node.respond_unfinished_block(request, None, True)
await self.full_node.add_unfinished_block(new_candidate, None, True)
except Exception as e:
# If we have an error with this block, try making an empty block
self.full_node.log.error(f"Error farming block {e} {request}")
self.full_node.log.error(f"Error farming block {e} {new_candidate}")
candidate_tuple = self.full_node.full_node_store.get_candidate_block(
farmer_request.quality_string, backup=True
)
@ -1071,8 +1068,7 @@ class FullNodeAPI:
):
return None
# Calls our own internal message to handle the end of sub slot, and potentially broadcasts to other peers.
full_node_message = full_node_protocol.RespondEndOfSubSlot(request.end_of_sub_slot_bundle)
msg, added = await self.full_node.respond_end_of_sub_slot(full_node_message, peer)
msg, added = await self.full_node.add_end_of_sub_slot(request.end_of_sub_slot_bundle, peer)
if not added:
self.log.error(
f"Was not able to add end of sub-slot: "
@ -1098,7 +1094,6 @@ class FullNodeAPI:
tx_additions: List[Coin] = []
if block.transactions_generator is not None:
block_generator: Optional[BlockGenerator] = await self.full_node.blockchain.get_block_generator(block)
# get_block_generator() returns None in case the block we specify
# does not have a generator (i.e. is not a transaction block).
@ -1411,7 +1406,7 @@ class FullNodeAPI:
async def respond_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_compact_proof_of_time(request)
await self.full_node.add_compact_proof_of_time(request)
return None
@api_request(peer_required=True, bytes_required=True, execute_task=True)
@ -1452,14 +1447,13 @@ class FullNodeAPI:
async def respond_compact_vdf(self, request: full_node_protocol.RespondCompactVDF, peer: WSChiaConnection) -> None:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_compact_vdf(request, peer)
await self.full_node.add_compact_vdf(request, peer)
return None
@api_request(peer_required=True)
async def register_interest_in_puzzle_hash(
self, request: wallet_protocol.RegisterForPhUpdates, peer: WSChiaConnection
) -> Message:
trusted = self.is_trusted(peer)
if trusted:
max_subscriptions = self.full_node.config.get("trusted_max_subscribe_items", 2000000)
@ -1532,7 +1526,6 @@ class FullNodeAPI:
async def register_interest_in_coin(
self, request: wallet_protocol.RegisterForCoinUpdates, peer: WSChiaConnection
) -> Message:
if self.is_trusted(peer):
max_subscriptions = self.full_node.config.get("trusted_max_subscribe_items", 2000000)
max_items = self.full_node.config.get("trusted_max_subscribe_response_items", 500000)

View File

@ -805,6 +805,6 @@ class FullNodeStore:
found_last_challenge = True
break
if not found_last_challenge:
log.warning(f"Did not find hash {last_challenge_to_add} connected to " f"{challenge_in_chain}")
log.warning(f"Did not find hash {last_challenge_to_add} connected to {challenge_in_chain}")
return None
return collected_sub_slots

View File

@ -1,13 +1,12 @@
from __future__ import annotations
import logging
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from sortedcontainers import SortedDict
from chia.full_node.fee_estimation import FeeMempoolInfo, MempoolInfo
from chia.full_node.fee_estimation import FeeMempoolInfo, MempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
@ -24,14 +23,46 @@ class MempoolRemoveReason(Enum):
class Mempool:
def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterface):
self.log: logging.Logger = logging.getLogger(__name__)
self.spends: Dict[bytes32, MempoolItem] = {}
self.sorted_spends: SortedDict = SortedDict()
self._spends: Dict[bytes32, MempoolItem] = {}
self._sorted_spends: SortedDict = SortedDict()
self.mempool_info: MempoolInfo = mempool_info
self.fee_estimator: FeeEstimatorInterface = fee_estimator
self.removal_coin_id_to_spendbundle_ids: Dict[bytes32, List[bytes32]] = {}
self.total_mempool_cost: CLVMCost = CLVMCost(uint64(0))
self.total_mempool_fees: int = 0
self._removal_coin_id_to_spendbundle_ids: Dict[bytes32, List[bytes32]] = {}
self._total_mempool_cost: CLVMCost = CLVMCost(uint64(0))
self._total_mempool_fees: int = 0
def total_mempool_fees(self) -> int:
return self._total_mempool_fees
def total_mempool_cost(self) -> CLVMCost:
return self._total_mempool_cost
def all_spends(self) -> List[MempoolItem]:
return list(self._spends.values())
def all_spend_ids(self) -> List[bytes32]:
return list(self._spends.keys())
def spends_by_feerate(self) -> List[MempoolItem]:
ret: List[MempoolItem] = []
for spends_with_fpc in reversed(self._sorted_spends.values()):
ret.extend(spends_with_fpc.values())
return ret
def size(self) -> int:
return len(self._spends)
def get_spend_by_id(self, spend_bundle_id: bytes32) -> Optional[MempoolItem]:
return self._spends.get(spend_bundle_id, None)
def get_spends_by_coin_id(self, spent_coin_id: bytes32) -> List[MempoolItem]:
spend_bundle_ids = self._removal_coin_id_to_spendbundle_ids.get(spent_coin_id)
if spend_bundle_ids is None:
return []
ret: List[MempoolItem] = []
for spend_bundle_id in spend_bundle_ids:
ret.append(self._spends[spend_bundle_id])
return ret
def get_min_fee_rate(self, cost: int) -> float:
"""
@ -39,11 +70,11 @@ class Mempool:
"""
if self.at_full_capacity(cost):
current_cost = self.total_mempool_cost
current_cost = self._total_mempool_cost
# Iterates through all spends in increasing fee per cost
fee_per_cost: float
for fee_per_cost, spends_with_fpc in self.sorted_spends.items():
for fee_per_cost, spends_with_fpc in self._sorted_spends.items():
for spend_name, item in spends_with_fpc.items():
current_cost -= item.cost
# Removing one at a time, until our transaction of size cost fits
@ -60,61 +91,65 @@ class Mempool:
Removes an item from the mempool.
"""
for spend_bundle_id in items:
item: Optional[MempoolItem] = self.spends.get(spend_bundle_id)
item: Optional[MempoolItem] = self._spends.get(spend_bundle_id)
if item is None:
continue
assert item.name == spend_bundle_id
removals: List[Coin] = item.removals
for rem in removals:
rem_name: bytes32 = rem.name()
self.removal_coin_id_to_spendbundle_ids[rem_name].remove(spend_bundle_id)
if len(self.removal_coin_id_to_spendbundle_ids[rem_name]) == 0:
del self.removal_coin_id_to_spendbundle_ids[rem_name]
del self.spends[item.name]
del self.sorted_spends[item.fee_per_cost][item.name]
dic = self.sorted_spends[item.fee_per_cost]
self._removal_coin_id_to_spendbundle_ids[rem_name].remove(spend_bundle_id)
if len(self._removal_coin_id_to_spendbundle_ids[rem_name]) == 0:
del self._removal_coin_id_to_spendbundle_ids[rem_name]
del self._spends[item.name]
del self._sorted_spends[item.fee_per_cost][item.name]
dic = self._sorted_spends[item.fee_per_cost]
if len(dic.values()) == 0:
del self.sorted_spends[item.fee_per_cost]
self.total_mempool_cost = CLVMCost(uint64(self.total_mempool_cost - item.cost))
self.total_mempool_fees = self.total_mempool_fees - item.fee
assert self.total_mempool_cost >= 0
info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost, self.total_mempool_fees, datetime.now())
del self._sorted_spends[item.fee_per_cost]
self._total_mempool_cost = CLVMCost(uint64(self._total_mempool_cost - item.cost))
self._total_mempool_fees = self._total_mempool_fees - item.fee
assert self._total_mempool_cost >= 0
info = FeeMempoolInfo(self.mempool_info, self._total_mempool_cost, self._total_mempool_fees, datetime.now())
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
self.fee_estimator.remove_mempool_item(info, item)
self.fee_estimator.remove_mempool_item(
info, MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool)
)
def add_to_pool(self, item: MempoolItem) -> None:
"""
Adds an item to the mempool by kicking out transactions (if it doesn't fit), in order of increasing fee per cost
"""
assert item.npc_result.conds is not None
while self.at_full_capacity(item.cost):
# Val is Dict[hash, MempoolItem]
fee_per_cost, val = self.sorted_spends.peekitem(index=0)
fee_per_cost, val = self._sorted_spends.peekitem(index=0)
to_remove: MempoolItem = list(val.values())[0]
self.remove_from_pool([to_remove.name], MempoolRemoveReason.POOL_FULL)
self.spends[item.name] = item
self._spends[item.name] = item
# sorted_spends is Dict[float, Dict[bytes32, MempoolItem]]
if item.fee_per_cost not in self.sorted_spends:
self.sorted_spends[item.fee_per_cost] = {}
# _sorted_spends is Dict[float, Dict[bytes32, MempoolItem]]
if item.fee_per_cost not in self._sorted_spends:
self._sorted_spends[item.fee_per_cost] = {}
self.sorted_spends[item.fee_per_cost][item.name] = item
self._sorted_spends[item.fee_per_cost][item.name] = item
for coin in item.removals:
coin_id = coin.name()
if coin_id not in self.removal_coin_id_to_spendbundle_ids:
self.removal_coin_id_to_spendbundle_ids[coin_id] = []
self.removal_coin_id_to_spendbundle_ids[coin_id].append(item.name)
if coin_id not in self._removal_coin_id_to_spendbundle_ids:
self._removal_coin_id_to_spendbundle_ids[coin_id] = []
self._removal_coin_id_to_spendbundle_ids[coin_id].append(item.name)
self.total_mempool_cost = CLVMCost(uint64(self.total_mempool_cost + item.cost))
self.total_mempool_fees = self.total_mempool_fees + item.fee
info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost, self.total_mempool_fees, datetime.now())
self.fee_estimator.add_mempool_item(info, item)
self._total_mempool_cost = CLVMCost(uint64(self._total_mempool_cost + item.cost))
self._total_mempool_fees = self._total_mempool_fees + item.fee
info = FeeMempoolInfo(self.mempool_info, self._total_mempool_cost, self._total_mempool_fees, datetime.now())
self.fee_estimator.add_mempool_item(info, MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
def at_full_capacity(self, cost: int) -> bool:
"""
Checks whether the mempool is at full capacity and cannot accept a transaction with size cost.
"""
return self.total_mempool_cost + cost > self.mempool_info.max_size_in_cost
return self._total_mempool_cost + cost > self.mempool_info.max_size_in_cost

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import logging
from typing import Dict, List, Optional, Tuple
from chia_rs import LIMIT_STACK, MEMPOOL_MODE
from chia_rs import ENABLE_ASSERT_BEFORE, LIMIT_STACK, MEMPOOL_MODE
from chia_rs import get_puzzle_and_solution_for_coin as get_puzzle_and_solution_for_coin_rust
from chia_rs import run_block_generator, run_chia_program
from clvm.casts import int_from_bytes
from chia.consensus.constants import ConsensusConstants
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.types.blockchain_format.coin import Coin
@ -33,7 +34,13 @@ log = logging.getLogger(__name__)
def get_name_puzzle_conditions(
generator: BlockGenerator, max_cost: int, *, cost_per_byte: int, mempool_mode: bool, height: Optional[uint32] = None
generator: BlockGenerator,
max_cost: int,
*,
cost_per_byte: int,
mempool_mode: bool,
height: Optional[uint32] = None,
constants: ConsensusConstants = DEFAULT_CONSTANTS,
) -> NPCResult:
# in mempool mode, the height doesn't matter, because it's always strict.
# But otherwise, height must be specified to know which rules to apply
@ -41,7 +48,9 @@ def get_name_puzzle_conditions(
if mempool_mode:
flags = MEMPOOL_MODE
elif height is not None and height >= DEFAULT_CONSTANTS.SOFT_FORK_HEIGHT:
elif height is not None and height >= constants.SOFT_FORK2_HEIGHT:
flags = LIMIT_STACK | ENABLE_ASSERT_BEFORE
elif height is not None and height >= constants.SOFT_FORK_HEIGHT:
flags = LIMIT_STACK
else:
flags = 0

View File

@ -11,12 +11,12 @@ from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
from blspy import GTElement
from chiabip158 import PyBIP158
from chia.consensus.block_record import BlockRecord
from chia.consensus.block_record import BlockRecordProtocol
from chia.consensus.constants import ConsensusConstants
from chia.consensus.cost_calculator import NPCResult
from chia.full_node.bitcoin_fee_estimator import create_bitcoin_fee_estimator
from chia.full_node.bundle_tools import simple_solution_generator
from chia.full_node.fee_estimation import FeeBlockInfo, MempoolInfo
from chia.full_node.fee_estimation import FeeBlockInfo, MempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.mempool import Mempool, MempoolRemoveReason
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, mempool_check_time_locks
@ -115,7 +115,7 @@ class MempoolManager:
# cache of MempoolItems with height conditions making them not valid yet
_pending_cache: PendingTxCache
seen_cache_size: int
peak: Optional[BlockRecord]
peak: Optional[BlockRecordProtocol]
mempool: Mempool
def __init__(
@ -157,7 +157,7 @@ class MempoolManager:
)
# The mempool will correspond to a certain peak
self.peak: Optional[BlockRecord] = None
self.peak: Optional[BlockRecordProtocol] = None
self.fee_estimator: FeeEstimatorInterface = create_bitcoin_fee_estimator(self.max_block_clvm_cost)
mempool_info = MempoolInfo(
CLVMCost(uint64(self.mempool_max_total_cost)),
@ -170,34 +170,30 @@ class MempoolManager:
self.pool.shutdown(wait=True)
def process_mempool_items(
self, item_inclusion_filter: Callable[[MempoolManager, MempoolItem], bool]
self, item_inclusion_filter: Callable[[bytes32], bool]
) -> Tuple[List[SpendBundle], uint64, List[Coin], List[Coin]]:
cost_sum = 0 # Checks that total cost does not exceed block maximum
fee_sum = 0 # Checks that total fees don't exceed 64 bits
spend_bundles: List[SpendBundle] = []
removals: List[Coin] = []
additions: List[Coin] = []
for dic in reversed(self.mempool.sorted_spends.values()):
for item in dic.values():
if not item_inclusion_filter(self, item):
continue
log.info(f"Cumulative cost: {cost_sum}, fee per cost: {item.fee / item.cost}")
if (
item.cost + cost_sum > self.max_block_clvm_cost
or item.fee + fee_sum > self.constants.MAX_COIN_AMOUNT
):
return (spend_bundles, uint64(cost_sum), additions, removals)
spend_bundles.append(item.spend_bundle)
cost_sum += item.cost
fee_sum += item.fee
removals.extend(item.removals)
additions.extend(item.additions)
for item in self.mempool.spends_by_feerate():
if not item_inclusion_filter(item.name):
continue
log.info(f"Cumulative cost: {cost_sum}, fee per cost: {item.fee / item.cost}")
if item.cost + cost_sum > self.max_block_clvm_cost or item.fee + fee_sum > self.constants.MAX_COIN_AMOUNT:
return (spend_bundles, uint64(cost_sum), additions, removals)
spend_bundles.append(item.spend_bundle)
cost_sum += item.cost
fee_sum += item.fee
removals.extend(item.removals)
additions.extend(item.additions)
return (spend_bundles, uint64(cost_sum), additions, removals)
def create_bundle_from_mempool(
self,
last_tb_header_hash: bytes32,
item_inclusion_filter: Optional[Callable[[MempoolManager, MempoolItem], bool]] = None,
item_inclusion_filter: Optional[Callable[[bytes32], bool]] = None,
) -> Optional[Tuple[SpendBundle, List[Coin], List[Coin]]]:
"""
Returns aggregated spendbundle that can be used for creating new block,
@ -208,7 +204,7 @@ class MempoolManager:
if item_inclusion_filter is None:
def always(mm: MempoolManager, mi: MempoolItem) -> bool:
def always(bundle_name: bytes32) -> bool:
return True
item_inclusion_filter = always
@ -227,7 +223,7 @@ class MempoolManager:
def get_filter(self) -> bytes:
all_transactions: Set[bytes32] = set()
byte_array_list = []
for key, _ in self.mempool.spends.items():
for key in self.mempool.all_spend_ids():
if key not in all_transactions:
all_transactions.add(key)
byte_array_list.append(bytearray(key))
@ -366,10 +362,9 @@ class MempoolManager:
"""
# Skip if already added
if spend_name in self.mempool.spends:
cost: Optional[uint64] = self.mempool.spends[spend_name].cost
assert cost is not None
return uint64(cost), MempoolInclusionStatus.SUCCESS, None
existing_item = self.mempool.get_spend_by_id(spend_name)
if existing_item is not None:
return existing_item.cost, MempoolInclusionStatus.SUCCESS, None
err, item, remove_items = await self.validate_spend_bundle(
new_spend, npc_result, spend_name, first_added_height
@ -517,9 +512,7 @@ class MempoolManager:
if tl_error:
assert_height = compute_assert_height(removal_record_dict, npc_result.conds)
potential = MempoolItem(
new_spend, uint64(fees), npc_result, cost, spend_name, additions, first_added_height, assert_height
)
potential = MempoolItem(new_spend, uint64(fees), npc_result, spend_name, first_added_height, assert_height)
if tl_error:
if tl_error is Err.ASSERT_HEIGHT_ABSOLUTE_FAILED or tl_error is Err.ASSERT_HEIGHT_RELATIVE_FAILED:
@ -529,9 +522,8 @@ class MempoolManager:
if fail_reason is Err.MEMPOOL_CONFLICT:
for conflicting in conflicts:
for c_sb_id in self.mempool.removal_coin_id_to_spendbundle_ids[conflicting.name()]:
sb: MempoolItem = self.mempool.spends[c_sb_id]
conflicting_pool_items[sb.name] = sb
for item in self.mempool.get_spends_by_coin_id(conflicting.name()):
conflicting_pool_items[item.name] = item
log.debug(f"Replace attempted. number of MempoolItems: {len(conflicting_pool_items)}")
if not self.can_replace(conflicting_pool_items, removal_record_dict, fees, fees_per_cost):
return Err.MEMPOOL_CONFLICT, potential, []
@ -562,7 +554,8 @@ class MempoolManager:
if record.spent:
return Err.DOUBLE_SPEND, []
# 2. Checks if there's a mempool conflict
if removal.name() in self.mempool.removal_coin_id_to_spendbundle_ids:
items: List[MempoolItem] = self.mempool.get_spends_by_coin_id(removal.name())
if len(items) > 0:
conflicts.append(removal)
if len(conflicts) > 0:
@ -572,8 +565,9 @@ class MempoolManager:
def get_spendbundle(self, bundle_hash: bytes32) -> Optional[SpendBundle]:
"""Returns a full SpendBundle if it's inside one the mempools"""
if bundle_hash in self.mempool.spends:
return self.mempool.spends[bundle_hash].spend_bundle
item: Optional[MempoolItem] = self.mempool.get_spend_by_id(bundle_hash)
if item is not None:
return item.spend_bundle
return None
def get_mempool_item(self, bundle_hash: bytes32, include_pending: bool = False) -> Optional[MempoolItem]:
@ -583,7 +577,7 @@ class MempoolManager:
If include_pending is specified, also check the PENDING cache.
"""
item = self.mempool.spends.get(bundle_hash, None)
item = self.mempool.get_spend_by_id(bundle_hash)
if not item and include_pending:
# no async lock needed since we're not mutating the pending_cache
item = self._pending_cache.get(bundle_hash)
@ -593,7 +587,7 @@ class MempoolManager:
return item
async def new_peak(
self, new_peak: Optional[BlockRecord], last_npc_result: Optional[NPCResult]
self, new_peak: Optional[BlockRecordProtocol], last_npc_result: Optional[NPCResult]
) -> List[Tuple[SpendBundle, NPCResult, bytes32]]:
"""
Called when a new peak is available, we try to recreate a mempool for the new tip.
@ -606,7 +600,7 @@ class MempoolManager:
return []
assert new_peak.timestamp is not None
self.fee_estimator.new_block_height(new_peak.height)
included_items = []
included_items: List[MempoolItemInfo] = []
use_optimization: bool = self.peak is not None and new_peak.prev_transaction_block_hash == self.peak.header_hash
self.peak = new_peak
@ -614,24 +608,19 @@ class MempoolManager:
if use_optimization and last_npc_result is not None:
# We don't reinitialize a mempool, just kick removed items
if last_npc_result.conds is not None:
spendbundle_ids_to_remove = []
spendbundle_ids_to_remove: List[bytes32] = []
for spend in last_npc_result.conds.spends:
if spend.coin_id in self.mempool.removal_coin_id_to_spendbundle_ids:
spendbundle_ids: List[bytes32] = self.mempool.removal_coin_id_to_spendbundle_ids[
bytes32(spend.coin_id)
]
spendbundle_ids_to_remove.extend(spendbundle_ids)
for spendbundle_id in spendbundle_ids:
item = self.mempool.spends.get(spendbundle_id)
if item:
included_items.append(item)
self.remove_seen(spendbundle_id)
items: List[MempoolItem] = self.mempool.get_spends_by_coin_id(bytes32(spend.coin_id))
for item in items:
included_items.append(MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
self.remove_seen(item.name)
spendbundle_ids_to_remove.append(item.name)
self.mempool.remove_from_pool(spendbundle_ids_to_remove, MempoolRemoveReason.BLOCK_INCLUSION)
else:
old_pool = self.mempool
self.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.seen_bundle_hashes = {}
for item in old_pool.spends.values():
for item in old_pool.all_spends():
_, result, err = await self.add_spend_bundle(
item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool
)
@ -643,7 +632,7 @@ class MempoolManager:
if result == MempoolInclusionStatus.FAILED and err == Err.DOUBLE_SPEND:
# Item was in mempool, but after the new block it's a double spend.
# Item is most likely included in the block.
included_items.append(item)
included_items.append(MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
potential_txs = self._pending_cache.drain(new_peak.height)
potential_txs.update(self._conflict_cache.drain())
@ -655,8 +644,8 @@ class MempoolManager:
if status == MempoolInclusionStatus.SUCCESS:
txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name))
log.info(
f"Size of mempool: {len(self.mempool.spends)} spends, "
f"cost: {self.mempool.total_mempool_cost} "
f"Size of mempool: {self.mempool.size()} spends, "
f"cost: {self.mempool.total_mempool_cost()} "
f"minimum fee rate (in FPC) to get in for 5M cost tx: {self.mempool.get_min_fee_rate(5000000)}"
)
self.mempool.fee_estimator.new_block(FeeBlockInfo(new_peak.height, included_items))
@ -668,11 +657,11 @@ class MempoolManager:
assert limit > 0
# Send 100 with the highest fee per cost
for dic in reversed(self.mempool.sorted_spends.values()):
for item in dic.values():
if len(items) == limit:
return items
if mempool_filter.Match(bytearray(item.spend_bundle_name)):
continue
items.append(item.spend_bundle)
for item in self.mempool.spends_by_feerate():
if len(items) >= limit:
return items
if mempool_filter.Match(bytearray(item.spend_bundle_name)):
continue
items.append(item.spend_bundle)
return items

View File

@ -34,7 +34,6 @@ class ConflictTxCache:
self._cache_cost += item.cost
while self._cache_cost > self._cache_max_total_cost or len(self._txs) > self._cache_max_size:
first_in = list(self._txs.keys())[0]
self._cache_cost -= self._txs[first_in].cost
self._txs.pop(first_in)
@ -77,7 +76,6 @@ class PendingTxCache:
self._by_height.setdefault(item.assert_height, {})[name] = item
while self._cache_cost > self._cache_max_total_cost or len(self._txs) > self._cache_max_size:
# we start removing items with the highest assert_height first
to_evict = self._by_height.items()[-1]
if to_evict[1] == {}:

View File

@ -112,7 +112,6 @@ class PeerSubscriptions:
break
def remove_peer(self, peer_id: bytes32) -> None:
counter = 0
puzzle_hashes = self._peer_puzzle_hash.get(peer_id)
if puzzle_hashes is not None:

View File

@ -54,7 +54,6 @@ def _create_shutdown_file() -> IO:
class WeightProofHandler:
LAMBDA_L = 100
C = 0.5
MAX_SAMPLES = 20
@ -74,7 +73,6 @@ class WeightProofHandler:
self.multiprocessing_context = multiprocessing_context
async def get_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]:
tip_rec = self.blockchain.try_block_record(tip)
if tip_rec is None:
log.error("unknown tip")
@ -527,7 +525,7 @@ class WeightProofHandler:
assert curr.reward_chain_block.challenge_chain_sp_vdf
cc_sp_vdf_info = curr.reward_chain_block.challenge_chain_sp_vdf
if not curr.challenge_chain_sp_proof.normalized_to_identity:
(_, _, _, _, cc_vdf_iters, _,) = get_signage_point_vdf_info(
(_, _, _, _, cc_vdf_iters, _) = get_signage_point_vdf_info(
self.constants,
curr.finished_sub_slots,
block_record.overflow,
@ -732,7 +730,7 @@ async def _challenge_block_vdfs(
block_rec: BlockRecord,
sub_blocks: Dict[bytes32, BlockRecord],
):
(_, _, _, _, cc_vdf_iters, _,) = get_signage_point_vdf_info(
(_, _, _, _, cc_vdf_iters, _) = get_signage_point_vdf_info(
constants,
header_block.finished_sub_slots,
block_rec.overflow,
@ -859,7 +857,6 @@ def _validate_sub_epoch_summaries(
constants: ConsensusConstants,
weight_proof: WeightProof,
) -> Tuple[Optional[List[SubEpochSummary]], Optional[List[uint128]]]:
last_ses_hash, last_ses_sub_height = _get_last_ses_hash(constants, weight_proof.recent_chain_data)
if last_ses_hash is None:
log.warning("could not find last ses block")
@ -1074,7 +1071,6 @@ def _validate_sub_slot_data(
sub_slots: List[SubSlotData],
ssi: uint64,
) -> Tuple[bool, List[Tuple[VDFProof, ClassgroupElement, VDFInfo]]]:
sub_slot_data = sub_slots[sub_slot_idx]
assert sub_slot_idx > 0
prev_ssd = sub_slots[sub_slot_idx - 1]
@ -1381,7 +1377,6 @@ def __get_rc_sub_slot(
summaries: List[SubEpochSummary],
curr_ssi: uint64,
) -> RewardChainSubSlot:
ses = summaries[uint32(segment.sub_epoch_n - 1)]
# find first challenge in sub epoch
first_idx = None
@ -1646,7 +1641,6 @@ def _validate_vdf_batch(
vdf_list: List[Tuple[bytes, bytes, bytes]],
shutdown_file_path: Optional[pathlib.Path] = None,
):
for vdf_proof_bytes, class_group_bytes, info in vdf_list:
vdf = VDFProof.from_bytes(vdf_proof_bytes)
class_group = ClassgroupElement.from_bytes(class_group_bytes)

View File

@ -8,6 +8,8 @@ from concurrent.futures.thread import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import Literal
from chia.consensus.constants import ConsensusConstants
from chia.plot_sync.sender import Sender
from chia.plotting.manager import PlotManager
@ -35,7 +37,6 @@ class Harvester:
_shut_down: bool
executor: ThreadPoolExecutor
state_changed_callback: Optional[StateChangedProtocol] = None
cached_challenges: List
constants: ConsensusConstants
_refresh_lock: asyncio.Lock
event_loop: asyncio.events.AbstractEventLoop
@ -50,7 +51,7 @@ class Harvester:
return self._server
def __init__(self, root_path: Path, config: Dict, constants: ConsensusConstants):
def __init__(self, root_path: Path, config: Dict[str, Any], constants: ConsensusConstants):
self.log = log
self.root_path = root_path
# TODO, remove checks below later after some versions / time
@ -76,38 +77,37 @@ class Harvester:
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=config["num_threads"])
self._server = None
self.constants = constants
self.cached_challenges = []
self.state_changed_callback: Optional[StateChangedProtocol] = None
self.parallel_read: bool = config.get("parallel_read", True)
async def _start(self):
async def _start(self) -> None:
self._refresh_lock = asyncio.Lock()
self.event_loop = asyncio.get_running_loop()
def _close(self):
def _close(self) -> None:
self._shut_down = True
self.executor.shutdown(wait=True)
self.plot_manager.stop_refreshing()
self.plot_manager.reset()
self.plot_sync_sender.stop()
async def _await_closed(self):
async def _await_closed(self) -> None:
await self.plot_sync_sender.await_closed()
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type)
async def on_connect(self, connection: WSChiaConnection):
async def on_connect(self, connection: WSChiaConnection) -> None:
self.state_changed("add_connection")
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback
def state_changed(self, change: str, change_data: Dict[str, Any] = None):
def state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None:
if self.state_changed_callback is not None:
self.state_changed_callback(change, change_data)
def _plot_refresh_callback(self, event: PlotRefreshEvents, update_result: PlotRefreshResult):
def _plot_refresh_callback(self, event: PlotRefreshEvents, update_result: PlotRefreshResult) -> None:
log_function = self.log.debug if event == PlotRefreshEvents.batch_processed else self.log.info
log_function(
f"_plot_refresh_callback: event {event.name}, loaded {len(update_result.loaded)}, "
@ -123,16 +123,16 @@ class Harvester:
if event == PlotRefreshEvents.done:
self.plot_sync_sender.sync_done(update_result.removed, update_result.duration)
def on_disconnect(self, connection: WSChiaConnection):
def on_disconnect(self, connection: WSChiaConnection) -> None:
self.log.info(f"peer disconnected {connection.get_peer_logging()}")
self.state_changed("close_connection")
self.plot_sync_sender.stop()
asyncio.run_coroutine_threadsafe(self.plot_sync_sender.await_closed(), asyncio.get_running_loop())
self.plot_manager.stop_refreshing()
def get_plots(self) -> Tuple[List[Dict], List[str], List[str]]:
def get_plots(self) -> Tuple[List[Dict[str, Any]], List[str], List[str]]:
self.log.debug(f"get_plots prover items: {self.plot_manager.plot_count()}")
response_plots: List[Dict] = []
response_plots: List[Dict[str, Any]] = []
with self.plot_manager:
for path, plot_info in self.plot_manager.plots.items():
prover = plot_info.prover
@ -159,7 +159,7 @@ class Harvester:
[str(s) for s in self.plot_manager.no_key_filenames],
)
def delete_plot(self, str_path: str):
def delete_plot(self, str_path: str) -> Literal[True]:
remove_plot(Path(str_path))
self.plot_manager.trigger_refresh()
self.state_changed("plots")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import time
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
from blspy import AugSchemeMPL, G1Element, G2Element
@ -14,7 +14,7 @@ from chia.protocols import harvester_protocol
from chia.protocols.farmer_protocol import FarmingInfo
from chia.protocols.harvester_protocol import Plot, PlotSyncResponse
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.server.outbound_message import Message, make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.proof_of_space import (
ProofOfSpace,
@ -37,7 +37,7 @@ class HarvesterAPI:
@api_request(peer_required=True)
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake, peer: WSChiaConnection
):
) -> None:
"""
Handshake between the harvester and farmer. The harvester receives the pool public keys,
as well as the farmer pks, which must be put into the plots, before the plotting process begins.
@ -53,7 +53,7 @@ class HarvesterAPI:
@api_request(peer_required=True)
async def new_signage_point_harvester(
self, new_challenge: harvester_protocol.NewSignagePointHarvester, peer: WSChiaConnection
):
) -> None:
"""
The harvester receives a new signage point from the farmer, this happens at the start of each slot.
The harvester does a few things:
@ -246,7 +246,7 @@ class HarvesterAPI:
)
@api_request()
async def request_signatures(self, request: harvester_protocol.RequestSignatures):
async def request_signatures(self, request: harvester_protocol.RequestSignatures) -> Optional[Message]:
"""
The farmer requests a signature on the header hash, for one of the proofs that we found.
A signature is created on the header hash using the harvester private key. This can also
@ -295,7 +295,7 @@ class HarvesterAPI:
return make_msg(ProtocolMessageTypes.respond_signatures, response)
@api_request()
async def request_plots(self, _: harvester_protocol.RequestPlots):
async def request_plots(self, _: harvester_protocol.RequestPlots) -> Message:
plots_response = []
plots, failed_to_open_filenames, no_key_filenames = self.harvester.get_plots()
for plot in plots:
@ -316,5 +316,5 @@ class HarvesterAPI:
return make_msg(ProtocolMessageTypes.respond_plots, response)
@api_request()
async def plot_sync_response(self, response: PlotSyncResponse):
async def plot_sync_response(self, response: PlotSyncResponse) -> None:
self.harvester.plot_sync_sender.set_response(response)

View File

@ -118,8 +118,7 @@ class Sender:
await self.await_closed()
if self._task is None:
self._task = asyncio.create_task(self._run())
# TODO, Add typing in PlotManager
if not self._plot_manager.initial_refresh() or self._sync_id != 0: # type:ignore[no-untyped-call]
if not self._plot_manager.initial_refresh() or self._sync_id != 0:
self._reset()
else:
raise AlreadyStartedError()
@ -173,7 +172,7 @@ class Sender:
return False
if response.identifier.sync_id != self._response.identifier.sync_id:
log.warning(
"set_response unexpected sync-id: " f"{response.identifier.sync_id}/{self._response.identifier.sync_id}"
"set_response unexpected sync-id: {response.identifier.sync_id}/{self._response.identifier.sync_id}"
)
return False
if response.identifier.message_id != self._response.identifier.message_id:
@ -184,7 +183,7 @@ class Sender:
return False
if response.message_type != int16(self._response.message_type.value):
log.warning(
"set_response unexpected message-type: " f"{response.message_type}/{self._response.message_type.value}"
"set_response unexpected message-type: {response.message_type}/{self._response.message_type.value}"
)
return False
log.debug(f"set_response valid {response}")

View File

@ -4,7 +4,7 @@ import logging
from collections import Counter
from pathlib import Path
from time import sleep, time
from typing import List
from typing import List, Optional
from blspy import G1Element
from chiapos import Verifier
@ -21,20 +21,28 @@ from chia.plotting.util import (
from chia.util.bech32m import encode_puzzle_hash
from chia.util.config import load_config
from chia.util.hash import std_hash
from chia.util.ints import uint32
from chia.util.keychain import Keychain
from chia.wallet.derive_keys import master_sk_to_farmer_sk, master_sk_to_local_sk
log = logging.getLogger(__name__)
def plot_refresh_callback(event: PlotRefreshEvents, refresh_result: PlotRefreshResult):
def plot_refresh_callback(event: PlotRefreshEvents, refresh_result: PlotRefreshResult) -> None:
log.info(f"event: {event.name}, loaded {len(refresh_result.loaded)} plots, {refresh_result.remaining} remaining")
def check_plots(root_path, num, challenge_start, grep_string, list_duplicates, debug_show_memo):
def check_plots(
root_path: Path,
num: Optional[int],
challenge_start: Optional[int],
grep_string: str,
list_duplicates: bool,
debug_show_memo: bool,
) -> None:
config = load_config(root_path, "config.yaml")
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
plot_refresh_parameter: PlotsRefreshParameter = PlotsRefreshParameter(batch_sleep_milliseconds=0)
plot_refresh_parameter: PlotsRefreshParameter = PlotsRefreshParameter(batch_sleep_milliseconds=uint32(0))
plot_manager: PlotManager = PlotManager(
root_path,
match_str=grep_string,
@ -97,7 +105,7 @@ def check_plots(root_path, num, challenge_start, grep_string, list_duplicates, d
log.info("")
log.info("")
log.info(f"Starting to test each plot with {num} challenges each\n")
total_good_plots: Counter = Counter()
total_good_plots: Counter[str] = Counter()
total_size = 0
bad_plots_list: List[Path] = []
@ -179,7 +187,7 @@ def check_plots(root_path, num, challenge_start, grep_string, list_duplicates, d
log.info("Summary")
total_plots: int = sum(list(total_good_plots.values()))
log.info(f"Found {total_plots} valid plots, total size {total_size / (1024 * 1024 * 1024 * 1024):.5f} TiB")
for (k, count) in sorted(dict(total_good_plots).items()):
for k, count in sorted(dict(total_good_plots).items()):
log.info(f"{count} plots of size {k}")
grand_total_bad = len(bad_plots_list) + len(plot_manager.failed_to_open_filenames)
if grand_total_bad > 0:

View File

@ -66,30 +66,33 @@ class PlotKeysResolver:
if self.resolved_keys is not None:
return self.resolved_keys
if self.connect_to_daemon:
keychain_proxy: Optional[KeychainProxy] = await connect_to_keychain_and_validate(self.root_path, self.log)
else:
keychain_proxy = wrap_local_keychain(Keychain(), log=self.log)
keychain_proxy: Optional[KeychainProxy] = None
try:
if self.connect_to_daemon:
keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
else:
keychain_proxy = wrap_local_keychain(Keychain(), log=self.log)
farmer_public_key: G1Element
if self.farmer_public_key is not None:
farmer_public_key = G1Element.from_bytes(bytes.fromhex(self.farmer_public_key))
else:
farmer_public_key = await self.get_farmer_public_key(keychain_proxy)
farmer_public_key: G1Element
if self.farmer_public_key is not None:
farmer_public_key = G1Element.from_bytes(bytes.fromhex(self.farmer_public_key))
else:
farmer_public_key = await self.get_farmer_public_key(keychain_proxy)
pool_public_key: Optional[G1Element] = None
if self.pool_public_key is not None:
if self.pool_contract_address is not None:
raise RuntimeError("Choose one of pool_contract_address and pool_public_key")
pool_public_key = G1Element.from_bytes(bytes.fromhex(self.pool_public_key))
else:
if self.pool_contract_address is None:
# If nothing is set, farms to the provided key (or the first key)
pool_public_key = await self.get_pool_public_key(keychain_proxy)
pool_public_key: Optional[G1Element] = None
if self.pool_public_key is not None:
if self.pool_contract_address is not None:
raise RuntimeError("Choose one of pool_contract_address and pool_public_key")
pool_public_key = G1Element.from_bytes(bytes.fromhex(self.pool_public_key))
else:
if self.pool_contract_address is None:
# If nothing is set, farms to the provided key (or the first key)
pool_public_key = await self.get_pool_public_key(keychain_proxy)
self.resolved_keys = PlotKeys(farmer_public_key, pool_public_key, self.pool_contract_address)
if keychain_proxy is not None:
await keychain_proxy.close()
self.resolved_keys = PlotKeys(farmer_public_key, pool_public_key, self.pool_contract_address)
finally:
if keychain_proxy is not None:
await keychain_proxy.close()
return self.resolved_keys
async def get_sk(self, keychain_proxy: Optional[KeychainProxy] = None) -> Optional[Tuple[PrivateKey, bytes]]:

View File

@ -73,7 +73,7 @@ class PlotManager:
def __exit__(self, exc_type, exc_value, exc_traceback):
self._lock.release()
def reset(self):
def reset(self) -> None:
with self:
self.last_refresh_time = time.time()
self.plots.clear()
@ -89,11 +89,11 @@ class PlotManager:
self.farmer_public_keys = farmer_public_keys
self.pool_public_keys = pool_public_keys
def initial_refresh(self):
def initial_refresh(self) -> bool:
return self._initial
def public_keys_available(self):
return len(self.farmer_public_keys) and len(self.pool_public_keys)
def public_keys_available(self) -> bool:
return len(self.farmer_public_keys) > 0 and len(self.pool_public_keys) > 0
def plot_count(self) -> int:
with self:

View File

@ -62,7 +62,7 @@ def load_pool_config(root_path: Path) -> List[PoolWalletConfig]:
# TODO: remove this a few versions after 1.3, since authentication_public_key is deprecated. This is here to support
# downgrading to versions older than 1.3.
def add_auth_key(root_path: Path, config_entry: PoolWalletConfig, auth_key: G1Element):
def add_auth_key(root_path: Path, config_entry: PoolWalletConfig, auth_key: G1Element) -> None:
with lock_and_load_config(root_path, "config.yaml") as config:
pool_list = config["pool"].get("pool_list", [])
updated = False
@ -82,7 +82,7 @@ def add_auth_key(root_path: Path, config_entry: PoolWalletConfig, auth_key: G1El
save_config(root_path, "config.yaml", config)
async def update_pool_config(root_path: Path, pool_config_list: List[PoolWalletConfig]):
async def update_pool_config(root_path: Path, pool_config_list: List[PoolWalletConfig]) -> None:
with lock_and_load_config(root_path, "config.yaml") as full_config:
full_config["pool"]["pool_list"] = [c.to_json_dict() for c in pool_config_list]
save_config(root_path, "config.yaml", full_config)

View File

@ -892,9 +892,9 @@ class PoolWallet:
if self.target_state is None:
return
if self.target_state == pool_wallet_info.current.state:
if self.target_state == pool_wallet_info.current:
self.target_state = None
raise ValueError("Internal error")
raise ValueError(f"Internal error. Pool wallet {self.wallet_id} state: {pool_wallet_info.current}")
if (
self.target_state.state in [FARMING_TO_POOL, SELF_POOLING]

View File

@ -173,5 +173,5 @@ def get_current_authentication_token(timeout: uint8) -> uint64:
# Validate a given authentication token against our local time
def validate_authentication_token(token: uint64, timeout: uint8):
def validate_authentication_token(token: uint64, timeout: uint8) -> bool:
return abs(token - get_current_authentication_token(timeout)) <= timeout

View File

@ -30,7 +30,7 @@ class CrawlerRpcApi:
return payloads
async def get_peer_counts(self, _request: Dict) -> EndpointResult:
async def get_peer_counts(self, _request: Dict[str, Any]) -> EndpointResult:
ipv6_addresses_count = 0
for host in self.service.best_timestamp_per_peer.keys():
try:
@ -54,7 +54,7 @@ class CrawlerRpcApi:
}
return data
async def get_ips_after_timestamp(self, _request: Dict) -> EndpointResult:
async def get_ips_after_timestamp(self, _request: Dict[str, Any]) -> EndpointResult:
after = _request.get("after", None)
if after is None:
raise ValueError("`after` is required and must be a unix timestamp")

View File

@ -11,71 +11,62 @@ from chia.util.ints import uint64
class DataLayerRpcClient(RpcClient):
async def create_data_store(self, fee: Optional[uint64]) -> Dict[str, Any]:
response = await self.fetch("create_data_store", {"fee": fee})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_value(self, store_id: bytes32, key: bytes, root_hash: Optional[bytes32]) -> Dict[str, Any]:
request: Dict[str, Any] = {"id": store_id.hex(), "key": key.hex()}
if root_hash is not None:
request["root_hash"] = root_hash.hex()
response = await self.fetch("get_value", request)
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def update_data_store(
self, store_id: bytes32, changelist: List[Dict[str, str]], fee: Optional[uint64]
) -> Dict[str, Any]:
response = await self.fetch("batch_update", {"id": store_id.hex(), "changelist": changelist, "fee": fee})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_keys_values(self, store_id: bytes32, root_hash: Optional[bytes32]) -> Dict[str, Any]:
request: Dict[str, Any] = {"id": store_id.hex()}
if root_hash is not None:
request["root_hash"] = root_hash.hex()
response = await self.fetch("get_keys_values", request)
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_keys(self, store_id: bytes32, root_hash: Optional[bytes32]) -> Dict[str, Any]:
request: Dict[str, Any] = {"id": store_id.hex()}
if root_hash is not None:
request["root_hash"] = root_hash.hex()
response = await self.fetch("get_keys", request)
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_ancestors(self, store_id: bytes32, hash: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_ancestors", {"id": store_id.hex(), "hash": hash})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_root(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_root", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_local_root(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_local_root", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_roots(self, store_ids: List[bytes32]) -> Dict[str, Any]:
response = await self.fetch("get_roots", {"ids": store_ids})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def subscribe(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("subscribe", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
return response
async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("remove_subscriptions", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
return response
async def unsubscribe(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("unsubscribe", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response
async def add_missing_files(
self, store_ids: Optional[List[bytes32]], overwrite: Optional[bool], foldername: Optional[Path]
@ -88,40 +79,40 @@ class DataLayerRpcClient(RpcClient):
if foldername is not None:
request["foldername"] = str(foldername)
response = await self.fetch("add_missing_files", request)
return response # type: ignore[no-any-return]
return response
async def get_kv_diff(self, store_id: bytes32, hash_1: bytes32, hash_2: bytes32) -> Dict[str, Any]:
response = await self.fetch(
"get_kv_diff", {"id": store_id.hex(), "hash_1": hash_1.hex(), "hash_2": hash_2.hex()}
)
return response # type: ignore[no-any-return]
return response
async def get_root_history(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_root_history", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response
async def add_mirror(
self, store_id: bytes32, urls: List[str], amount: int, fee: Optional[uint64]
) -> Dict[str, Any]:
response = await self.fetch("add_mirror", {"id": store_id.hex(), "urls": urls, "amount": amount, "fee": fee})
return response # type: ignore[no-any-return]
return response
async def delete_mirror(self, coin_id: bytes32, fee: Optional[uint64]) -> Dict[str, Any]:
response = await self.fetch("delete_mirror", {"coin_id": coin_id.hex(), "fee": fee})
return response # type: ignore[no-any-return]
return response
async def get_mirrors(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_mirrors", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response
async def get_subscriptions(self) -> Dict[str, Any]:
response = await self.fetch("subscriptions", {})
return response # type: ignore[no-any-return]
return response
async def get_owned_stores(self) -> Dict[str, Any]:
response = await self.fetch("get_owned_stores", {})
return response # type: ignore[no-any-return]
return response
async def get_sync_status(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_sync_status", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response

View File

@ -216,7 +216,7 @@ class FarmerRpcApi:
return payloads
async def get_signage_point(self, request: Dict) -> EndpointResult:
async def get_signage_point(self, request: Dict[str, Any]) -> EndpointResult:
sp_hash = hexstr_to_bytes(request["sp_hash"])
for _, sps in self.service.sps.items():
for sp in sps:
@ -235,7 +235,7 @@ class FarmerRpcApi:
}
raise ValueError(f"Signage point {sp_hash.hex()} not found")
async def get_signage_points(self, _: Dict) -> EndpointResult:
async def get_signage_points(self, _: Dict[str, Any]) -> EndpointResult:
result: List[Dict[str, Any]] = []
for sps in self.service.sps.values():
for sp in sps:
@ -255,12 +255,12 @@ class FarmerRpcApi:
)
return {"signage_points": result}
async def get_reward_targets(self, request: Dict) -> EndpointResult:
async def get_reward_targets(self, request: Dict[str, Any]) -> EndpointResult:
search_for_private_key = request["search_for_private_key"]
max_ph_to_search = request.get("max_ph_to_search", 500)
return await self.service.get_reward_targets(search_for_private_key, max_ph_to_search)
async def set_reward_targets(self, request: Dict) -> EndpointResult:
async def set_reward_targets(self, request: Dict[str, Any]) -> EndpointResult:
farmer_target, pool_target = None, None
if "farmer_target" in request:
farmer_target = request["farmer_target"]
@ -278,7 +278,7 @@ class FarmerRpcApi:
)
return plot_count
async def get_pool_state(self, _: Dict) -> EndpointResult:
async def get_pool_state(self, request: Dict[str, Any]) -> EndpointResult:
pools_list = []
for p2_singleton_puzzle_hash, pool_dict in self.service.pool_state.items():
pool_state = pool_dict.copy()
@ -287,12 +287,12 @@ class FarmerRpcApi:
pools_list.append(pool_state)
return {"pool_state": pools_list}
async def set_payout_instructions(self, request: Dict) -> EndpointResult:
async def set_payout_instructions(self, request: Dict[str, Any]) -> EndpointResult:
launcher_id: bytes32 = bytes32.from_hexstr(request["launcher_id"])
await self.service.set_payout_instructions(launcher_id, request["payout_instructions"])
return {}
async def get_harvesters(self, _: Dict) -> EndpointResult:
async def get_harvesters(self, request: Dict[str, Any]) -> EndpointResult:
return await self.service.get_harvesters(False)
async def get_harvesters_summary(self, _: Dict[str, object]) -> EndpointResult:
@ -335,7 +335,7 @@ class FarmerRpcApi:
async def get_harvester_plots_duplicates(self, request_dict: Dict[str, object]) -> EndpointResult:
return self.paginated_plot_path_request(Receiver.duplicates, request_dict)
async def get_pool_login_link(self, request: Dict) -> EndpointResult:
async def get_pool_login_link(self, request: Dict[str, Any]) -> EndpointResult:
launcher_id: bytes32 = bytes32(hexstr_to_bytes(request["launcher_id"]))
login_link: Optional[str] = await self.service.generate_login_link(launcher_id)
if login_link is None:

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, cast
from chia.rpc.farmer_rpc_api import PlotInfoRequestData, PlotPathRequestData
from chia.rpc.rpc_client import RpcClient
@ -17,16 +17,16 @@ class FarmerRpcClient(RpcClient):
to the full node.
"""
async def get_signage_point(self, sp_hash: bytes32) -> Optional[Dict]:
async def get_signage_point(self, sp_hash: bytes32) -> Optional[Dict[str, Any]]:
try:
return await self.fetch("get_signage_point", {"sp_hash": sp_hash.hex()})
except ValueError:
return None
async def get_signage_points(self) -> List[Dict]:
return (await self.fetch("get_signage_points", {}))["signage_points"]
async def get_signage_points(self) -> List[Dict[str, Any]]:
return cast(List[Dict[str, Any]], (await self.fetch("get_signage_points", {}))["signage_points"])
async def get_reward_targets(self, search_for_private_key: bool, max_ph_to_search: int = 500) -> Dict:
async def get_reward_targets(self, search_for_private_key: bool, max_ph_to_search: int = 500) -> Dict[str, Any]:
response = await self.fetch(
"get_reward_targets",
{"search_for_private_key": search_for_private_key, "max_ph_to_search": max_ph_to_search},
@ -41,7 +41,11 @@ class FarmerRpcClient(RpcClient):
return_dict["have_farmer_sk"] = response["have_farmer_sk"]
return return_dict
async def set_reward_targets(self, farmer_target: Optional[str] = None, pool_target: Optional[str] = None) -> Dict:
async def set_reward_targets(
self,
farmer_target: Optional[str] = None,
pool_target: Optional[str] = None,
) -> Dict[str, Any]:
request = {}
if farmer_target is not None:
request["farmer_target"] = farmer_target
@ -49,10 +53,10 @@ class FarmerRpcClient(RpcClient):
request["pool_target"] = pool_target
return await self.fetch("set_reward_targets", request)
async def get_pool_state(self) -> Dict:
async def get_pool_state(self) -> Dict[str, Any]:
return await self.fetch("get_pool_state", {})
async def set_payout_instructions(self, launcher_id: bytes32, payout_instructions: str) -> Dict:
async def set_payout_instructions(self, launcher_id: bytes32, payout_instructions: str) -> Dict[str, Any]:
request = {"launcher_id": launcher_id.hex(), "payout_instructions": payout_instructions}
return await self.fetch("set_payout_instructions", request)
@ -76,6 +80,7 @@ class FarmerRpcClient(RpcClient):
async def get_pool_login_link(self, launcher_id: bytes32) -> Optional[str]:
try:
return (await self.fetch("get_pool_login_link", {"launcher_id": launcher_id.hex()}))["login_link"]
result = await self.fetch("get_pool_login_link", {"launcher_id": launcher_id.hex()})
return cast(Optional[str], result["login_link"])
except ValueError:
return None

View File

@ -26,16 +26,16 @@ from chia.util.math import make_monotonically_decreasing
from chia.util.ws_message import WsRpcMessage, create_payload_dict
def coin_record_dict_backwards_compat(coin_record: Dict[str, Any]):
def coin_record_dict_backwards_compat(coin_record: Dict[str, Any]) -> Dict[str, bool]:
coin_record["spent"] = coin_record["spent_block_index"] > 0
return coin_record
class FullNodeRpcApi:
def __init__(self, service: FullNode):
def __init__(self, service: FullNode) -> None:
self.service = service
self.service_name = "chia_full_node"
self.cached_blockchain_state: Optional[Dict] = None
self.cached_blockchain_state: Optional[Dict[str, Any]] = None
def get_routes(self) -> Dict[str, Endpoint]:
return {
@ -73,7 +73,7 @@ class FullNodeRpcApi:
"/get_fee_estimate": self.get_fee_estimate,
}
async def _state_changed(self, change: str, change_data: Dict[str, Any] = None) -> List[WsRpcMessage]:
async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> List[WsRpcMessage]:
if change_data is None:
change_data = {}
@ -105,17 +105,17 @@ class FullNodeRpcApi:
# this function is just here for backwards-compatibility. It will probably
# be removed in the future
async def get_initial_freeze_period(self, _: Dict) -> EndpointResult:
async def get_initial_freeze_period(self, _: Dict[str, Any]) -> EndpointResult:
# Mon May 03 2021 17:00:00 GMT+0000
return {"INITIAL_FREEZE_END_TIMESTAMP": 1620061200}
async def get_blockchain_state(self, _request: Dict) -> EndpointResult:
async def get_blockchain_state(self, _: Dict[str, Any]) -> EndpointResult:
"""
Returns a summary of the node's view of the blockchain.
"""
node_id = self.service.server.node_id.hex()
if self.service.initialized is False:
res: Dict = {
res = {
"blockchain_state": {
"peak": None,
"genesis_challenge_initialized": self.service.initialized,
@ -179,9 +179,9 @@ class FullNodeRpcApi:
space = {"space": uint128(0)}
if self.service.mempool_manager is not None:
mempool_size = len(self.service.mempool_manager.mempool.spends)
mempool_cost = self.service.mempool_manager.mempool.total_mempool_cost
mempool_fees = self.service.mempool_manager.mempool.total_mempool_fees
mempool_size = self.service.mempool_manager.mempool.size()
mempool_cost = self.service.mempool_manager.mempool.total_mempool_cost()
mempool_fees = self.service.mempool_manager.mempool.total_mempool_fees()
mempool_min_fee_5m = self.service.mempool_manager.mempool.get_min_fee_rate(5000000)
mempool_max_total_cost = self.service.mempool_manager.mempool_max_total_cost
else:
@ -199,7 +199,7 @@ class FullNodeRpcApi:
synced = await self.service.synced() and is_connected
assert space is not None
response: Dict = {
response = {
"blockchain_state": {
"peak": peak,
"genesis_challenge_initialized": self.service.initialized,
@ -228,12 +228,12 @@ class FullNodeRpcApi:
self.cached_blockchain_state = dict(response["blockchain_state"])
return response
async def get_network_info(self, request: Dict) -> EndpointResult:
async def get_network_info(self, _: Dict[str, Any]) -> EndpointResult:
network_name = self.service.config["selected_network"]
address_prefix = self.service.config["network_overrides"]["config"][network_name]["address_prefix"]
return {"network_name": network_name, "network_prefix": address_prefix}
async def get_recent_signage_point_or_eos(self, request: Dict) -> EndpointResult:
async def get_recent_signage_point_or_eos(self, request: Dict[str, Any]) -> EndpointResult:
if "sp_hash" not in request:
challenge_hash: bytes32 = bytes32.from_hexstr(request["challenge_hash"])
# This is the case of getting an end of slot
@ -324,7 +324,7 @@ class FullNodeRpcApi:
return {"signage_point": sp, "time_received": time_received, "reverted": True}
async def get_block(self, request: Dict) -> EndpointResult:
async def get_block(self, request: Dict[str, Any]) -> EndpointResult:
if "header_hash" not in request:
raise ValueError("No header_hash in request")
header_hash = bytes32.from_hexstr(request["header_hash"])
@ -335,7 +335,7 @@ class FullNodeRpcApi:
return {"block": block}
async def get_blocks(self, request: Dict) -> EndpointResult:
async def get_blocks(self, request: Dict[str, Any]) -> EndpointResult:
if "start" not in request:
raise ValueError("No start in request")
if "end" not in request:
@ -365,7 +365,7 @@ class FullNodeRpcApi:
json_blocks.append(json)
return {"blocks": json_blocks}
async def get_block_count_metrics(self, request: Dict) -> EndpointResult:
async def get_block_count_metrics(self, _: Dict[str, Any]) -> EndpointResult:
compact_blocks = 0
uncompact_blocks = 0
with log_exceptions(self.service.log, consume=True):
@ -385,7 +385,7 @@ class FullNodeRpcApi:
}
}
async def get_block_records(self, request: Dict) -> EndpointResult:
async def get_block_records(self, request: Dict[str, Any]) -> EndpointResult:
if "start" not in request:
raise ValueError("No start in request")
if "end" not in request:
@ -415,7 +415,7 @@ class FullNodeRpcApi:
records.append(record)
return {"block_records": records}
async def get_block_spends(self, request: Dict) -> EndpointResult:
async def get_block_spends(self, request: Dict[str, Any]) -> EndpointResult:
if "header_hash" not in request:
raise ValueError("No header_hash in request")
header_hash = bytes32.from_hexstr(request["header_hash"])
@ -432,7 +432,7 @@ class FullNodeRpcApi:
return {"block_spends": spends}
async def get_block_record_by_height(self, request: Dict) -> EndpointResult:
async def get_block_record_by_height(self, request: Dict[str, Any]) -> EndpointResult:
if "height" not in request:
raise ValueError("No height in request")
height = request["height"]
@ -451,7 +451,7 @@ class FullNodeRpcApi:
raise ValueError(f"Block {header_hash} does not exist")
return {"block_record": record}
async def get_block_record(self, request: Dict) -> EndpointResult:
async def get_block_record(self, request: Dict[str, Any]) -> EndpointResult:
if "header_hash" not in request:
raise ValueError("header_hash not in request")
header_hash_str = request["header_hash"]
@ -465,8 +465,7 @@ class FullNodeRpcApi:
return {"block_record": record}
async def get_unfinished_block_headers(self, request: Dict) -> EndpointResult:
async def get_unfinished_block_headers(self, _request: Dict[str, Any]) -> EndpointResult:
peak: Optional[BlockRecord] = self.service.blockchain.get_peak()
if peak is None:
return {"headers": []}
@ -486,7 +485,7 @@ class FullNodeRpcApi:
response_headers.append(unfinished_header_block)
return {"headers": response_headers}
async def get_network_space(self, request: Dict) -> EndpointResult:
async def get_network_space(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves an estimate of total space validating the chain
between two block header hashes.
@ -526,7 +525,7 @@ class FullNodeRpcApi:
)
return {"space": uint128(int(network_space_bytes_estimate))}
async def get_coin_records_by_puzzle_hash(self, request: Dict) -> EndpointResult:
async def get_coin_records_by_puzzle_hash(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves the coins for a given puzzlehash, by default returns unspent coins.
"""
@ -545,7 +544,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_records_by_puzzle_hashes(self, request: Dict) -> EndpointResult:
async def get_coin_records_by_puzzle_hashes(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves the coins for a given puzzlehash, by default returns unspent coins.
"""
@ -567,7 +566,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_record_by_name(self, request: Dict) -> EndpointResult:
async def get_coin_record_by_name(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves a coin record by it's name.
"""
@ -581,7 +580,7 @@ class FullNodeRpcApi:
return {"coin_record": coin_record_dict_backwards_compat(coin_record.to_json_dict())}
async def get_coin_records_by_names(self, request: Dict) -> EndpointResult:
async def get_coin_records_by_names(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves the coins for given coin IDs, by default returns unspent coins.
"""
@ -603,7 +602,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_records_by_parent_ids(self, request: Dict) -> EndpointResult:
async def get_coin_records_by_parent_ids(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves the coins for given parent coin IDs, by default returns unspent coins.
"""
@ -625,7 +624,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_records_by_hint(self, request: Dict) -> EndpointResult:
async def get_coin_records_by_hint(self, request: Dict[str, Any]) -> EndpointResult:
"""
Retrieves coins by hint, by default returns unspent coins.
"""
@ -654,7 +653,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def push_tx(self, request: Dict) -> EndpointResult:
async def push_tx(self, request: Dict[str, Any]) -> EndpointResult:
if "spend_bundle" not in request:
raise ValueError("Spend bundle not in request")
@ -665,7 +664,7 @@ class FullNodeRpcApi:
status = MempoolInclusionStatus.SUCCESS
error = None
else:
status, error = await self.service.respond_transaction(spend_bundle, spend_name)
status, error = await self.service.add_transaction(spend_bundle, spend_name)
if status != MempoolInclusionStatus.SUCCESS:
if self.service.mempool_manager.get_spendbundle(spend_name) is not None:
# Already in mempool
@ -679,7 +678,7 @@ class FullNodeRpcApi:
"status": status.name,
}
async def get_puzzle_and_solution(self, request: Dict) -> EndpointResult:
async def get_puzzle_and_solution(self, request: Dict[str, Any]) -> EndpointResult:
coin_name: bytes32 = bytes32.from_hexstr(request["coin_id"])
height = request["height"]
coin_record = await self.service.coin_store.get_coin_record(coin_name)
@ -704,7 +703,7 @@ class FullNodeRpcApi:
return {"coin_solution": CoinSpend(coin_record.coin, puzzle, solution)}
async def get_additions_and_removals(self, request: Dict) -> EndpointResult:
async def get_additions_and_removals(self, request: Dict[str, Any]) -> EndpointResult:
if "header_hash" not in request:
raise ValueError("No header_hash in request")
header_hash = bytes32.from_hexstr(request["header_hash"])
@ -724,17 +723,17 @@ class FullNodeRpcApi:
"removals": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in removals],
}
async def get_all_mempool_tx_ids(self, request: Dict) -> EndpointResult:
ids = list(self.service.mempool_manager.mempool.spends.keys())
async def get_all_mempool_tx_ids(self, _: Dict[str, Any]) -> EndpointResult:
ids = list(self.service.mempool_manager.mempool.all_spend_ids())
return {"tx_ids": ids}
async def get_all_mempool_items(self, request: Dict) -> EndpointResult:
async def get_all_mempool_items(self, _: Dict[str, Any]) -> EndpointResult:
spends = {}
for tx_id, item in self.service.mempool_manager.mempool.spends.items():
spends[tx_id.hex()] = item
for item in self.service.mempool_manager.mempool.all_spends():
spends[item.name.hex()] = item.to_json_dict()
return {"mempool_items": spends}
async def get_mempool_item_by_tx_id(self, request: Dict) -> EndpointResult:
async def get_mempool_item_by_tx_id(self, request: Dict[str, Any]) -> EndpointResult:
if "tx_id" not in request:
raise ValueError("No tx_id in request")
include_pending: bool = request.get("include_pending", False)
@ -744,7 +743,7 @@ class FullNodeRpcApi:
if item is None:
raise ValueError(f"Tx id 0x{tx_id.hex()} not in the mempool")
return {"mempool_item": item}
return {"mempool_item": item.to_json_dict()}
def _get_spendbundle_type_cost(self, name: str) -> uint64:
"""
@ -765,7 +764,7 @@ class FullNodeRpcApi:
}
return uint64(tx_cost_estimates[name])
async def _validate_fee_estimate_cost(self, request: Dict) -> uint64:
async def _validate_fee_estimate_cost(self, request: Dict[str, Any]) -> uint64:
c = 0
ns = ["spend_bundle", "cost", "spend_type"]
for n in ns:
@ -790,13 +789,13 @@ class FullNodeRpcApi:
cost *= request.get("spend_count", 1)
return uint64(cost)
def _validate_target_times(self, request: Dict) -> None:
def _validate_target_times(self, request: Dict[str, Any]) -> None:
if "target_times" not in request:
raise ValueError("Request must contain 'target_times' array")
if any(t < 0 for t in request["target_times"]):
raise ValueError("'target_times' array members must be non-negative")
async def get_fee_estimate(self, request: Dict) -> Dict[str, Any]:
async def get_fee_estimate(self, request: Dict[str, Any]) -> Dict[str, Any]:
self._validate_target_times(request)
spend_cost = await self._validate_fee_estimate_cost(request)
@ -812,9 +811,9 @@ class FullNodeRpcApi:
# such as estimating a higher fee for a longer transaction time.
estimates = make_monotonically_decreasing(estimates)
current_fee_rate = estimator.estimate_fee_rate(time_offset_seconds=1)
mempool_size = self.service.mempool_manager.mempool.total_mempool_cost
mempool_fees = self.service.mempool_manager.mempool.total_mempool_fees
num_mempool_spends = len(self.service.mempool_manager.mempool.spends)
mempool_size = self.service.mempool_manager.mempool.total_mempool_cost()
mempool_fees = self.service.mempool_manager.mempool.total_mempool_fees()
num_mempool_spends = self.service.mempool_manager.mempool.size()
mempool_max_size = estimator.mempool_max_size()
blockchain_state = await self.get_blockchain_state({})
synced = blockchain_state["blockchain_state"]["sync"]["synced"]

View File

@ -21,7 +21,7 @@ class HarvesterRpcClient(RpcClient):
await self.fetch("refresh_plots", {})
async def delete_plot(self, filename: str) -> bool:
return await self.fetch("delete_plot", {"filename": filename})
return (await self.fetch("delete_plot", {"filename": filename}))["success"]
async def add_plot_directory(self, dirname: str) -> bool:
return (await self.fetch("add_plot_directory", {"dirname": dirname}))["success"]

View File

@ -53,7 +53,7 @@ class RpcClient:
self.closing_task = None
return self
async def fetch(self, path, request_json) -> Any:
async def fetch(self, path, request_json) -> Dict[str, Any]:
async with self.session.post(self.url + path, json=request_json, ssl_context=self.ssl_context) as response:
response.raise_for_status()
res_json = await response.json()

View File

@ -207,7 +207,6 @@ class RpcServer:
if change == "add_connection" or change == "close_connection" or change == "peer_changed_peak":
data = await self.get_connections({})
if data is not None:
payload = create_payload_dict(
"get_connections",
data,

View File

@ -46,7 +46,12 @@ from chia.wallet.derive_keys import (
from chia.wallet.did_wallet import did_wallet_puzzles
from chia.wallet.did_wallet.did_info import DIDInfo
from chia.wallet.did_wallet.did_wallet import DIDWallet
from chia.wallet.did_wallet.did_wallet_puzzles import DID_INNERPUZ_MOD, match_did_puzzle, program_to_metadata
from chia.wallet.did_wallet.did_wallet_puzzles import (
DID_INNERPUZ_MOD,
match_did_puzzle,
metadata_to_program,
program_to_metadata,
)
from chia.wallet.nft_wallet import nft_puzzles
from chia.wallet.nft_wallet.nft_info import NFTCoinInfo, NFTInfo
from chia.wallet.nft_wallet.nft_puzzles import get_metadata_and_phs
@ -1334,8 +1339,11 @@ class WalletRpcApi:
:return:
"""
puzzle_hash: bytes32 = decode_puzzle_hash(request["address"])
is_hex = request.get("is_hex", False)
if isinstance(is_hex, str):
is_hex = bool(is_hex)
pubkey, signature = await self.service.wallet_state_manager.main_wallet.sign_message(
request["message"], puzzle_hash
request["message"], puzzle_hash, is_hex
)
return {
"success": True,
@ -1353,6 +1361,9 @@ class WalletRpcApi:
entity_id: bytes32 = decode_puzzle_hash(request["id"])
selected_wallet: Optional[WalletProtocol] = None
is_hex = request.get("is_hex", False)
if isinstance(is_hex, str):
is_hex = bool(is_hex)
if is_valid_address(request["id"], {AddressType.DID}, self.service.config):
for wallet in self.service.wallet_state_manager.wallets.values():
if wallet.type() == WalletType.DECENTRALIZED_ID.value:
@ -1364,7 +1375,7 @@ class WalletRpcApi:
if selected_wallet is None:
return {"success": False, "error": f"DID for {entity_id.hex()} doesn't exist."}
assert isinstance(selected_wallet, DIDWallet)
pubkey, signature = await selected_wallet.sign_message(request["message"])
pubkey, signature = await selected_wallet.sign_message(request["message"], is_hex)
latest_coin: Set[Coin] = await selected_wallet.select_coins(uint64(1))
latest_coin_id = None
if len(latest_coin) > 0:
@ -1383,7 +1394,7 @@ class WalletRpcApi:
return {"success": False, "error": f"NFT for {entity_id.hex()} doesn't exist."}
assert isinstance(selected_wallet, NFTWallet)
pubkey, signature = await selected_wallet.sign_message(request["message"], target_nft)
pubkey, signature = await selected_wallet.sign_message(request["message"], target_nft, is_hex)
latest_coin_id = target_nft.coin.name()
else:
return {"success": False, "error": f'Unknown ID type, {request["id"]}'}
@ -1853,6 +1864,7 @@ class WalletRpcApi:
"metadata": program_to_metadata(metadata),
"launcher_id": singleton_struct.rest().first().as_python().hex(),
"full_puzzle": full_puzzle,
"solution": Program.from_bytes(bytes(coin_spend.solution)).as_python(),
"hints": hints,
}
@ -1883,9 +1895,6 @@ class WalletRpcApi:
p2_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata = curried_args
hint_list = compute_coin_hints(coin_spend)
old_inner_puzhash = DID_INNERPUZ_MOD.curry(
p2_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata
).get_tree_hash()
derivation_record = None
# Hint is required, if it doesn't have any hint then it should be invalid
is_invalid = len(hint_list) == 0
@ -1896,11 +1905,9 @@ class WalletRpcApi:
)
)
if derivation_record is not None:
is_invalid = False
break
# Check if the mismatch is because of the memo bug
if hint == old_inner_puzhash:
is_invalid = True
break
is_invalid = True
if is_invalid:
# This is an invalid DID, check if we are owner
derivation_record = (
@ -1922,15 +1929,6 @@ class WalletRpcApi:
did_puzzle_empty_recovery = DID_INNERPUZ_MOD.curry(
our_inner_puzzle, Program.to([]).get_tree_hash(), uint64(0), singleton_struct, metadata
)
full_puzzle_empty_recovery = create_fullpuz(did_puzzle_empty_recovery, launcher_id)
if full_puzzle.get_tree_hash() != coin_state.coin.puzzle_hash:
if full_puzzle_empty_recovery.get_tree_hash() == coin_state.coin.puzzle_hash:
did_puzzle = did_puzzle_empty_recovery
else:
return {
"success": False,
"error": f"Cannot recover DID {launcher_id.hex()} because the last spend is metadata update.",
}
# Check if we have the DID wallet
did_wallet: Optional[DIDWallet] = None
for wallet in self.service.wallet_state_manager.wallets.values():
@ -1940,6 +1938,62 @@ class WalletRpcApi:
did_wallet = wallet
break
full_puzzle_empty_recovery = create_fullpuz(did_puzzle_empty_recovery, launcher_id)
if full_puzzle.get_tree_hash() != coin_state.coin.puzzle_hash:
if full_puzzle_empty_recovery.get_tree_hash() == coin_state.coin.puzzle_hash:
did_puzzle = did_puzzle_empty_recovery
elif (
did_wallet is not None
and did_wallet.did_info.current_inner is not None
and create_fullpuz(did_wallet.did_info.current_inner, launcher_id).get_tree_hash()
== coin_state.coin.puzzle_hash
):
# Check if the old wallet has the inner puzzle
did_puzzle = did_wallet.did_info.current_inner
else:
# Try override
if "recovery_list_hash" in request:
recovery_list_hash = Program.from_bytes(bytes.fromhex(request["recovery_list_hash"]))
num_verification = request.get("num_verification", num_verification)
if "metadata" in request:
metadata = metadata_to_program(request["metadata"])
did_puzzle = DID_INNERPUZ_MOD.curry(
our_inner_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata
)
full_puzzle = create_fullpuz(did_puzzle, launcher_id)
matched = True
if full_puzzle.get_tree_hash() != coin_state.coin.puzzle_hash:
matched = False
# Brute force addresses
index = 0
derivation_record = await self.service.wallet_state_manager.puzzle_store.get_derivation_record(
uint32(index), uint32(1), False
)
while derivation_record is not None:
our_inner_puzzle = self.service.wallet_state_manager.main_wallet.puzzle_for_pk(
derivation_record.pubkey
)
did_puzzle = DID_INNERPUZ_MOD.curry(
our_inner_puzzle, recovery_list_hash, num_verification, singleton_struct, metadata
)
full_puzzle = create_fullpuz(did_puzzle, launcher_id)
if full_puzzle.get_tree_hash() == coin_state.coin.puzzle_hash:
matched = True
break
index += 1
derivation_record = (
await self.service.wallet_state_manager.puzzle_store.get_derivation_record(
uint32(index), uint32(1), False
)
)
if not matched:
return {
"success": False,
"error": f"Cannot recover DID {launcher_id.hex()}"
f" because the last spend updated recovery_list_hash/num_verification/metadata.",
}
if did_wallet is None:
# Create DID wallet
response: List[CoinState] = await self.service.get_coin_state([launcher_id], peer=peer)

View File

@ -45,7 +45,7 @@ class WalletRpcClient(RpcClient):
except ValueError as e:
return e.args[0]
async def set_wallet_resync_on_startup(self, enable: bool = True) -> None:
async def set_wallet_resync_on_startup(self, enable: bool = True) -> Dict[str, Any]:
return await self.fetch(
path="set_wallet_resync_on_startup",
request_json={"enable": enable},
@ -247,7 +247,6 @@ class WalletRpcClient(RpcClient):
exclude_amounts: Optional[List[uint64]] = None,
wallet_id: Optional[int] = None,
) -> List[TransactionRecord]:
# Converts bytes to hex for puzzle hashes
additions_hex = []
for ad in additions:
@ -308,7 +307,6 @@ class WalletRpcClient(RpcClient):
exclude_coins: Optional[List[Coin]] = None,
wallet_id: Optional[int] = None,
) -> TransactionRecord:
txs: List[TransactionRecord] = await self.create_signed_transactions(
additions=additions,
coins=coins,
@ -521,7 +519,6 @@ class WalletRpcClient(RpcClient):
p2_singleton_delay_time: Optional[uint64] = None,
p2_singleton_delayed_ph: Optional[bytes32] = None,
) -> TransactionRecord:
request: Dict[str, Any] = {
"wallet_type": "pool_wallet",
"mode": mode,

View File

@ -100,7 +100,6 @@ class Crawler:
async def connect_task(self, peer):
async def peer_action(peer: WSChiaConnection):
peer_info = peer.get_peer_info()
version = peer.get_version()
if peer_info is not None and version is not None:

View File

@ -300,7 +300,7 @@ class AddressManager:
def mark_good_(self, addr: PeerInfo, test_before_evict: bool, timestamp: int) -> None:
self.last_good = timestamp
(info, node_id) = self.find_(addr)
if not addr.is_valid(self.allow_private_subnets):
if addr.ip.is_private and not self.allow_private_subnets:
return None
if info is None:
return None
@ -365,7 +365,7 @@ class AddressManager:
addr.host,
addr.port,
)
if not peer_info.is_valid(self.allow_private_subnets):
if peer_info.ip.is_private and not self.allow_private_subnets:
return False
(info, node_id) = self.find_(peer_info)
if info is not None and info.peer_info == peer_info:
@ -554,7 +554,7 @@ class AddressManager:
rand_pos = randrange(len(self.random_pos) - n) + n
self.swap_random_(n, rand_pos)
info = self.map_info[self.random_pos[n]]
if not info.peer_info.is_valid(self.allow_private_subnets):
if info.peer_info.ip.is_private and not self.allow_private_subnets:
continue
if not info.is_terrible():
cur_peer_info = TimestampedPeerInfo(

View File

@ -25,6 +25,7 @@ from chia.server.ws_connection import WSChiaConnection
from chia.types.peer_info import PeerInfo, TimestampedPeerInfo
from chia.util.hash import std_hash
from chia.util.ints import uint16, uint64
from chia.util.network import IPAddress, get_host_addr
MAX_PEERS_RECEIVED_PER_REQUEST = 1000
MAX_TOTAL_PEERS_RECEIVED = 3000
@ -61,8 +62,10 @@ class FullNodeDiscovery:
self.dns_servers = dns_servers
random.shuffle(dns_servers) # Don't always start with the same DNS server
if introducer_info is not None:
# get_host_addr is blocking but this only gets called on startup or in the wallet after disconnecting from
# all trusted peers.
self.introducer_info: Optional[PeerInfo] = PeerInfo(
introducer_info["host"],
str(get_host_addr(introducer_info["host"], prefer_ipv6=False)),
introducer_info["port"],
)
else:
@ -390,9 +393,6 @@ class FullNodeDiscovery:
addr = info.peer_info
if has_collision:
break
if addr is not None and not addr.is_valid():
addr = None
continue
if not is_feeler and addr.get_group() in groups:
addr = None
continue
@ -600,7 +600,6 @@ class FullNodePeers(FullNodeDiscovery):
async def request_peers(self, peer_info: PeerInfo) -> Optional[Message]:
try:
# Prevent a fingerprint attack: do not send peers to inbound connections.
# This asymmetric behavior for inbound and outbound connections was introduced
# to prevent a fingerprinting attack: an attacker can send specific fake addresses
@ -647,8 +646,9 @@ class FullNodePeers(FullNodeDiscovery):
relay_peer, num_peers = await self.relay_queue.get()
except asyncio.CancelledError:
return None
relay_peer_info = PeerInfo(relay_peer.host, relay_peer.port)
if not relay_peer_info.is_valid():
try:
IPAddress.create(relay_peer.host)
except ValueError:
continue
# https://en.bitcoin.it/wiki/Satoshi_Client_Node_Discovery#Address_Relay
connections = self.server.get_connections(NodeType.FULL_NODE)

View File

@ -72,7 +72,6 @@ class RateLimiter:
rate_limits = get_rate_limits_to_use(our_capabilities, peer_capabilities)
try:
limits: RLSettings = rate_limits["default_settings"]
if message_type in rate_limits["rate_limits_tx"]:
limits = rate_limits["rate_limits_tx"][message_type]

View File

@ -6,7 +6,7 @@ import ssl
import time
import traceback
from dataclasses import dataclass, field
from ipaddress import IPv4Network, IPv6Address, IPv6Network, ip_address, ip_network
from ipaddress import IPv4Network, IPv6Network, ip_network
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union, cast
@ -160,7 +160,6 @@ class ChiaServer:
chia_ca_crt_key: Tuple[Path, Path],
name: str = __name__,
) -> ChiaServer:
log = logging.getLogger(name)
log.info("Service capabilities: %s", capabilities)
@ -419,14 +418,8 @@ class ChiaServer:
timeout_value = float(self.config.get("peer_connect_timeout", 30))
timeout = ClientTimeout(total=timeout_value)
session = ClientSession(timeout=timeout)
try:
if type(ip_address(target_node.host)) is IPv6Address:
target_node = PeerInfo(f"[{target_node.host}]", target_node.port)
except ValueError:
pass
url = f"wss://{target_node.host}:{target_node.port}/ws"
ip = f"[{target_node.ip}]" if target_node.ip.is_v6 else f"{target_node.ip}"
url = f"wss://{ip}:{target_node.port}/ws"
self.log.debug(f"Connecting: {url}, Peer info: {target_node}")
try:
ws = await session.ws_connect(
@ -656,10 +649,10 @@ class ChiaServer:
ip = None
if ip is None:
return None
peer = PeerInfo(ip, uint16(port))
if not peer.is_valid():
try:
return PeerInfo(ip, uint16(port))
except ValueError:
return None
return peer
def get_port(self) -> uint16:
return uint16(self._port)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import asyncio
import contextlib
import logging
import math
import time
import traceback
from dataclasses import dataclass, field
@ -28,7 +29,6 @@ from chia.util.api_decorators import get_metadata
from chia.util.errors import Err, ProtocolError
from chia.util.ints import uint8, uint16
from chia.util.log_exceptions import log_exceptions
from chia.util.logging import TimedDuplicateFilter
# Each message is prepended with LENGTH_BYTES bytes specifying the length
from chia.util.network import class_for_type, is_localhost
@ -41,6 +41,10 @@ WebSocket = Union[WebSocketResponse, ClientWebSocketResponse]
ConnectionCallback = Callable[["WSChiaConnection"], Awaitable[None]]
def create_default_last_message_time_dict() -> Dict[ProtocolMessageTypes, float]:
return {message_type: -math.inf for message_type in ProtocolMessageTypes}
class ConnectionClosedCallbackProtocol(Protocol):
def __call__(
self,
@ -110,6 +114,11 @@ class WSChiaConnection:
version: str = field(default_factory=str)
protocol_version: str = field(default_factory=str)
log_rate_limit_last_time: Dict[ProtocolMessageTypes, float] = field(
default_factory=create_default_last_message_time_dict,
repr=False,
)
@classmethod
def create(
cls,
@ -128,7 +137,6 @@ class WSChiaConnection:
local_capabilities_for_handshake: List[Tuple[uint16, str]],
session: Optional[ClientSession] = None,
) -> WSChiaConnection:
assert ws._writer is not None
peername = ws._writer.transport.get_extra_info("peername")
@ -163,8 +171,18 @@ class WSChiaConnection:
)
def _get_extra_info(self, name: str) -> Optional[Any]:
assert self.ws._writer is not None, "websocket's ._writer is None, was .prepare() called?"
return self.ws._writer.transport.get_extra_info(name)
writer = self.ws._writer
assert writer is not None, "websocket's ._writer is None, was .prepare() called?"
transport = writer.transport
if transport is None:
return None
try:
return transport.get_extra_info(name)
except AttributeError:
# "/usr/lib/python3.11/asyncio/sslproto.py", line 91, in get_extra_info
# return self._ssl_protocol._get_extra_info(name, default)
# AttributeError: 'NoneType' object has no attribute '_get_extra_info'
return None
async def perform_handshake(
self,
@ -340,7 +358,7 @@ class WSChiaConnection:
if self.received_message_callback is not None:
await self.received_message_callback(self)
self.log.debug(
f"<- {ProtocolMessageTypes(full_message.type).name} from peer " f"{self.peer_node_id} {self.peer_host}"
f"<- {ProtocolMessageTypes(full_message.type).name} from peer {self.peer_node_id} {self.peer_host}"
)
message_type = ProtocolMessageTypes(full_message.type).name
@ -550,10 +568,13 @@ class WSChiaConnection:
message, self.local_capabilities, self.peer_capabilities
):
if not is_localhost(self.peer_host):
msg = f"Rate limiting ourselves. message type: {ProtocolMessageTypes(message.type).name}, "
f"peer: {self.peer_host}"
self.log.debug(msg)
self.log.addFilter(TimedDuplicateFilter(msg, 60))
message_type = ProtocolMessageTypes(message.type)
last_time = self.log_rate_limit_last_time[message_type]
now = time.monotonic()
self.log_rate_limit_last_time[message_type] = now
if now - last_time >= 60:
msg = f"Rate limiting ourselves. message type: {message_type.name}, peer: {self.peer_host}"
self.log.debug(msg)
# TODO: fix this special case. This function has rate limits which are too low.
if ProtocolMessageTypes(message.type) != ProtocolMessageTypes.respond_peers:

View File

@ -21,7 +21,6 @@ from chia_rs import compute_merkle_set_root
from chiabip158 import PyBIP158
from clvm.casts import int_from_bytes
from chia.cmds.init_funcs import create_default_chia_config
from chia.consensus.block_creation import unfinished_block_to_full_block
from chia.consensus.block_record import BlockRecord
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
@ -101,7 +100,14 @@ from chia.types.spend_bundle import SpendBundle
from chia.types.unfinished_block import UnfinishedBlock
from chia.util.bech32m import encode_puzzle_hash
from chia.util.block_cache import BlockCache
from chia.util.config import config_path_for_filename, load_config, lock_config, override_config, save_config
from chia.util.config import (
config_path_for_filename,
create_default_chia_config,
load_config,
lock_config,
override_config,
save_config,
)
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.errors import Err
from chia.util.hash import std_hash
@ -147,7 +153,6 @@ test_constants = DEFAULT_CONSTANTS.replace(
def compute_additions_unchecked(sb: SpendBundle) -> List[Coin]:
ret: List[Coin] = []
for cs in sb.coin_spends:
parent_id = cs.coin.name()
_, r = cs.puzzle_reveal.run_with_cost(INFINITE_COST, cs.solution)
for cond in Program.to(r).as_iter():
@ -182,7 +187,6 @@ class BlockTools:
plot_dir: str = "test-plots",
log: logging.Logger = logging.getLogger(__name__),
):
self._block_cache_header = bytes32([0] * 32)
self._tempdir = None
@ -281,52 +285,58 @@ class BlockTools:
)
async def setup_keys(self, fingerprint: Optional[int] = None, reward_ph: Optional[bytes32] = None):
if self.local_keychain:
keychain_proxy: Optional[KeychainProxy] = wrap_local_keychain(self.local_keychain, log=self.log)
elif not self.automated_testing and fingerprint is not None:
keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
else: # if we are automated testing or if we don't have a fingerprint.
keychain_proxy = await connect_to_keychain_and_validate(
self.root_path, self.log, user="testing-1.8.0", service="chia-testing-1.8.0"
)
assert keychain_proxy is not None
if fingerprint is None: # if we are not specifying an existing key
await keychain_proxy.delete_all_keys()
self.farmer_master_sk_entropy = std_hash(b"block_tools farmer key") # both entropies are only used here
self.pool_master_sk_entropy = std_hash(b"block_tools pool key")
self.farmer_master_sk = await keychain_proxy.add_private_key(
bytes_to_mnemonic(self.farmer_master_sk_entropy)
)
self.pool_master_sk = await keychain_proxy.add_private_key(bytes_to_mnemonic(self.pool_master_sk_entropy))
else:
self.farmer_master_sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
self.pool_master_sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
keychain_proxy: Optional[KeychainProxy]
try:
if self.local_keychain:
keychain_proxy = wrap_local_keychain(self.local_keychain, log=self.log)
elif not self.automated_testing and fingerprint is not None:
keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
else: # if we are automated testing or if we don't have a fingerprint.
keychain_proxy = await connect_to_keychain_and_validate(
self.root_path, self.log, user="testing-1.8.0", service="chia-testing-1.8.0"
)
assert keychain_proxy is not None
if fingerprint is None: # if we are not specifying an existing key
await keychain_proxy.delete_all_keys()
self.farmer_master_sk_entropy = std_hash(b"block_tools farmer key") # both entropies are only used here
self.pool_master_sk_entropy = std_hash(b"block_tools pool key")
self.farmer_master_sk = await keychain_proxy.add_private_key(
bytes_to_mnemonic(self.farmer_master_sk_entropy)
)
self.pool_master_sk = await keychain_proxy.add_private_key(
bytes_to_mnemonic(self.pool_master_sk_entropy),
)
else:
self.farmer_master_sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
self.pool_master_sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
self.farmer_pk = master_sk_to_farmer_sk(self.farmer_master_sk).get_g1()
self.pool_pk = master_sk_to_pool_sk(self.pool_master_sk).get_g1()
self.farmer_pk = master_sk_to_farmer_sk(self.farmer_master_sk).get_g1()
self.pool_pk = master_sk_to_pool_sk(self.pool_master_sk).get_g1()
if reward_ph is None:
self.farmer_ph: bytes32 = create_puzzlehash_for_pk(
master_sk_to_wallet_sk(self.farmer_master_sk, uint32(0)).get_g1()
)
self.pool_ph: bytes32 = create_puzzlehash_for_pk(
master_sk_to_wallet_sk(self.pool_master_sk, uint32(0)).get_g1()
)
else:
self.farmer_ph = reward_ph
self.pool_ph = reward_ph
if self.automated_testing:
self.all_sks: List[PrivateKey] = [sk for sk, _ in await keychain_proxy.get_all_private_keys()]
else:
self.all_sks = [self.farmer_master_sk] # we only want to include plots under the same fingerprint
self.pool_pubkeys: List[G1Element] = [master_sk_to_pool_sk(sk).get_g1() for sk in self.all_sks]
if reward_ph is None:
self.farmer_ph: bytes32 = create_puzzlehash_for_pk(
master_sk_to_wallet_sk(self.farmer_master_sk, uint32(0)).get_g1()
)
self.pool_ph: bytes32 = create_puzzlehash_for_pk(
master_sk_to_wallet_sk(self.pool_master_sk, uint32(0)).get_g1()
)
else:
self.farmer_ph = reward_ph
self.pool_ph = reward_ph
if self.automated_testing:
self.all_sks: List[PrivateKey] = [sk for sk, _ in await keychain_proxy.get_all_private_keys()]
else:
self.all_sks = [self.farmer_master_sk] # we only want to include plots under the same fingerprint
self.pool_pubkeys: List[G1Element] = [master_sk_to_pool_sk(sk).get_g1() for sk in self.all_sks]
self.farmer_pubkeys: List[G1Element] = [master_sk_to_farmer_sk(sk).get_g1() for sk in self.all_sks]
if len(self.pool_pubkeys) == 0 or len(self.farmer_pubkeys) == 0:
raise RuntimeError("Keys not generated. Run `chia keys generate`")
self.farmer_pubkeys: List[G1Element] = [master_sk_to_farmer_sk(sk).get_g1() for sk in self.all_sks]
if len(self.pool_pubkeys) == 0 or len(self.farmer_pubkeys) == 0:
raise RuntimeError("Keys not generated. Run `chia keys generate`")
self.plot_manager.set_public_keys(self.farmer_pubkeys, self.pool_pubkeys)
await keychain_proxy.close() # close the keychain proxy
self.plot_manager.set_public_keys(self.farmer_pubkeys, self.pool_pubkeys)
finally:
if keychain_proxy is not None:
await keychain_proxy.close() # close the keychain proxy
def change_config(self, new_config: Dict):
self._config = new_config
@ -770,7 +780,7 @@ class BlockTools:
blocks[full_block.header_hash] = block_record
self.log.info(
f"Created block {block_record.height} ove=False, iters " f"{block_record.total_iters}"
f"Created block {block_record.height} ove=False, iters {block_record.total_iters}"
)
height_to_hash[uint32(full_block.height)] = full_block.header_hash
latest_block = blocks[full_block.header_hash]
@ -1050,9 +1060,7 @@ class BlockTools:
previous_generator = compressor_arg
blocks_added_this_sub_slot += 1
self.log.info(
f"Created block {block_record.height } ov=True, iters " f"{block_record.total_iters}"
)
self.log.info(f"Created block {block_record.height } ov=True, iters {block_record.total_iters}")
num_blocks -= 1
blocks[full_block.header_hash] = block_record
@ -1281,7 +1289,6 @@ class BlockTools:
qualities = plot_info.prover.get_qualities_for_challenge(new_challenge)
for proof_index, quality_str in enumerate(qualities):
required_iters = calculate_iterations_quality(
constants.DIFFICULTY_CONSTANT_FACTOR,
quality_str,
@ -2085,7 +2092,7 @@ def create_test_unfinished_block(
additions = []
if removals is None:
removals = []
(foliage, foliage_transaction_block, transactions_info,) = create_test_foliage(
(foliage, foliage_transaction_block, transactions_info) = create_test_foliage(
constants,
rc_block,
block_generator,

View File

@ -12,7 +12,6 @@ from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate
from chia.consensus.multiprocess_validation import PreValidationResult
from chia.full_node.full_node import FullNode
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols.full_node_protocol import RespondBlock
from chia.rpc.rpc_server import default_get_connections
from chia.server.outbound_message import NodeType
from chia.simulator.block_tools import BlockTools
@ -131,7 +130,7 @@ class FullNodeSimulator(FullNodeAPI):
config["simulator"]["auto_farm"] = self.auto_farm
save_config(self.bt.root_path, "config.yaml", config)
self.config = config
if self.auto_farm is True and self.full_node.mempool_manager.mempool.total_mempool_cost > 0:
if self.auto_farm is True and self.full_node.mempool_manager.mempool.total_mempool_cost() > 0:
# if mempool is not empty and auto farm was just enabled, farm a block
await self.farm_new_transaction_block(FarmNewBlockProtocol(self.bt.farmer_ph))
return self.auto_farm
@ -225,8 +224,7 @@ class FullNodeSimulator(FullNodeAPI):
current_time=current_time,
previous_generator=self.full_node.full_node_store.previous_generator,
)
rr = RespondBlock(more[-1])
await self.full_node.respond_block(rr)
await self.full_node.add_block(more[-1])
return more[-1]
async def farm_new_block(self, request: FarmNewBlockProtocol, force_wait_for_timestamp: bool = False):
@ -272,8 +270,7 @@ class FullNodeSimulator(FullNodeAPI):
current_time=current_time,
time_per_block=time_per_block,
)
rr: RespondBlock = RespondBlock(more[-1])
await self.full_node.respond_block(rr)
await self.full_node.add_block(more[-1])
async def reorg_from_index_to_new_index(self, request: ReorgProtocol):
new_index = request.new_index
@ -297,7 +294,7 @@ class FullNodeSimulator(FullNodeAPI):
)
for block in more_blocks:
await self.full_node.respond_block(RespondBlock(block))
await self.full_node.add_block(block)
async def farm_blocks_to_puzzlehash(
self,

View File

@ -134,45 +134,6 @@ async def setup_n_nodes(
keyring.cleanup()
async def setup_node_and_wallet(
consensus_constants: ConsensusConstants,
self_hostname: str,
key_seed: Optional[bytes32] = None,
db_version: int = 1,
disable_capabilities: Optional[List[Capability]] = None,
) -> AsyncGenerator[Tuple[FullNodeAPI, WalletNode, ChiaServer, ChiaServer, BlockTools], None]:
with TempKeyring(populate=True) as keychain:
btools = await create_block_tools_async(constants=test_constants, keychain=keychain)
full_node_iter = setup_full_node(
consensus_constants,
"blockchain_test.db",
self_hostname,
btools,
simulator=False,
db_version=db_version,
disable_capabilities=disable_capabilities,
)
wallet_node_iter = setup_wallet_node(
btools.config["self_hostname"],
consensus_constants,
btools,
None,
key_seed=key_seed,
)
full_node_service = await full_node_iter.__anext__()
full_node_api = full_node_service._api
wallet_node_service = await wallet_node_iter.__anext__()
wallet = wallet_node_service._node
s2 = wallet_node_service._node.server
yield full_node_api, wallet, full_node_api.full_node.server, s2, btools
await _teardown_nodes([full_node_iter])
await _teardown_nodes([wallet_node_iter])
async def setup_simulators_and_wallets(
simulator_count: int,
wallet_count: int,
@ -329,7 +290,6 @@ async def setup_farmer_multi_harvester(
*,
start_services: bool,
) -> AsyncIterator[Tuple[List[Service[Harvester]], Service[Farmer], BlockTools]]:
farmer_node_iterators = [
setup_farmer(
block_tools,

View File

@ -81,13 +81,9 @@ async def setup_daemon(btools: BlockTools) -> AsyncGenerator[WebSocketServer, No
ca_crt_path = root_path / config["private_ssl_ca"]["crt"]
ca_key_path = root_path / config["private_ssl_ca"]["key"]
with Lockfile.create(daemon_launch_lock_path(root_path)):
shutdown_event = asyncio.Event()
ws_server = WebSocketServer(root_path, ca_crt_path, ca_key_path, crt_path, key_path, shutdown_event)
await ws_server.start()
yield ws_server
await ws_server.stop()
ws_server = WebSocketServer(root_path, ca_crt_path, ca_key_path, crt_path, key_path)
async with ws_server.run():
yield ws_server
async def setup_full_node(
@ -105,6 +101,8 @@ async def setup_full_node(
) -> AsyncGenerator[Service[FullNode], None]:
db_path = local_bt.root_path / f"{db_name}"
if db_path.exists():
# TODO: remove (maybe) when fixed https://github.com/python/cpython/issues/97641
gc.collect()
db_path.unlink()
if db_version > 1:
@ -200,6 +198,8 @@ async def setup_wallet_node(
db_path = local_bt.root_path / db_path_replaced
if db_path.exists():
# TODO: remove (maybe) when fixed https://github.com/python/cpython/issues/97641
gc.collect()
db_path.unlink()
service_config["database_path"] = str(db_name)
service_config["testing"] = True
@ -233,6 +233,8 @@ async def setup_wallet_node(
service.stop()
await service.wait_closed()
if db_path.exists():
# TODO: remove (maybe) when fixed https://github.com/python/cpython/issues/97641
gc.collect()
db_path.unlink()
keychain.delete_all_keys()

View File

@ -1,6 +1,5 @@
from __future__ import annotations
import asyncio
import sys
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, Optional, Tuple
@ -136,7 +135,6 @@ async def get_full_chia_simulator(
keychain = Keychain()
with Lockfile.create(daemon_launch_lock_path(chia_root)):
mnemonic, fingerprint = mnemonic_fingerprint(keychain)
ssl_ca_cert_and_key_wrapper: SSLTestCollateralWrapper[
@ -158,13 +156,8 @@ async def get_full_chia_simulator(
ca_crt_path = chia_root / config["private_ssl_ca"]["crt"]
ca_key_path = chia_root / config["private_ssl_ca"]["key"]
shutdown_event = asyncio.Event()
ws_server = WebSocketServer(chia_root, ca_crt_path, ca_key_path, crt_path, key_path, shutdown_event)
ws_server = WebSocketServer(chia_root, ca_crt_path, ca_key_path, crt_path, key_path)
await ws_server.setup_process_global_state()
await ws_server.start()
async for simulator in start_simulator(chia_root, automated_testing):
yield simulator, chia_root, config, mnemonic, fingerprint, keychain
await ws_server.stop()
await shutdown_event.wait() # wait till shutdown is complete
async with ws_server.run():
async for simulator in start_simulator(chia_root, automated_testing):
yield simulator, chia_root, config, mnemonic, fingerprint, keychain

View File

@ -476,7 +476,7 @@ class Timelord:
rc_challenge = self.last_state.get_challenge(Chain.REWARD_CHAIN)
if rc_info.challenge != rc_challenge:
assert rc_challenge is not None
log.warning(f"SP: Do not have correct challenge {rc_challenge.hex()}" f" has {rc_info.challenge}")
log.warning(f"SP: Do not have correct challenge {rc_challenge.hex()} has {rc_info.challenge}")
# This proof is on an outdated challenge, so don't use it
continue
iters_from_sub_slot_start = cc_info.number_of_iterations + self.last_state.get_last_ip()
@ -745,7 +745,7 @@ class Timelord:
rc_challenge = self.last_state.get_challenge(Chain.REWARD_CHAIN)
if rc_vdf.challenge != rc_challenge:
assert rc_challenge is not None
log.warning(f"Do not have correct challenge {rc_challenge.hex()} has" f" {rc_vdf.challenge}")
log.warning(f"Do not have correct challenge {rc_challenge.hex()} has {rc_vdf.challenge}")
# This proof is on an outdated challenge, so don't use it
return None
log.debug("Collected end of subslot vdfs.")

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import io
from typing import Any, Callable, Dict, List, Set, Tuple
from typing import Any, Callable, Dict, Set, Tuple
from chia_rs import run_chia_program, tree_hash
from clvm import SExp
@ -174,28 +174,6 @@ class Program(SExp):
def as_int(self) -> int:
return int_from_bytes(self.as_atom())
def as_atom_list(self) -> List[bytes]:
"""
Pretend `self` is a list of atoms. Return the corresponding
python list of atoms.
At each step, we always assume a node to be an atom or a pair.
If the assumption is wrong, we exit early. This way we never fail
and always return SOMETHING.
"""
items = []
obj = self
while True:
pair = obj.pair
if pair is None:
break
atom = pair[0].atom
if atom is None:
break
items.append(atom)
obj = pair[1]
return items
def __deepcopy__(self, memo):
return type(self).from_bytes(bytes(self))

View File

@ -92,7 +92,6 @@ class SerializedProgram:
def run_as_generator(
self, max_cost: int, flags: int, *args: Union[Program, SerializedProgram]
) -> Tuple[Optional[int], Optional[SpendBundleConditions]]:
serialized_args = bytearray()
if len(args) > 1:
# when we have more than one argument, serialize them into a list

View File

@ -35,7 +35,6 @@ def verify_vdf(
discriminant_size: int,
witness_type: uint8,
):
return verify_n_wesolowski(
str(disc),
input_el,

View File

@ -1,7 +1,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List
from typing import List, Tuple
from clvm.casts import int_from_bytes
@ -29,7 +29,11 @@ class CoinSpend(Streamable):
solution: SerializedProgram
def compute_additions(cs: CoinSpend, *, max_cost: int = DEFAULT_CONSTANTS.MAX_BLOCK_COST_CLVM) -> List[Coin]:
def compute_additions_with_cost(
cs: CoinSpend,
*,
max_cost: int = DEFAULT_CONSTANTS.MAX_BLOCK_COST_CLVM,
) -> Tuple[List[Coin], int]:
"""
Run the puzzle in the specified CoinSpend and return the cost and list of
coins created by the puzzle, i.e. additions. If the cost (CLVM- and
@ -55,4 +59,9 @@ def compute_additions(cs: CoinSpend, *, max_cost: int = DEFAULT_CONSTANTS.MAX_BL
puzzle_hash = next(atoms).atom
amount = int_from_bytes(next(atoms).atom)
ret.append(Coin(parent_id, puzzle_hash, amount))
return ret
return ret, cost
def compute_additions(cs: CoinSpend, *, max_cost: int = DEFAULT_CONSTANTS.MAX_BLOCK_COST_CLVM) -> List[Coin]:
return compute_additions_with_cost(cs, max_cost=max_cost)[0]

View File

@ -1,14 +1,15 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from typing import Any, Dict, List, Optional
from chia.consensus.cost_calculator import NPCResult
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.spend_bundle import SpendBundle
from chia.util.generator_tools import additions_for_npc
from chia.util.ints import uint32, uint64
from chia.util.streamable import Streamable, streamable
from chia.util.streamable import Streamable, recurse_jsonify, streamable
@streamable
@ -17,9 +18,7 @@ class MempoolItem(Streamable):
spend_bundle: SpendBundle
fee: uint64
npc_result: NPCResult
cost: uint64
spend_bundle_name: bytes32
additions: List[Coin]
height_added_to_mempool: uint32
# If present, this SpendBundle is not valid at or before this height
@ -36,6 +35,26 @@ class MempoolItem(Streamable):
def name(self) -> bytes32:
return self.spend_bundle_name
@property
def cost(self) -> uint64:
assert self.npc_result.conds is not None
return uint64(self.npc_result.conds.cost)
@property
def additions(self) -> List[Coin]:
return additions_for_npc(self.npc_result)
@property
def removals(self) -> List[Coin]:
return self.spend_bundle.removals()
def to_json_dict(self) -> Dict[str, Any]:
return {
"spend_bundle": recurse_jsonify(self.spend_bundle),
"fee": recurse_jsonify(self.fee),
"npc_result": recurse_jsonify(self.npc_result),
"cost": recurse_jsonify(self.cost),
"spend_bundle_name": recurse_jsonify(self.spend_bundle_name),
"additions": recurse_jsonify(self.additions),
"removals": recurse_jsonify(self.removals),
}

View File

@ -2,64 +2,52 @@ from __future__ import annotations
import ipaddress
from dataclasses import dataclass
from typing import Optional, Union
from typing import Union
from chia.util.ints import uint16, uint64
from chia.util.network import IPAddress
from chia.util.streamable import Streamable, streamable
@dataclass(frozen=True)
# TODO, Replace unsafe_hash with frozen and drop the __init__ as soon as all PeerInfo call sites pass in an IPAddress.
@dataclass(unsafe_hash=True)
class PeerInfo:
host: str
port: uint16
_ip: IPAddress
_port: uint16
def is_valid(self, allow_private_subnets: bool = False) -> bool:
ip: Optional[Union[ipaddress.IPv6Address, ipaddress.IPv4Address]] = None
try:
ip = ipaddress.IPv6Address(self.host)
except ValueError:
ip = None
if ip is not None:
if ip.is_private and not allow_private_subnets:
return False
return True
# TODO, Drop this as soon as all call PeerInfo calls pass in an IPAddress
def __init__(self, host: Union[IPAddress, str], port: int):
self._ip = host if isinstance(host, IPAddress) else IPAddress.create(host)
self._port = uint16(port)
try:
ip = ipaddress.IPv4Address(self.host)
except ValueError:
ip = None
if ip is not None:
if ip.is_private and not allow_private_subnets:
return False
return True
return False
# Kept here for compatibility until all places where its used transitioned to IPAddress instead of str.
@property
def host(self) -> str:
return str(self._ip)
@property
def ip(self) -> IPAddress:
return self._ip
@property
def port(self) -> uint16:
return self._port
# Functions related to peer bucketing in new/tried tables.
def get_key(self) -> bytes:
try:
ip = ipaddress.IPv6Address(self.host)
except ValueError:
ip_v4 = ipaddress.IPv4Address(self.host)
ip = ipaddress.IPv6Address(int(ipaddress.IPv6Address("2002::")) | (int(ip_v4) << 80))
key = ip.packed
if self.ip.is_v4:
key = ipaddress.IPv6Address(int(ipaddress.IPv6Address("2002::")) | (int(self.ip) << 80)).packed
else:
key = self.ip.packed
key += bytes([self.port // 0x100, self.port & 0x0FF])
return key
def get_group(self) -> bytes:
# TODO: Port everything from Bitcoin.
ip_v4: Optional[ipaddress.IPv4Address] = None
ip_v6: Optional[ipaddress.IPv6Address] = None
try:
ip_v4 = ipaddress.IPv4Address(self.host)
except ValueError:
ip_v6 = ipaddress.IPv6Address(self.host)
if ip_v4 is not None:
group = bytes([1]) + ip_v4.packed[:2]
elif ip_v6 is not None:
group = bytes([0]) + ip_v6.packed[:4]
if self.ip.is_v4:
return bytes([1]) + self.ip.packed[:2]
else:
raise ValueError("PeerInfo.host is not an ip address")
return group
return bytes([0]) + self.ip.packed[:4]
@streamable

View File

@ -9,10 +9,11 @@ from blspy import AugSchemeMPL, G2Element
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.errors import Err, ValidationError
from chia.util.streamable import Streamable, recurse_jsonify, streamable, streamable_from_dict
from chia.wallet.util.debug_spend_bundle import debug_spend_bundle
from .coin_spend import CoinSpend, compute_additions
from .coin_spend import CoinSpend, compute_additions_with_cost
@streamable
@ -43,10 +44,14 @@ class SpendBundle(Streamable):
return cls(coin_spends, aggregated_signature)
# TODO: this should be removed
def additions(self) -> List[Coin]:
def additions(self, *, max_cost: int = DEFAULT_CONSTANTS.MAX_BLOCK_COST_CLVM) -> List[Coin]:
items: List[Coin] = []
for cs in self.coin_spends:
items.extend(compute_additions(cs))
coins, cost = compute_additions_with_cost(cs, max_cost=max_cost)
max_cost -= cost
if max_cost < 0:
raise ValidationError(Err.BLOCK_COST_EXCEEDS_MAX, "additions() for SpendBundle")
items.extend(coins)
return items
def removals(self) -> List[Coin]:

View File

@ -20,13 +20,15 @@
from __future__ import annotations
# Based on this specification from Pieter Wuille:
# https://github.com/sipa/bips/blob/bip-bech32m/bip-bech32m.mediawiki
"""Reference implementation for Bech32m and segwit addresses."""
from typing import Iterable, List, Optional, Tuple
from chia.types.blockchain_format.sized_bytes import bytes32
# Based on this specification from Pieter Wuille:
# https://github.com/sipa/bips/blob/bip-bech32m/bip-bech32m.mediawiki
"""Reference implementation for Bech32m and segwit addresses."""
CHARSET = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"

Some files were not shown because too many files have changed in this diff Show More