mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-21 00:24:37 +03:00
correct balance lookup, streamable fix
This commit is contained in:
parent
c8d32a5958
commit
d9b2685411
@ -181,7 +181,7 @@ class Streamable:
|
|||||||
f.write(uint32(len(item)).to_bytes(4, "big"))
|
f.write(uint32(len(item)).to_bytes(4, "big"))
|
||||||
f.write(item.encode("utf-8"))
|
f.write(item.encode("utf-8"))
|
||||||
elif f_type is bool:
|
elif f_type is bool:
|
||||||
f.write(bytes(item))
|
f.write(int(item).to_bytes(4, "big"))
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"can't stream {item}, {f_type}")
|
raise NotImplementedError(f"can't stream {item}, {f_type}")
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, Optional
|
from typing import Optional, List
|
||||||
|
|
||||||
from src.types.hashable.Coin import Coin
|
from src.types.hashable.Coin import Coin
|
||||||
from src.types.hashable.SpendBundle import SpendBundle
|
from src.types.hashable.SpendBundle import SpendBundle
|
||||||
@ -21,10 +21,9 @@ class TransactionRecord(Streamable):
|
|||||||
sent: bool
|
sent: bool
|
||||||
created_at_time: uint64
|
created_at_time: uint64
|
||||||
spend_bundle: Optional[SpendBundle]
|
spend_bundle: Optional[SpendBundle]
|
||||||
additions: Dict[bytes32, Coin]
|
additions: List[Coin]
|
||||||
removals: Dict[bytes32, Coin]
|
removals: List[Coin]
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> bytes32:
|
def name(self) -> bytes32:
|
||||||
if self.spend_bundle:
|
if self.spend_bundle:
|
||||||
return self.spend_bundle.name()
|
return self.spend_bundle.name()
|
||||||
|
@ -67,12 +67,15 @@ class Wallet:
|
|||||||
self.wallet_state_manager = wallet_state_manager
|
self.wallet_state_manager = wallet_state_manager
|
||||||
self.pubkey_num_lookup = {}
|
self.pubkey_num_lookup = {}
|
||||||
|
|
||||||
|
self.server = None
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_next_public_key(self) -> PublicKey:
|
def get_next_public_key(self) -> PublicKey:
|
||||||
pubkey = self.private_key.public_child(self.next_address).get_public_key()
|
pubkey = self.private_key.public_child(self.next_address).get_public_key()
|
||||||
self.pubkey_num_lookup[pubkey.serialize()] = self.next_address
|
self.pubkey_num_lookup[pubkey.serialize()] = self.next_address
|
||||||
self.next_address = self.next_address + 1
|
self.next_address = self.next_address + 1
|
||||||
|
self.wallet_state_manager.next_address = self.next_address
|
||||||
return pubkey
|
return pubkey
|
||||||
|
|
||||||
async def get_confirmed_balance(self) -> uint64:
|
async def get_confirmed_balance(self) -> uint64:
|
||||||
@ -103,6 +106,7 @@ class Wallet:
|
|||||||
def get_new_puzzlehash(self) -> bytes32:
|
def get_new_puzzlehash(self) -> bytes32:
|
||||||
puzzle: Program = self.get_new_puzzle()
|
puzzle: Program = self.get_new_puzzle()
|
||||||
puzzlehash: bytes32 = puzzle.get_hash()
|
puzzlehash: bytes32 = puzzle.get_hash()
|
||||||
|
self.wallet_state_manager.puzzlehash_set.add(puzzlehash)
|
||||||
return puzzlehash
|
return puzzlehash
|
||||||
|
|
||||||
def set_server(self, server: ChiaServer):
|
def set_server(self, server: ChiaServer):
|
||||||
|
@ -54,7 +54,7 @@ class WalletNode:
|
|||||||
self.tx_store = await WalletTransactionStore.create(path)
|
self.tx_store = await WalletTransactionStore.create(path)
|
||||||
|
|
||||||
self.wallet_state_manager = await WalletStateManager.create(
|
self.wallet_state_manager = await WalletStateManager.create(
|
||||||
config, self.wallet_store, self.tx_store
|
config, key_config, self.wallet_store, self.tx_store
|
||||||
)
|
)
|
||||||
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)
|
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)
|
||||||
|
|
||||||
|
@ -18,15 +18,18 @@ class WalletStateManager:
|
|||||||
tx_store: WalletTransactionStore
|
tx_store: WalletTransactionStore
|
||||||
header_hash: List[bytes32]
|
header_hash: List[bytes32]
|
||||||
start_index: int
|
start_index: int
|
||||||
|
next_address: int
|
||||||
|
|
||||||
log: logging.Logger
|
log: logging.Logger
|
||||||
|
|
||||||
# TODO Don't allow user to send tx until wallet is synced
|
# TODO Don't allow user to send tx until wallet is synced
|
||||||
synced: bool
|
synced: bool
|
||||||
|
puzzlehash_set: set
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create(
|
async def create(
|
||||||
config: Dict,
|
config: Dict,
|
||||||
|
key_config: Dict,
|
||||||
wallet_store: WalletStore,
|
wallet_store: WalletStore,
|
||||||
tx_store: WalletTransactionStore,
|
tx_store: WalletTransactionStore,
|
||||||
name: str = None,
|
name: str = None,
|
||||||
@ -44,7 +47,9 @@ class WalletStateManager:
|
|||||||
self.wallet_store = wallet_store
|
self.wallet_store = wallet_store
|
||||||
self.tx_store = tx_store
|
self.tx_store = tx_store
|
||||||
self.synced = False
|
self.synced = False
|
||||||
|
self.next_address = 0
|
||||||
|
|
||||||
|
self.puzzlehash_set = set()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def get_confirmed_balance(self) -> uint64:
|
async def get_confirmed_balance(self) -> uint64:
|
||||||
@ -65,9 +70,10 @@ class WalletStateManager:
|
|||||||
removal_amount = 0
|
removal_amount = 0
|
||||||
|
|
||||||
for record in unconfirmed_tx:
|
for record in unconfirmed_tx:
|
||||||
for name, coin in record.additions.items():
|
for coin in record.additions:
|
||||||
addition_amount += coin.amount
|
if coin.puzzle_hash in self.puzzlehash_set:
|
||||||
for name, coin in record.removals.items():
|
addition_amount += coin.amount
|
||||||
|
for coin in record.removals:
|
||||||
removal_amount += coin.amount
|
removal_amount += coin.amount
|
||||||
|
|
||||||
result = confirmed - removal_amount + addition_amount
|
result = confirmed - removal_amount + addition_amount
|
||||||
@ -77,16 +83,16 @@ class WalletStateManager:
|
|||||||
additions: Dict[bytes32, Coin] = {}
|
additions: Dict[bytes32, Coin] = {}
|
||||||
unconfirmed_tx = await self.tx_store.get_not_confirmed()
|
unconfirmed_tx = await self.tx_store.get_not_confirmed()
|
||||||
for record in unconfirmed_tx:
|
for record in unconfirmed_tx:
|
||||||
for name, coin in record.additions.items():
|
for coin in record.additions:
|
||||||
additions[name] = coin
|
additions[coin.name()] = coin
|
||||||
return additions
|
return additions
|
||||||
|
|
||||||
async def unconfirmed_removals(self) -> Dict[bytes32, Coin]:
|
async def unconfirmed_removals(self) -> Dict[bytes32, Coin]:
|
||||||
removals: Dict[bytes32, Coin] = {}
|
removals: Dict[bytes32, Coin] = {}
|
||||||
unconfirmed_tx = await self.tx_store.get_not_confirmed()
|
unconfirmed_tx = await self.tx_store.get_not_confirmed()
|
||||||
for record in unconfirmed_tx:
|
for record in unconfirmed_tx:
|
||||||
for name, coin in record.removals.items():
|
for coin in record.removals:
|
||||||
removals[name] = coin
|
removals[coin.name()] = coin
|
||||||
return removals
|
return removals
|
||||||
|
|
||||||
async def select_coins(self, amount) -> Optional[Set[Coin]]:
|
async def select_coins(self, amount) -> Optional[Set[Coin]]:
|
||||||
@ -149,16 +155,16 @@ class WalletStateManager:
|
|||||||
Called from wallet_node before new transaction is sent to the full_node
|
Called from wallet_node before new transaction is sent to the full_node
|
||||||
"""
|
"""
|
||||||
now = uint64(int(time.time()))
|
now = uint64(int(time.time()))
|
||||||
add_dict: Dict[bytes32, Coin] = {}
|
add_list: List[Coin] = []
|
||||||
rem_dict: Dict[bytes32, Coin] = {}
|
rem_list: List[Coin] = []
|
||||||
for add in spend_bundle.additions():
|
for add in spend_bundle.additions():
|
||||||
add_dict[add.name()] = add
|
add_list.append(add)
|
||||||
for rem in spend_bundle.removals():
|
for rem in spend_bundle.removals():
|
||||||
rem_dict[rem.name()] = rem
|
rem_list.append(rem)
|
||||||
|
|
||||||
# Wallet node will use this queue to retry sending this transaction until full nodes receives it
|
# Wallet node will use this queue to retry sending this transaction until full nodes receives it
|
||||||
tx_record = TransactionRecord(
|
tx_record = TransactionRecord(
|
||||||
uint32(0), uint32(0), False, False, now, spend_bundle, add_dict, rem_dict
|
uint32(0), uint32(0), False, False, now, spend_bundle, add_list, rem_list
|
||||||
)
|
)
|
||||||
await self.tx_store.add_transaction_record(tx_record)
|
await self.tx_store.add_transaction_record(tx_record)
|
||||||
|
|
||||||
|
@ -5,6 +5,7 @@ from blspy import ExtendedPrivateKey
|
|||||||
|
|
||||||
from src.protocols.wallet_protocol import RespondBody
|
from src.protocols.wallet_protocol import RespondBody
|
||||||
from src.wallet.wallet import Wallet
|
from src.wallet.wallet import Wallet
|
||||||
|
from src.wallet.wallet_node import WalletNode
|
||||||
from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
||||||
|
|
||||||
|
|
||||||
@ -25,8 +26,10 @@ class TestWallet:
|
|||||||
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
|
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
|
||||||
key_config = {"wallet_sk": sk}
|
key_config = {"wallet_sk": sk}
|
||||||
|
|
||||||
wallet = await Wallet.create({}, key_config)
|
wallet_node = await WalletNode.create({}, key_config)
|
||||||
await wallet.wallet_store._clear_database()
|
wallet = wallet_node.wallet
|
||||||
|
await wallet_node.wallet_store._clear_database()
|
||||||
|
await wallet_node.tx_store._clear_database()
|
||||||
|
|
||||||
num_blocks = 10
|
num_blocks = 10
|
||||||
blocks = bt.get_consecutive_blocks(
|
blocks = bt.get_consecutive_blocks(
|
||||||
@ -39,11 +42,12 @@ class TestWallet:
|
|||||||
|
|
||||||
for i in range(1, num_blocks):
|
for i in range(1, num_blocks):
|
||||||
a = RespondBody(blocks[i].body, blocks[i].height)
|
a = RespondBody(blocks[i].body, blocks[i].height)
|
||||||
await wallet.received_body(a)
|
await wallet_node.received_body(a)
|
||||||
|
|
||||||
assert await wallet.get_confirmed_balance() == 144000000000000
|
assert await wallet.get_confirmed_balance() == 144000000000000
|
||||||
|
|
||||||
await wallet.wallet_store.close()
|
await wallet_node.wallet_store.close()
|
||||||
|
await wallet_node.tx_store.close()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_wallet_make_transaction(self, two_nodes):
|
async def test_wallet_make_transaction(self, two_nodes):
|
||||||
@ -52,10 +56,15 @@ class TestWallet:
|
|||||||
key_config = {"wallet_sk": sk}
|
key_config = {"wallet_sk": sk}
|
||||||
key_config_b = {"wallet_sk": sk_b}
|
key_config_b = {"wallet_sk": sk_b}
|
||||||
|
|
||||||
wallet = await Wallet.create({}, key_config)
|
wallet_node = await WalletNode.create({}, key_config)
|
||||||
await wallet.wallet_store._clear_database()
|
wallet = wallet_node.wallet
|
||||||
wallet_b = await Wallet.create({}, key_config_b)
|
await wallet_node.wallet_store._clear_database()
|
||||||
await wallet_b.wallet_store._clear_database()
|
await wallet_node.tx_store._clear_database()
|
||||||
|
|
||||||
|
wallet_node_b = await WalletNode.create({}, key_config_b)
|
||||||
|
wallet_b = wallet_node_b.wallet
|
||||||
|
await wallet_node_b.wallet_store._clear_database()
|
||||||
|
await wallet_node_b.tx_store._clear_database()
|
||||||
|
|
||||||
num_blocks = 10
|
num_blocks = 10
|
||||||
blocks = bt.get_consecutive_blocks(
|
blocks = bt.get_consecutive_blocks(
|
||||||
@ -68,7 +77,7 @@ class TestWallet:
|
|||||||
|
|
||||||
for i in range(1, num_blocks):
|
for i in range(1, num_blocks):
|
||||||
a = RespondBody(blocks[i].body, blocks[i].height)
|
a = RespondBody(blocks[i].body, blocks[i].height)
|
||||||
await wallet.received_body(a)
|
await wallet_node.received_body(a)
|
||||||
|
|
||||||
assert await wallet.get_confirmed_balance() == 144000000000000
|
assert await wallet.get_confirmed_balance() == 144000000000000
|
||||||
|
|
||||||
@ -83,5 +92,8 @@ class TestWallet:
|
|||||||
assert confirmed_balance == 144000000000000
|
assert confirmed_balance == 144000000000000
|
||||||
assert unconfirmed_balance == confirmed_balance - 10
|
assert unconfirmed_balance == confirmed_balance - 10
|
||||||
|
|
||||||
await wallet.wallet_store.close()
|
await wallet_node.wallet_store.close()
|
||||||
await wallet_b.wallet_store.close()
|
await wallet_node.tx_store.close()
|
||||||
|
|
||||||
|
await wallet_node_b.wallet_store.close()
|
||||||
|
await wallet_node_b.tx_store.close()
|
||||||
|
@ -1,59 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import signal
|
|
||||||
import time
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from blspy import ExtendedPrivateKey
|
|
||||||
|
|
||||||
from src.protocols import full_node_protocol
|
|
||||||
from src.server.outbound_message import NodeType
|
|
||||||
from src.server.server import ChiaServer
|
|
||||||
from src.types.peer_info import PeerInfo
|
|
||||||
from src.wallet.wallet import Wallet
|
|
||||||
from tests.setup_nodes import setup_two_nodes, test_constants, bt
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def event_loop():
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
yield loop
|
|
||||||
|
|
||||||
|
|
||||||
class TestWalletProtocol:
|
|
||||||
@pytest.fixture(scope="function")
|
|
||||||
async def two_nodes(self):
|
|
||||||
async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 0}):
|
|
||||||
yield _
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_wallet_connect(self, two_nodes):
|
|
||||||
num_blocks = 10
|
|
||||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
|
||||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
|
||||||
|
|
||||||
for i in range(1, num_blocks):
|
|
||||||
async for _ in full_node_1.respond_block(
|
|
||||||
full_node_protocol.RespondBlock(blocks[i])
|
|
||||||
):
|
|
||||||
pass
|
|
||||||
|
|
||||||
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
|
|
||||||
key_config = {"wallet_sk": sk}
|
|
||||||
|
|
||||||
wallet = await Wallet.create({}, key_config)
|
|
||||||
server = ChiaServer(8223, wallet, NodeType.WALLET)
|
|
||||||
|
|
||||||
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all)
|
|
||||||
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all)
|
|
||||||
|
|
||||||
_ = await server.start_server("127.0.0.1", wallet._on_connect)
|
|
||||||
await asyncio.sleep(2)
|
|
||||||
full_node_peer = PeerInfo(server_1._host, server_1._port)
|
|
||||||
_ = await server.start_client(full_node_peer, None)
|
|
||||||
|
|
||||||
start_unf = time.time()
|
|
||||||
while time.time() - start_unf < 3:
|
|
||||||
# TODO check if we've synced proof hashes and verified number of proofs
|
|
||||||
await asyncio.sleep(0.1)
|
|
||||||
|
|
||||||
await wallet.wallet_store.close()
|
|
Loading…
Reference in New Issue
Block a user