correct balance lookup, streamable fix

This commit is contained in:
Yostra 2020-02-24 12:41:17 -08:00
parent c8d32a5958
commit d9b2685411
7 changed files with 50 additions and 88 deletions

View File

@ -181,7 +181,7 @@ class Streamable:
f.write(uint32(len(item)).to_bytes(4, "big"))
f.write(item.encode("utf-8"))
elif f_type is bool:
f.write(bytes(item))
f.write(int(item).to_bytes(4, "big"))
else:
raise NotImplementedError(f"can't stream {item}, {f_type}")

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Optional
from typing import Optional, List
from src.types.hashable.Coin import Coin
from src.types.hashable.SpendBundle import SpendBundle
@ -21,10 +21,9 @@ class TransactionRecord(Streamable):
sent: bool
created_at_time: uint64
spend_bundle: Optional[SpendBundle]
additions: Dict[bytes32, Coin]
removals: Dict[bytes32, Coin]
additions: List[Coin]
removals: List[Coin]
@property
def name(self) -> bytes32:
if self.spend_bundle:
return self.spend_bundle.name()

View File

@ -67,12 +67,15 @@ class Wallet:
self.wallet_state_manager = wallet_state_manager
self.pubkey_num_lookup = {}
self.server = None
return self
def get_next_public_key(self) -> PublicKey:
pubkey = self.private_key.public_child(self.next_address).get_public_key()
self.pubkey_num_lookup[pubkey.serialize()] = self.next_address
self.next_address = self.next_address + 1
self.wallet_state_manager.next_address = self.next_address
return pubkey
async def get_confirmed_balance(self) -> uint64:
@ -103,6 +106,7 @@ class Wallet:
def get_new_puzzlehash(self) -> bytes32:
puzzle: Program = self.get_new_puzzle()
puzzlehash: bytes32 = puzzle.get_hash()
self.wallet_state_manager.puzzlehash_set.add(puzzlehash)
return puzzlehash
def set_server(self, server: ChiaServer):

View File

@ -54,7 +54,7 @@ class WalletNode:
self.tx_store = await WalletTransactionStore.create(path)
self.wallet_state_manager = await WalletStateManager.create(
config, self.wallet_store, self.tx_store
config, key_config, self.wallet_store, self.tx_store
)
self.wallet = await Wallet.create(config, key_config, self.wallet_state_manager)

View File

@ -18,15 +18,18 @@ class WalletStateManager:
tx_store: WalletTransactionStore
header_hash: List[bytes32]
start_index: int
next_address: int
log: logging.Logger
# TODO Don't allow user to send tx until wallet is synced
synced: bool
puzzlehash_set: set
@staticmethod
async def create(
config: Dict,
key_config: Dict,
wallet_store: WalletStore,
tx_store: WalletTransactionStore,
name: str = None,
@ -44,7 +47,9 @@ class WalletStateManager:
self.wallet_store = wallet_store
self.tx_store = tx_store
self.synced = False
self.next_address = 0
self.puzzlehash_set = set()
return self
async def get_confirmed_balance(self) -> uint64:
@ -65,9 +70,10 @@ class WalletStateManager:
removal_amount = 0
for record in unconfirmed_tx:
for name, coin in record.additions.items():
addition_amount += coin.amount
for name, coin in record.removals.items():
for coin in record.additions:
if coin.puzzle_hash in self.puzzlehash_set:
addition_amount += coin.amount
for coin in record.removals:
removal_amount += coin.amount
result = confirmed - removal_amount + addition_amount
@ -77,16 +83,16 @@ class WalletStateManager:
additions: Dict[bytes32, Coin] = {}
unconfirmed_tx = await self.tx_store.get_not_confirmed()
for record in unconfirmed_tx:
for name, coin in record.additions.items():
additions[name] = coin
for coin in record.additions:
additions[coin.name()] = coin
return additions
async def unconfirmed_removals(self) -> Dict[bytes32, Coin]:
removals: Dict[bytes32, Coin] = {}
unconfirmed_tx = await self.tx_store.get_not_confirmed()
for record in unconfirmed_tx:
for name, coin in record.removals.items():
removals[name] = coin
for coin in record.removals:
removals[coin.name()] = coin
return removals
async def select_coins(self, amount) -> Optional[Set[Coin]]:
@ -149,16 +155,16 @@ class WalletStateManager:
Called from wallet_node before new transaction is sent to the full_node
"""
now = uint64(int(time.time()))
add_dict: Dict[bytes32, Coin] = {}
rem_dict: Dict[bytes32, Coin] = {}
add_list: List[Coin] = []
rem_list: List[Coin] = []
for add in spend_bundle.additions():
add_dict[add.name()] = add
add_list.append(add)
for rem in spend_bundle.removals():
rem_dict[rem.name()] = rem
rem_list.append(rem)
# Wallet node will use this queue to retry sending this transaction until full nodes receives it
tx_record = TransactionRecord(
uint32(0), uint32(0), False, False, now, spend_bundle, add_dict, rem_dict
uint32(0), uint32(0), False, False, now, spend_bundle, add_list, rem_list
)
await self.tx_store.add_transaction_record(tx_record)

View File

@ -5,6 +5,7 @@ from blspy import ExtendedPrivateKey
from src.protocols.wallet_protocol import RespondBody
from src.wallet.wallet import Wallet
from src.wallet.wallet_node import WalletNode
from tests.setup_nodes import setup_two_nodes, test_constants, bt
@ -25,8 +26,10 @@ class TestWallet:
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
key_config = {"wallet_sk": sk}
wallet = await Wallet.create({}, key_config)
await wallet.wallet_store._clear_database()
wallet_node = await WalletNode.create({}, key_config)
wallet = wallet_node.wallet
await wallet_node.wallet_store._clear_database()
await wallet_node.tx_store._clear_database()
num_blocks = 10
blocks = bt.get_consecutive_blocks(
@ -39,11 +42,12 @@ class TestWallet:
for i in range(1, num_blocks):
a = RespondBody(blocks[i].body, blocks[i].height)
await wallet.received_body(a)
await wallet_node.received_body(a)
assert await wallet.get_confirmed_balance() == 144000000000000
await wallet.wallet_store.close()
await wallet_node.wallet_store.close()
await wallet_node.tx_store.close()
@pytest.mark.asyncio
async def test_wallet_make_transaction(self, two_nodes):
@ -52,10 +56,15 @@ class TestWallet:
key_config = {"wallet_sk": sk}
key_config_b = {"wallet_sk": sk_b}
wallet = await Wallet.create({}, key_config)
await wallet.wallet_store._clear_database()
wallet_b = await Wallet.create({}, key_config_b)
await wallet_b.wallet_store._clear_database()
wallet_node = await WalletNode.create({}, key_config)
wallet = wallet_node.wallet
await wallet_node.wallet_store._clear_database()
await wallet_node.tx_store._clear_database()
wallet_node_b = await WalletNode.create({}, key_config_b)
wallet_b = wallet_node_b.wallet
await wallet_node_b.wallet_store._clear_database()
await wallet_node_b.tx_store._clear_database()
num_blocks = 10
blocks = bt.get_consecutive_blocks(
@ -68,7 +77,7 @@ class TestWallet:
for i in range(1, num_blocks):
a = RespondBody(blocks[i].body, blocks[i].height)
await wallet.received_body(a)
await wallet_node.received_body(a)
assert await wallet.get_confirmed_balance() == 144000000000000
@ -83,5 +92,8 @@ class TestWallet:
assert confirmed_balance == 144000000000000
assert unconfirmed_balance == confirmed_balance - 10
await wallet.wallet_store.close()
await wallet_b.wallet_store.close()
await wallet_node.wallet_store.close()
await wallet_node.tx_store.close()
await wallet_node_b.wallet_store.close()
await wallet_node_b.tx_store.close()

View File

@ -1,59 +0,0 @@
import asyncio
import signal
import time
import pytest
from blspy import ExtendedPrivateKey
from src.protocols import full_node_protocol
from src.server.outbound_message import NodeType
from src.server.server import ChiaServer
from src.types.peer_info import PeerInfo
from src.wallet.wallet import Wallet
from tests.setup_nodes import setup_two_nodes, test_constants, bt
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
class TestWalletProtocol:
@pytest.fixture(scope="function")
async def two_nodes(self):
async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 0}):
yield _
@pytest.mark.asyncio
async def test_wallet_connect(self, two_nodes):
num_blocks = 10
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, full_node_2, server_1, server_2 = two_nodes
for i in range(1, num_blocks):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):
pass
sk = bytes(ExtendedPrivateKey.from_seed(b"")).hex()
key_config = {"wallet_sk": sk}
wallet = await Wallet.create({}, key_config)
server = ChiaServer(8223, wallet, NodeType.WALLET)
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, server.close_all)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, server.close_all)
_ = await server.start_server("127.0.0.1", wallet._on_connect)
await asyncio.sleep(2)
full_node_peer = PeerInfo(server_1._host, server_1._port)
_ = await server.start_client(full_node_peer, None)
start_unf = time.time()
while time.time() - start_unf < 3:
# TODO check if we've synced proof hashes and verified number of proofs
await asyncio.sleep(0.1)
await wallet.wallet_store.close()