mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-21 00:24:37 +03:00
correct balance lookup, streamable fix
This commit is contained in:
parent
c8d32a5958
commit
d9b2685411
@ -181,7 +181,7 @@ class Streamable:
|
||||
f.write(uint32(len(item)).to_bytes(4, "big"))
|
||||
f.write(item.encode("utf-8"))
|
||||
elif f_type is bool:
|
||||
f.write(bytes(item))
|
||||
f.write(int(item).to_bytes(4, "big"))
|
||||
else:
|
||||
raise NotImplementedError(f"can't stream {item}, {f_type}")
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional, List
|
||||
|
||||
from src.types.hashable.Coin import Coin
|
||||
from src.types.hashable.SpendBundle import SpendBundle
|
||||
@ -21,10 +21,9 @@ class TransactionRecord(Streamable):
|
||||
sent: bool
|
||||
created_at_time: uint64
|
||||
spend_bundle: Optional[SpendBundle]
|
||||
additions: Dict[bytes32, Coin]
|
||||
removals: Dict[bytes32, Coin]
|
||||
additions: List[Coin]
|
||||
removals: List[Coin]
|
||||
|
||||
@property
|
||||
def name(self) -> bytes32:
|
||||
if self.spend_bundle:
|
||||
return self.spend_bundle.name()
|
||||
|
@ -67,12 +67,15 @@ class Wallet:
|
||||
self.wallet_state_manager = wallet_state_manager
|
||||
self.pubkey_num_lookup = {}
|
||||
|
||||
self.server = None
|
||||
|
||||
return self
|
||||
|
||||
def get_next_public_key(self) -> PublicKey:
|
||||
pubkey = self.private_key.public_child(self.next_address).get_public_key()
|
||||
self.pubkey_num_lookup[pubkey.serialize()] = self.next_address
|
||||
self.next_address = self.next_address + 1
|
||||
self.wallet_state_manager.next_address = self.next_address
|
||||
return pubkey
|
||||
|
||||
async def get_confirmed_balance(self) -> uint64:
|
||||
@ -103,6 +106,7 @@ class Wallet:
|
||||
def get_new_puzzlehash(self) -> bytes32:
|
||||
puzzle: Program = self.get_new_puzzle()
|
||||
puzzlehash: bytes32 = puzzle.get_hash()
|
||||
self.wallet_state_manager.puzzlehash_set.add(puzzlehash)
|
||||
return puzzlehash
|
||||
|
||||
def set_server(self, server: ChiaServer):
|
||||
|
@ -54,7 +54,7 @@ class WalletNode:
|
||||
self.tx_store = await WalletTransactionStore.create(path)
|
||||
|
||||
self.wallet_state_manager = await WalletStateManager.create(
|
||||
config, self.wallet_store, self.tx_store
|
||||
config, key_config, self.wallet_store, self.tx_store
|
||||
)
|
||||
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)
|
||||
|
||||
|
@ -18,15 +18,18 @@ class WalletStateManager:
|
||||
tx_store: WalletTransactionStore
|
||||
header_hash: List[bytes32]
|
||||
start_index: int
|
||||
next_address: int
|
||||
|
||||
log: logging.Logger
|
||||
|
||||
# TODO Don't allow user to send tx until wallet is synced
|
||||
synced: bool
|
||||
puzzlehash_set: set
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
config: Dict,
|
||||
key_config: Dict,
|
||||
wallet_store: WalletStore,
|
||||
tx_store: WalletTransactionStore,
|
||||
name: str = None,
|
||||
@ -44,7 +47,9 @@ class WalletStateManager:
|
||||
self.wallet_store = wallet_store
|
||||
self.tx_store = tx_store
|
||||
self.synced = False
|
||||
self.next_address = 0
|
||||
|
||||
self.puzzlehash_set = set()
|
||||
return self
|
||||
|
||||
async def get_confirmed_balance(self) -> uint64:
|
||||
@ -65,9 +70,10 @@ class WalletStateManager:
|
||||
removal_amount = 0
|
||||
|
||||
for record in unconfirmed_tx:
|
||||
for name, coin in record.additions.items():
|
||||
addition_amount += coin.amount
|
||||
for name, coin in record.removals.items():
|
||||
for coin in record.additions:
|
||||
if coin.puzzle_hash in self.puzzlehash_set:
|
||||
addition_amount += coin.amount
|
||||
for coin in record.removals:
|
||||
removal_amount += coin.amount
|
||||
|
||||
result = confirmed - removal_amount + addition_amount
|
||||
@ -77,16 +83,16 @@ class WalletStateManager:
|
||||
additions: Dict[bytes32, Coin] = {}
|
||||
unconfirmed_tx = await self.tx_store.get_not_confirmed()
|
||||
for record in unconfirmed_tx:
|
||||
for name, coin in record.additions.items():
|
||||
additions[name] = coin
|
||||
for coin in record.additions:
|
||||
additions[coin.name()] = coin
|
||||
return additions
|
||||
|
||||
async def unconfirmed_removals(self) -> Dict[bytes32, Coin]:
|
||||
removals: Dict[bytes32, Coin] = {}
|
||||
unconfirmed_tx = await self.tx_store.get_not_confirmed()
|
||||
for record in unconfirmed_tx:
|
||||
for name, coin in record.removals.items():
|
||||
removals[name] = coin
|
||||
for coin in record.removals:
|
||||
removals[coin.name()] = coin
|
||||
return removals
|
||||
|
||||
async def select_coins(self, amount) -> Optional[Set[Coin]]:
|
||||
@ -149,16 +155,16 @@ class WalletStateManager:
|
||||
Called from wallet_node before new transaction is sent to the full_node
|
||||
"""
|
||||
now = uint64(int(time.time()))
|
||||
add_dict: Dict[bytes32, Coin] = {}
|
||||
rem_dict: Dict[bytes32, Coin] = {}
|
||||
add_list: List[Coin] = []
|
||||
rem_list: List[Coin] = []
|
||||
for add in spend_bundle.additions():
|
||||
add_dict[add.name()] = add
|
||||
add_list.append(add)
|
||||
for rem in spend_bundle.removals():
|
||||
rem_dict[rem.name()] = rem
|
||||
rem_list.append(rem)
|
||||
|
||||
# Wallet node will use this queue to retry sending this transaction until full nodes receives it
|
||||
tx_record = TransactionRecord(
|
||||
uint32(0), uint32(0), False, False, now, spend_bundle, add_dict, rem_dict
|
||||
uint32(0), uint32(0), False, False, now, spend_bundle, add_list, rem_list
|
||||
)
|
||||
await self.tx_store.add_transaction_record(tx_record)
|
||||
|
||||
|
@ -5,6 +5,7 @@ from blspy import ExtendedPrivateKey
|
||||
|
||||
from src.protocols.wallet_protocol import RespondBody
|
||||
from src.wallet.wallet import Wallet
|
||||
from src.wallet.wallet_node import WalletNode
|
||||
from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
||||
|
||||
|
||||
@ -25,8 +26,10 @@ class TestWallet:
|
||||
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
|
||||
key_config = {"wallet_sk": sk}
|
||||
|
||||
wallet = await Wallet.create({}, key_config)
|
||||
await wallet.wallet_store._clear_database()
|
||||
wallet_node = await WalletNode.create({}, key_config)
|
||||
wallet = wallet_node.wallet
|
||||
await wallet_node.wallet_store._clear_database()
|
||||
await wallet_node.tx_store._clear_database()
|
||||
|
||||
num_blocks = 10
|
||||
blocks = bt.get_consecutive_blocks(
|
||||
@ -39,11 +42,12 @@ class TestWallet:
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
a = RespondBody(blocks[i].body, blocks[i].height)
|
||||
await wallet.received_body(a)
|
||||
await wallet_node.received_body(a)
|
||||
|
||||
assert await wallet.get_confirmed_balance() == 144000000000000
|
||||
|
||||
await wallet.wallet_store.close()
|
||||
await wallet_node.wallet_store.close()
|
||||
await wallet_node.tx_store.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_make_transaction(self, two_nodes):
|
||||
@ -52,10 +56,15 @@ class TestWallet:
|
||||
key_config = {"wallet_sk": sk}
|
||||
key_config_b = {"wallet_sk": sk_b}
|
||||
|
||||
wallet = await Wallet.create({}, key_config)
|
||||
await wallet.wallet_store._clear_database()
|
||||
wallet_b = await Wallet.create({}, key_config_b)
|
||||
await wallet_b.wallet_store._clear_database()
|
||||
wallet_node = await WalletNode.create({}, key_config)
|
||||
wallet = wallet_node.wallet
|
||||
await wallet_node.wallet_store._clear_database()
|
||||
await wallet_node.tx_store._clear_database()
|
||||
|
||||
wallet_node_b = await WalletNode.create({}, key_config_b)
|
||||
wallet_b = wallet_node_b.wallet
|
||||
await wallet_node_b.wallet_store._clear_database()
|
||||
await wallet_node_b.tx_store._clear_database()
|
||||
|
||||
num_blocks = 10
|
||||
blocks = bt.get_consecutive_blocks(
|
||||
@ -68,7 +77,7 @@ class TestWallet:
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
a = RespondBody(blocks[i].body, blocks[i].height)
|
||||
await wallet.received_body(a)
|
||||
await wallet_node.received_body(a)
|
||||
|
||||
assert await wallet.get_confirmed_balance() == 144000000000000
|
||||
|
||||
@ -83,5 +92,8 @@ class TestWallet:
|
||||
assert confirmed_balance == 144000000000000
|
||||
assert unconfirmed_balance == confirmed_balance - 10
|
||||
|
||||
await wallet.wallet_store.close()
|
||||
await wallet_b.wallet_store.close()
|
||||
await wallet_node.wallet_store.close()
|
||||
await wallet_node.tx_store.close()
|
||||
|
||||
await wallet_node_b.wallet_store.close()
|
||||
await wallet_node_b.tx_store.close()
|
||||
|
@ -1,59 +0,0 @@
|
||||
import asyncio
|
||||
import signal
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from blspy import ExtendedPrivateKey
|
||||
|
||||
from src.protocols import full_node_protocol
|
||||
from src.server.outbound_message import NodeType
|
||||
from src.server.server import ChiaServer
|
||||
from src.types.peer_info import PeerInfo
|
||||
from src.wallet.wallet import Wallet
|
||||
from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
|
||||
|
||||
class TestWalletProtocol:
|
||||
@pytest.fixture(scope="function")
|
||||
async def two_nodes(self):
|
||||
async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 0}):
|
||||
yield _
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_connect(self, two_nodes):
|
||||
num_blocks = 10
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(blocks[i])
|
||||
):
|
||||
pass
|
||||
|
||||
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
|
||||
key_config = {"wallet_sk": sk}
|
||||
|
||||
wallet = await Wallet.create({}, key_config)
|
||||
server = ChiaServer(8223, wallet, NodeType.WALLET)
|
||||
|
||||
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all)
|
||||
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all)
|
||||
|
||||
_ = await server.start_server("127.0.0.1", wallet._on_connect)
|
||||
await asyncio.sleep(2)
|
||||
full_node_peer = PeerInfo(server_1._host, server_1._port)
|
||||
_ = await server.start_client(full_node_peer, None)
|
||||
|
||||
start_unf = time.time()
|
||||
while time.time() - start_unf < 3:
|
||||
# TODO check if we've synced proof hashes and verified number of proofs
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
await wallet.wallet_store.close()
|
Loading…
Reference in New Issue
Block a user