correct balance lookup, streamable fix

This commit is contained in:
Yostra 2020-02-24 12:41:17 -08:00
parent c8d32a5958
commit d9b2685411
7 changed files with 50 additions and 88 deletions

View File

@ -181,7 +181,7 @@ class Streamable:
f.write(uint32(len(item)).to_bytes(4, "big")) f.write(uint32(len(item)).to_bytes(4, "big"))
f.write(item.encode("utf-8")) f.write(item.encode("utf-8"))
elif f_type is bool: elif f_type is bool:
f.write(bytes(item)) f.write(int(item).to_bytes(4, "big"))
else: else:
raise NotImplementedError(f"can't stream {item}, {f_type}") raise NotImplementedError(f"can't stream {item}, {f_type}")

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional from typing import Optional, List
from src.types.hashable.Coin import Coin from src.types.hashable.Coin import Coin
from src.types.hashable.SpendBundle import SpendBundle from src.types.hashable.SpendBundle import SpendBundle
@ -21,10 +21,9 @@ class TransactionRecord(Streamable):
sent: bool sent: bool
created_at_time: uint64 created_at_time: uint64
spend_bundle: Optional[SpendBundle] spend_bundle: Optional[SpendBundle]
additions: Dict[bytes32, Coin] additions: List[Coin]
removals: Dict[bytes32, Coin] removals: List[Coin]
@property
def name(self) -> bytes32: def name(self) -> bytes32:
if self.spend_bundle: if self.spend_bundle:
return self.spend_bundle.name() return self.spend_bundle.name()

View File

@ -67,12 +67,15 @@ class Wallet:
self.wallet_state_manager = wallet_state_manager self.wallet_state_manager = wallet_state_manager
self.pubkey_num_lookup = {} self.pubkey_num_lookup = {}
self.server = None
return self return self
def get_next_public_key(self) -> PublicKey: def get_next_public_key(self) -> PublicKey:
pubkey = self.private_key.public_child(self.next_address).get_public_key() pubkey = self.private_key.public_child(self.next_address).get_public_key()
self.pubkey_num_lookup[pubkey.serialize()] = self.next_address self.pubkey_num_lookup[pubkey.serialize()] = self.next_address
self.next_address = self.next_address + 1 self.next_address = self.next_address + 1
self.wallet_state_manager.next_address = self.next_address
return pubkey return pubkey
async def get_confirmed_balance(self) -> uint64: async def get_confirmed_balance(self) -> uint64:
@ -103,6 +106,7 @@ class Wallet:
def get_new_puzzlehash(self) -> bytes32: def get_new_puzzlehash(self) -> bytes32:
puzzle: Program = self.get_new_puzzle() puzzle: Program = self.get_new_puzzle()
puzzlehash: bytes32 = puzzle.get_hash() puzzlehash: bytes32 = puzzle.get_hash()
self.wallet_state_manager.puzzlehash_set.add(puzzlehash)
return puzzlehash return puzzlehash
def set_server(self, server: ChiaServer): def set_server(self, server: ChiaServer):

View File

@ -54,7 +54,7 @@ class WalletNode:
self.tx_store = await WalletTransactionStore.create(path) self.tx_store = await WalletTransactionStore.create(path)
self.wallet_state_manager = await WalletStateManager.create( self.wallet_state_manager = await WalletStateManager.create(
config, self.wallet_store, self.tx_store config, key_config, self.wallet_store, self.tx_store
) )
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager) self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)

View File

@ -18,15 +18,18 @@ class WalletStateManager:
tx_store: WalletTransactionStore tx_store: WalletTransactionStore
header_hash: List[bytes32] header_hash: List[bytes32]
start_index: int start_index: int
next_address: int
log: logging.Logger log: logging.Logger
# TODO Don't allow user to send tx until wallet is synced # TODO Don't allow user to send tx until wallet is synced
synced: bool synced: bool
puzzlehash_set: set
@staticmethod @staticmethod
async def create( async def create(
config: Dict, config: Dict,
key_config: Dict,
wallet_store: WalletStore, wallet_store: WalletStore,
tx_store: WalletTransactionStore, tx_store: WalletTransactionStore,
name: str = None, name: str = None,
@ -44,7 +47,9 @@ class WalletStateManager:
self.wallet_store = wallet_store self.wallet_store = wallet_store
self.tx_store = tx_store self.tx_store = tx_store
self.synced = False self.synced = False
self.next_address = 0
self.puzzlehash_set = set()
return self return self
async def get_confirmed_balance(self) -> uint64: async def get_confirmed_balance(self) -> uint64:
@ -65,9 +70,10 @@ class WalletStateManager:
removal_amount = 0 removal_amount = 0
for record in unconfirmed_tx: for record in unconfirmed_tx:
for name, coin in record.additions.items(): for coin in record.additions:
addition_amount += coin.amount if coin.puzzle_hash in self.puzzlehash_set:
for name, coin in record.removals.items(): addition_amount += coin.amount
for coin in record.removals:
removal_amount += coin.amount removal_amount += coin.amount
result = confirmed - removal_amount + addition_amount result = confirmed - removal_amount + addition_amount
@ -77,16 +83,16 @@ class WalletStateManager:
additions: Dict[bytes32, Coin] = {} additions: Dict[bytes32, Coin] = {}
unconfirmed_tx = await self.tx_store.get_not_confirmed() unconfirmed_tx = await self.tx_store.get_not_confirmed()
for record in unconfirmed_tx: for record in unconfirmed_tx:
for name, coin in record.additions.items(): for coin in record.additions:
additions[name] = coin additions[coin.name()] = coin
return additions return additions
async def unconfirmed_removals(self) -> Dict[bytes32, Coin]: async def unconfirmed_removals(self) -> Dict[bytes32, Coin]:
removals: Dict[bytes32, Coin] = {} removals: Dict[bytes32, Coin] = {}
unconfirmed_tx = await self.tx_store.get_not_confirmed() unconfirmed_tx = await self.tx_store.get_not_confirmed()
for record in unconfirmed_tx: for record in unconfirmed_tx:
for name, coin in record.removals.items(): for coin in record.removals:
removals[name] = coin removals[coin.name()] = coin
return removals return removals
async def select_coins(self, amount) -> Optional[Set[Coin]]: async def select_coins(self, amount) -> Optional[Set[Coin]]:
@ -149,16 +155,16 @@ class WalletStateManager:
Called from wallet_node before new transaction is sent to the full_node Called from wallet_node before new transaction is sent to the full_node
""" """
now = uint64(int(time.time())) now = uint64(int(time.time()))
add_dict: Dict[bytes32, Coin] = {} add_list: List[Coin] = []
rem_dict: Dict[bytes32, Coin] = {} rem_list: List[Coin] = []
for add in spend_bundle.additions(): for add in spend_bundle.additions():
add_dict[add.name()] = add add_list.append(add)
for rem in spend_bundle.removals(): for rem in spend_bundle.removals():
rem_dict[rem.name()] = rem rem_list.append(rem)
# Wallet node will use this queue to retry sending this transaction until full nodes receives it # Wallet node will use this queue to retry sending this transaction until full nodes receives it
tx_record = TransactionRecord( tx_record = TransactionRecord(
uint32(0), uint32(0), False, False, now, spend_bundle, add_dict, rem_dict uint32(0), uint32(0), False, False, now, spend_bundle, add_list, rem_list
) )
await self.tx_store.add_transaction_record(tx_record) await self.tx_store.add_transaction_record(tx_record)

View File

@ -5,6 +5,7 @@ from blspy import ExtendedPrivateKey
from src.protocols.wallet_protocol import RespondBody from src.protocols.wallet_protocol import RespondBody
from src.wallet.wallet import Wallet from src.wallet.wallet import Wallet
from src.wallet.wallet_node import WalletNode
from tests.setup_nodes import setup_two_nodes, test_constants, bt from tests.setup_nodes import setup_two_nodes, test_constants, bt
@ -25,8 +26,10 @@ class TestWallet:
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex() sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
key_config = {"wallet_sk": sk} key_config = {"wallet_sk": sk}
wallet = await Wallet.create({}, key_config) wallet_node = await WalletNode.create({}, key_config)
await wallet.wallet_store._clear_database() wallet = wallet_node.wallet
await wallet_node.wallet_store._clear_database()
await wallet_node.tx_store._clear_database()
num_blocks = 10 num_blocks = 10
blocks = bt.get_consecutive_blocks( blocks = bt.get_consecutive_blocks(
@ -39,11 +42,12 @@ class TestWallet:
for i in range(1, num_blocks): for i in range(1, num_blocks):
a = RespondBody(blocks[i].body, blocks[i].height) a = RespondBody(blocks[i].body, blocks[i].height)
await wallet.received_body(a) await wallet_node.received_body(a)
assert await wallet.get_confirmed_balance() == 144000000000000 assert await wallet.get_confirmed_balance() == 144000000000000
await wallet.wallet_store.close() await wallet_node.wallet_store.close()
await wallet_node.tx_store.close()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_wallet_make_transaction(self, two_nodes): async def test_wallet_make_transaction(self, two_nodes):
@ -52,10 +56,15 @@ class TestWallet:
key_config = {"wallet_sk": sk} key_config = {"wallet_sk": sk}
key_config_b = {"wallet_sk": sk_b} key_config_b = {"wallet_sk": sk_b}
wallet = await Wallet.create({}, key_config) wallet_node = await WalletNode.create({}, key_config)
await wallet.wallet_store._clear_database() wallet = wallet_node.wallet
wallet_b = await Wallet.create({}, key_config_b) await wallet_node.wallet_store._clear_database()
await wallet_b.wallet_store._clear_database() await wallet_node.tx_store._clear_database()
wallet_node_b = await WalletNode.create({}, key_config_b)
wallet_b = wallet_node_b.wallet
await wallet_node_b.wallet_store._clear_database()
await wallet_node_b.tx_store._clear_database()
num_blocks = 10 num_blocks = 10
blocks = bt.get_consecutive_blocks( blocks = bt.get_consecutive_blocks(
@ -68,7 +77,7 @@ class TestWallet:
for i in range(1, num_blocks): for i in range(1, num_blocks):
a = RespondBody(blocks[i].body, blocks[i].height) a = RespondBody(blocks[i].body, blocks[i].height)
await wallet.received_body(a) await wallet_node.received_body(a)
assert await wallet.get_confirmed_balance() == 144000000000000 assert await wallet.get_confirmed_balance() == 144000000000000
@ -83,5 +92,8 @@ class TestWallet:
assert confirmed_balance == 144000000000000 assert confirmed_balance == 144000000000000
assert unconfirmed_balance == confirmed_balance - 10 assert unconfirmed_balance == confirmed_balance - 10
await wallet.wallet_store.close() await wallet_node.wallet_store.close()
await wallet_b.wallet_store.close() await wallet_node.tx_store.close()
await wallet_node_b.wallet_store.close()
await wallet_node_b.tx_store.close()

View File

@ -1,59 +0,0 @@
import asyncio
import signal
import time
import pytest
from blspy import ExtendedPrivateKey
from src.protocols import full_node_protocol
from src.server.outbound_message import NodeType
from src.server.server import ChiaServer
from src.types.peer_info import PeerInfo
from src.wallet.wallet import Wallet
from tests.setup_nodes import setup_two_nodes, test_constants, bt
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
class TestWalletProtocol:
@pytest.fixture(scope="function")
async def two_nodes(self):
async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 0}):
yield _
@pytest.mark.asyncio
async def test_wallet_connect(self, two_nodes):
num_blocks = 10
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, full_node_2, server_1, server_2 = two_nodes
for i in range(1, num_blocks):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
key_config = {"wallet_sk": sk}
wallet = await Wallet.create({}, key_config)
server = ChiaServer(8223, wallet, NodeType.WALLET)
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all)
_ = await server.start_server("127.0.0.1", wallet._on_connect)
await asyncio.sleep(2)
full_node_peer = PeerInfo(server_1._host, server_1._port)
_ = await server.start_client(full_node_peer, None)
start_unf = time.time()
while time.time() - start_unf < 3:
# TODO check if we've synced proof hashes and verified number of proofs
await asyncio.sleep(0.1)
await wallet.wallet_store.close()