Fix race condition in select coins, and order coins by age

This commit is contained in:
Mariano Sorgente 2020-04-08 15:29:34 +09:00
parent 451b3711cc
commit 02aa39a667
No known key found for this signature in database
GPG Key ID: 0F866338C369278C
14 changed files with 125 additions and 110 deletions

View File

@ -84,20 +84,19 @@ def main():
)
filename: str = f"plot-{i}-{args.size}-{plot_seed}.dat"
full_path: Path = args.final_dir / filename
if full_path.exists():
if not full_path.exists():
# Creates the plot. This will take a long time for larger plots.
plotter: DiskPlotter = DiskPlotter()
plotter.create_plot_disk(
str(tmp_dir),
str(args.final_dir),
filename,
args.size,
bytes([]),
plot_seed,
)
else:
print(f"Plot {filename} already exists")
continue
# Creates the plot. This will take a long time for larger plots.
plotter: DiskPlotter = DiskPlotter()
plotter.create_plot_disk(
str(tmp_dir),
str(args.final_dir),
filename,
args.size,
bytes([]),
plot_seed,
)
# Updates the config if necessary.
plot_config = load_config(plot_config_filename)

View File

@ -1124,7 +1124,9 @@ class Blockchain:
if rem in additions_dic:
# Ephemeral coin
rem_coin: Coin = additions_dic[rem]
new_unspent: CoinRecord = CoinRecord(rem_coin, block.height, uint32(0), False, False)
new_unspent: CoinRecord = CoinRecord(
rem_coin, block.height, uint32(0), False, False
)
removal_coin_records[new_unspent.name] = new_unspent
else:
assert prev_header is not None

View File

@ -1083,9 +1083,10 @@ class FullNode:
assert prev_full_block is not None
async with self.store.lock:
error_code, iterations_needed = await self.blockchain.validate_unfinished_block(
block, prev_full_block
)
(
error_code,
iterations_needed,
) = await self.blockchain.validate_unfinished_block(block, prev_full_block)
if error_code is not None:
raise ConsensusError(error_code)

View File

@ -66,10 +66,7 @@ class MempoolManager:
spend_bundles: List[SpendBundle] = []
for dic in mempool.sorted_spends.values():
for item in dic.values():
if (
item.cost + cost_sum
<= self.constants["MAX_BLOCK_COST_CLVM"]
):
if item.cost + cost_sum <= self.constants["MAX_BLOCK_COST_CLVM"]:
spend_bundles.append(item.spend_bundle)
cost_sum += item.cost
else:
@ -181,13 +178,21 @@ class MempoolManager:
unknown_unspent_error: bool = False
removal_amount = uint64(0)
for name in removal_names:
removal_record = await self.unspent_store.get_coin_record(name, pool.header)
removal_record = await self.unspent_store.get_coin_record(
name, pool.header
)
if removal_record is None and name not in additions_dict:
unknown_unspent_error = True
break
elif name in additions_dict:
removal_coin = additions_dict[name]
removal_record = CoinRecord(removal_coin, uint32(pool.header.height + 1), uint32(0), False, False)
removal_record = CoinRecord(
removal_coin,
uint32(pool.header.height + 1),
uint32(0),
False,
False,
)
assert removal_record is not None
removal_amount = uint64(removal_amount + removal_record.coin.amount)
@ -385,9 +390,9 @@ class MempoolManager:
# If old spends height is bigger than the new tip height, try adding spends to the pool
for height in self.old_mempools.keys():
old_spend_dict: Dict[
bytes32, MempoolItem
] = self.old_mempools[height]
old_spend_dict: Dict[bytes32, MempoolItem] = self.old_mempools[
height
]
await self.add_old_spends_to_pool(new_pool, old_spend_dict)
await self.initialize_pool_from_current_pools(new_pool)
@ -396,7 +401,9 @@ class MempoolManager:
for pool in self.mempools.values():
if pool.header.header_hash not in new_pools:
await self.add_to_old_mempool_cache(list(pool.spends.values()), pool.header)
await self.add_to_old_mempool_cache(
list(pool.spends.values()), pool.header
)
self.mempools = new_pools

View File

@ -161,7 +161,7 @@ class FullNodeSimulator(FullNode):
10,
reward_puzzlehash=request.puzzle_hash,
transaction_data_at_height=dict_h,
seed=token_bytes()
seed=token_bytes(),
)
new_lca = more_blocks[-1]

View File

@ -120,62 +120,66 @@ class Wallet:
async def select_coins(self, amount) -> Optional[Set[Coin]]:
""" Returns a set of coins that can be used for generating a new transaction. """
spendable_am = await self.wallet_state_manager.get_unconfirmed_spendable_for_wallet(
self.wallet_info.id
)
if amount > spendable_am:
self.log.warning(
f"Can't select amount higher than our spendable balance {amount}, spendable {spendable_am}"
)
return None
self.log.info(f"About to select coins for amount {amount}")
unspent: Set[
WalletCoinRecord
] = await self.wallet_state_manager.get_spendable_coins_for_wallet(
self.wallet_info.id
)
sum = 0
used_coins: Set = set()
# Try to use coins from the store, if there isn't enough of "unused"
# coins use change coins that are not confirmed yet
unconfirmed_removals: Dict[
bytes32, Coin
] = await self.wallet_state_manager.unconfirmed_removals_for_wallet(
self.wallet_info.id
)
for coinrecord in unspent:
if sum >= amount:
break
if coinrecord.coin.name() in unconfirmed_removals:
continue
sum += coinrecord.coin.amount
used_coins.add(coinrecord.coin)
self.log.info(
f"Selected coin: {coinrecord.coin.name()} at height {coinrecord.confirmed_block_index}!"
)
# This happens when we couldn't use one of the coins because it's already used
# but unconfirmed, and we are waiting for the change. (unconfirmed_additions)
unconfirmed_additions = None
if sum < amount:
raise ValueError(
"Can't make this transaction at the moment. Waiting for the change from the previous transaction."
)
unconfirmed_additions = await self.wallet_state_manager.unconfirmed_additions_for_wallet(
async with self.wallet_state_manager.lock:
spendable_am = await self.wallet_state_manager.get_unconfirmed_spendable_for_wallet(
self.wallet_info.id
)
for coin in unconfirmed_additions.values():
if sum > amount:
break
if coin.name() in unconfirmed_removals:
continue
sum += coin.amount
used_coins.add(coin)
self.log.info(f"Selected used coin: {coin.name()}")
if amount > spendable_am:
self.log.warning(
f"Can't select amount higher than our spendable balance {amount}, spendable {spendable_am}"
)
return None
self.log.info(f"About to select coins for amount {amount}")
unspent: List[WalletCoinRecord] = list(
await self.wallet_state_manager.get_spendable_coins_for_wallet(
self.wallet_info.id
)
)
sum = 0
used_coins: Set = set()
# Use older coins first
unspent.sort(key=lambda r: r.confirmed_block_index)
# Try to use coins from the store, if there isn't enough of "unused"
# coins use change coins that are not confirmed yet
unconfirmed_removals: Dict[
bytes32, Coin
] = await self.wallet_state_manager.unconfirmed_removals_for_wallet(
self.wallet_info.id
)
for coinrecord in unspent:
if sum >= amount:
break
if coinrecord.coin.name() in unconfirmed_removals:
continue
sum += coinrecord.coin.amount
used_coins.add(coinrecord.coin)
self.log.info(
f"Selected coin: {coinrecord.coin.name()} at height {coinrecord.confirmed_block_index}!"
)
# This happens when we couldn't use one of the coins because it's already used
# but unconfirmed, and we are waiting for the change. (unconfirmed_additions)
unconfirmed_additions = None
if sum < amount:
raise ValueError(
"Can't make this transaction at the moment. Waiting for the change from the previous transaction."
)
unconfirmed_additions = await self.wallet_state_manager.unconfirmed_additions_for_wallet(
self.wallet_info.id
)
for coin in unconfirmed_additions.values():
if sum > amount:
break
if coin.name() in unconfirmed_removals:
continue
sum += coin.amount
used_coins.add(coin)
self.log.info(f"Selected used coin: {coin.name()}")
if sum >= amount:
self.log.info(f"Successfully selected coins: {used_coins}")

View File

@ -5,7 +5,6 @@ import concurrent
import random
import logging
import traceback
from blspy import ExtendedPrivateKey
from src.full_node.full_node import OutboundMessageGenerator
from src.types.peer_info import PeerInfo
@ -34,7 +33,6 @@ from src.util.path import path_from_root, mkdir
class WalletNode:
private_key: ExtendedPrivateKey
key_config: Dict
config: Dict
constants: Dict
@ -75,8 +73,6 @@ class WalletNode:
self = WalletNode()
self.config = config
self.key_config = key_config
sk_hex = self.key_config["wallet_sk"]
self.private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex))
self.constants = consensus_constants.copy()
for key, value in override_constants.items():
self.constants[key] = value

View File

@ -132,7 +132,7 @@ class WalletStateManager:
self.wallets[wallet_info.id] = wallet
async with self.puzzle_store.lock:
await self.create_more_puzzle_hashes()
await self.create_more_puzzle_hashes(from_zero=True)
if len(self.block_records) > 0:
# Initializes the state based on the DB block records
@ -178,7 +178,7 @@ class WalletStateManager:
return self
async def create_more_puzzle_hashes(self):
async def create_more_puzzle_hashes(self, from_zero: bool = False):
"""
For all wallets in the user store, generates the first few puzzle hashes so
that we can restore the wallet from only the private keys.
@ -201,6 +201,8 @@ class WalletStateManager:
start_index = last + 1
to_generate -= last - unused
if from_zero:
start_index = 0
for index in range(start_index, start_index + to_generate):
pubkey: PublicKey = target_wallet.get_public_key(index)
puzzle: Program = target_wallet.puzzle_for_pk(bytes(pubkey))
@ -219,6 +221,8 @@ class WalletStateManager:
)
await self.puzzle_store.add_derivation_paths(derivation_paths)
if from_zero and unused is not None and unused > 0:
await self.puzzle_store.set_used_up_to(uint32(unused - 1))
async def get_unused_derivation_record(self, wallet_id: uint32) -> DerivationRecord:
"""

View File

@ -1143,9 +1143,13 @@ class TestBlockchainTransactions:
next_block.transactions_generator,
next_block.transactions_filter,
)
result, removed, error_code = await full_node_1.blockchain.receive_block(bad_block)
result, removed, error_code = await full_node_1.blockchain.receive_block(
bad_block
)
assert result == ReceiveBlockResult.INVALID_BLOCK
assert error_code == Err.INVALID_TRANSACTIONS_FILTER_HASH
result, removed, error_code = await full_node_1.blockchain.receive_block(next_block)
result, removed, error_code = await full_node_1.blockchain.receive_block(
next_block
)
assert result == ReceiveBlockResult.ADDED_TO_HEAD

View File

@ -178,7 +178,9 @@ class TestFullNodeProtocol:
spend_bundles.append(spend_bundle)
# Mempool is full
new_transaction = fnp.NewTransaction(token_bytes(32), uint64(1000000), uint64(1))
new_transaction = fnp.NewTransaction(
token_bytes(32), uint64(1000000), uint64(1)
)
msgs = [x async for x in full_node_1.new_transaction(new_transaction)]
assert len(msgs) == 0

View File

@ -7,9 +7,7 @@ from src.protocols import full_node_protocol
from src.simulator.simulator_protocol import FarmNewBlockProtocol, ReorgProtocol
from src.types.peer_info import PeerInfo
from src.util.ints import uint16, uint32
from tests.setup_nodes import (
setup_simulators_and_wallets
)
from tests.setup_nodes import setup_simulators_and_wallets
from src.consensus.block_rewards import calculate_base_fee, calculate_block_reward
@ -27,15 +25,15 @@ class TestTransactions:
@pytest.fixture(scope="function")
async def two_wallet_nodes(self):
async for _ in setup_simulators_and_wallets(1, 2,
{"COINBASE_FREEZE_PERIOD": 0}
async for _ in setup_simulators_and_wallets(
1, 2, {"COINBASE_FREEZE_PERIOD": 0}
):
yield _
@pytest.fixture(scope="function")
async def three_nodes_two_wallets(self):
async for _ in setup_simulators_and_wallets(3, 2,
{"COINBASE_FREEZE_PERIOD": 0}
async for _ in setup_simulators_and_wallets(
3, 2, {"COINBASE_FREEZE_PERIOD": 0}
):
yield _

View File

@ -335,7 +335,9 @@ async def setup_node_and_two_wallets(dic={}):
pass
async def setup_simulators_and_wallets(simulator_count: int, wallet_count: int, dic: Dict):
async def setup_simulators_and_wallets(
simulator_count: int, wallet_count: int, dic: Dict
):
simulators: List[Tuple[FullNode, ChiaServer]] = []
wallets = []
node_iters = []

View File

@ -5,10 +5,7 @@ import pytest
from blspy import ExtendedPrivateKey
from chiabip158 import PyBIP158
from tests.setup_nodes import (
test_constants,
bt,
setup_simulators_and_wallets)
from tests.setup_nodes import test_constants, bt, setup_simulators_and_wallets
from src.util.config import load_config

View File

@ -7,8 +7,7 @@ from src.protocols import full_node_protocol
from src.simulator.simulator_protocol import FarmNewBlockProtocol, ReorgProtocol
from src.types.peer_info import PeerInfo
from src.util.ints import uint16, uint32
from tests.setup_nodes import (
setup_simulators_and_wallets)
from tests.setup_nodes import setup_simulators_and_wallets
from src.consensus.block_rewards import calculate_base_fee, calculate_block_reward
@ -26,22 +25,22 @@ class TestWalletSimulator:
@pytest.fixture(scope="function")
async def two_wallet_nodes(self):
async for _ in setup_simulators_and_wallets(1, 2,
{"COINBASE_FREEZE_PERIOD": 0}
async for _ in setup_simulators_and_wallets(
1, 2, {"COINBASE_FREEZE_PERIOD": 0}
):
yield _
@pytest.fixture(scope="function")
async def two_wallet_nodes_five_freeze(self):
async for _ in setup_simulators_and_wallets(1, 2,
{"COINBASE_FREEZE_PERIOD": 5}
async for _ in setup_simulators_and_wallets(
1, 2, {"COINBASE_FREEZE_PERIOD": 5}
):
yield _
@pytest.fixture(scope="function")
async def three_sim_two_wallets(self):
async for _ in setup_simulators_and_wallets(3, 2,
{"COINBASE_FREEZE_PERIOD": 0}
async for _ in setup_simulators_and_wallets(
3, 2, {"COINBASE_FREEZE_PERIOD": 0}
):
yield _