From b5ebf6c4946b04c973dd5d168297df4ca1fc994e Mon Sep 17 00:00:00 2001 From: Mariano Sorgente Date: Tue, 7 Apr 2020 18:17:44 +0900 Subject: [PATCH] Refactor puzzle store, tests, generate future puzzle hashes --- electron-ui/renderer.js | 6 +- scripts/chia-start-sim | 1 - src/rpc/rpc_server.py | 2 +- src/timelord_launcher.py | 3 +- src/util/initial-config.yaml | 2 +- src/wallet/derivation_record.py | 22 ++++ src/wallet/rl_wallet/rl_wallet.py | 125 ++++++++++++--------- src/wallet/transaction_record.py | 2 +- src/wallet/wallet.py | 23 ++-- src/wallet/wallet_info.py | 3 +- src/wallet/wallet_node.py | 44 +------- src/wallet/wallet_puzzle_store.py | 145 +++++++++++++++++-------- src/wallet/wallet_state_manager.py | 138 +++++++++++++++++++++-- src/wallet/wallet_transaction_store.py | 4 +- src/wallet/websocket_server.py | 22 ++-- tests/full_node/test_full_node.py | 20 +--- tests/full_node/test_full_sync.py | 8 +- tests/full_node/test_node_load.py | 8 +- tests/full_node/test_transactions.py | 62 ++++++----- tests/rpc/test_rpc.py | 2 +- tests/setup_nodes.py | 6 +- tests/test_filter.py | 2 +- tests/wallet/test_puzzle_store.py | 99 +++++++++++++++++ tests/wallet/test_wallet.py | 51 ++++----- tests/wallet/test_wallet_sync.py | 30 ++--- 25 files changed, 545 insertions(+), 285 deletions(-) create mode 100644 src/wallet/derivation_record.py create mode 100644 tests/wallet/test_puzzle_store.py diff --git a/electron-ui/renderer.js b/electron-ui/renderer.js index 719700994fe5..1b13f84f4384 100644 --- a/electron-ui/renderer.js +++ b/electron-ui/renderer.js @@ -106,7 +106,6 @@ function set_callbacks(socket) { if (command == "start_server") { get_wallets(); - get_new_puzzlehash(); get_transactions(); get_wallet_balance(g_wallet_id); get_height_info(); @@ -266,6 +265,11 @@ copy.addEventListener("click", () => { }) async function get_new_puzzlehash() { + if (global_syncing) { + alert("Cannot create address while syncing.") + return; + } + /* Sends websocket request for new puzzle_hash */ diff --git a/scripts/chia-start-sim b/scripts/chia-start-sim index 402e59f52550..fc0c03f089cc 100755 --- a/scripts/chia-start-sim +++ b/scripts/chia-start-sim @@ -4,7 +4,6 @@ echo "Starting local blockchain simulation. Runs a local introducer and chia sys echo "Note that this simulation will not work if connected to external nodes." # Starts a harvester, farmer, timelord, introducer, and 3 full nodes, locally. -# Make sure to point the full node in config/config.yaml to the local introducer: 127.0.0.1:8445. # Please note that the simulation is meant to be run locally and not connected to external nodes. _run_bg_cmd python -m src.server.start_harvester --logging.log_stdout=True diff --git a/src/rpc/rpc_server.py b/src/rpc/rpc_server.py index 8cb49c3c44a6..c54ef4afbdbe 100644 --- a/src/rpc/rpc_server.py +++ b/src/rpc/rpc_server.py @@ -240,7 +240,7 @@ async def start_rpc_server( runner = web.AppRunner(app, access_log=None) await runner.setup() - site = web.TCPSite(runner, None, int(rpc_port)) + site = web.TCPSite(runner, "localhost", int(rpc_port)) await site.start() async def cleanup(): diff --git a/src/timelord_launcher.py b/src/timelord_launcher.py index 6ac63714e115..a1f1ca50102f 100644 --- a/src/timelord_launcher.py +++ b/src/timelord_launcher.py @@ -71,10 +71,9 @@ async def spawn_process(host, port, counter): async def spawn_all_processes(): await asyncio.sleep(15) - host = config["host"] port = config["port"] process_count = config["process_count"] - awaitables = [spawn_process(host, port, i) for i in range(process_count)] + awaitables = [spawn_process("127.0.0.1", port, i) for i in range(process_count)] await asyncio.gather(*awaitables) diff --git a/src/util/initial-config.yaml b/src/util/initial-config.yaml index 8f51bca18982..a845abd1d564 100644 --- a/src/util/initial-config.yaml +++ b/src/util/initial-config.yaml @@ -149,7 +149,7 @@ wallet: port: 8444 testing: False - database_path: wallet/db/blockchain_wallet_v4.db + database_path: wallet/db/blockchain_wallet_v5.db logging: *logging diff --git a/src/wallet/derivation_record.py b/src/wallet/derivation_record.py new file mode 100644 index 000000000000..fc45ff2490ca --- /dev/null +++ b/src/wallet/derivation_record.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from blspy import PublicKey + +from src.types.sized_bytes import bytes32 +from src.util.streamable import Streamable, streamable +from src.util.ints import uint32 +from src.wallet.util.wallet_types import WalletType + + +@dataclass(frozen=True) +@streamable +class DerivationRecord(Streamable): + """ + These are records representing a puzzle hash, which is generated from a + public key, derivation index, and wallet type. Stored in the puzzle_store. + """ + + index: uint32 + puzzle_hash: bytes32 + pubkey: PublicKey + wallet_type: WalletType + wallet_id: uint32 diff --git a/src/wallet/rl_wallet/rl_wallet.py b/src/wallet/rl_wallet/rl_wallet.py index 503f623ddb1f..93e1e9b3c7a8 100644 --- a/src/wallet/rl_wallet/rl_wallet.py +++ b/src/wallet/rl_wallet/rl_wallet.py @@ -4,7 +4,7 @@ import logging from binascii import hexlify from dataclasses import dataclass from secrets import token_bytes -from typing import Dict, Optional, List, Tuple +from typing import Dict, Optional, List, Tuple, Any import clvm import json @@ -18,7 +18,7 @@ from src.types.coin_solution import CoinSolution from src.types.program import Program from src.types.spend_bundle import SpendBundle from src.types.sized_bytes import bytes32 -from src.util.ints import uint64 +from src.util.ints import uint64, uint32 from src.util.streamable import streamable, Streamable from src.wallet.rl_wallet.rl_wallet_puzzles import ( rl_puzzle_for_pk, @@ -32,7 +32,7 @@ from src.wallet.util.wallet_types import WalletType from src.wallet.wallet import Wallet from src.wallet.wallet_coin_record import WalletCoinRecord from src.wallet.wallet_info import WalletInfo -from src.wallet.wallet_state_manager import WalletStateManager +from src.wallet.derivation_record import DerivationRecord @dataclass(frozen=True) @@ -53,7 +53,7 @@ class RLWallet: key_config: Dict config: Dict server: Optional[ChiaServer] - wallet_state_manager: WalletStateManager + wallet_state_manager: Any log: logging.Logger wallet_info: WalletInfo rl_coin_record: WalletCoinRecord @@ -64,19 +64,21 @@ class RLWallet: async def create_rl_admin( config: Dict, key_config: Dict, - wallet_state_manager: WalletStateManager, + wallet_state_manager: Any, wallet: Wallet, name: str = None, ): - current_max_index = ( - await wallet_state_manager.puzzle_store.get_max_derivation_path() + 1 - ) - my_pubkey_index = current_max_index + 1 + unused: Optional[ + uint32 + ] = await wallet_state_manager.puzzle_store.get_unused_derivation_path() + if unused is None: + await wallet_state_manager.create_more_puzzle_hashes() + unused = await wallet_state_manager.puzzle_store.get_unused_derivation_path() + assert unused is not None + sk_hex = key_config["wallet_sk"] private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex)) - pubkey_bytes: bytes = bytes( - private_key.public_child(my_pubkey_index).get_public_key() - ) + pubkey_bytes: bytes = bytes(private_key.public_child(unused).get_public_key()) rl_info = RLInfo("admin", pubkey_bytes, None, None, None, None, None, None) info_as_string = json.dumps(rl_info.to_json_dict()) @@ -86,13 +88,19 @@ class RLWallet: wallet_info = await wallet_state_manager.user_store.get_last_wallet() if wallet_info is None: raise - await wallet_state_manager.puzzle_store.add_derivation_path_of_interest( - my_pubkey_index, - token_bytes(), - pubkey_bytes, - WalletType.RATE_LIMITED, - wallet_info.id, + + await wallet_state_manager.puzzle_store.add_derivation_paths( + [ + DerivationRecord( + unused, + token_bytes(), + pubkey_bytes, + WalletType.RATE_LIMITED, + wallet_info.id, + ) + ] ) + await wallet_state_manager.puzzle_store.set_used_up_to(unused) self = await RLWallet.create( config, key_config, wallet_state_manager, wallet_info, wallet, name @@ -103,46 +111,60 @@ class RLWallet: async def create_rl_user( config: Dict, key_config: Dict, - wallet_state_manager: WalletStateManager, + wallet_state_manager: Any, wallet: Wallet, name: str = None, ): - current_max_index = ( - await wallet_state_manager.puzzle_store.get_max_derivation_path() + 1 - ) - my_pubkey_index = current_max_index + 1 - sk_hex = key_config["wallet_sk"] - private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex)) - pubkey_bytes: bytes = bytes( - private_key.public_child(my_pubkey_index).get_public_key() - ) + async with wallet_state_manager.puzzle_store.lock: + unused: Optional[ + uint32 + ] = await wallet_state_manager.puzzle_store.get_unused_derivation_path() + if unused is None: + await wallet_state_manager.create_more_puzzle_hashes() + unused = ( + await wallet_state_manager.puzzle_store.get_unused_derivation_path() + ) + assert unused is not None - rl_info = RLInfo("user", None, pubkey_bytes, None, None, None, None, None) - info_as_string = json.dumps(rl_info.to_json_dict()) - await wallet_state_manager.user_store.create_wallet( - "RL User", WalletType.RATE_LIMITED, info_as_string - ) - wallet_info = await wallet_state_manager.user_store.get_last_wallet() - if wallet_info is None: - raise - await wallet_state_manager.puzzle_store.add_derivation_path_of_interest( - my_pubkey_index, - token_bytes(), - pubkey_bytes, - WalletType.RATE_LIMITED, - wallet_info.id, - ) + sk_hex = key_config["wallet_sk"] + private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex)) + pubkey_bytes: bytes = bytes( + private_key.public_child(unused).get_public_key() + ) + + rl_info = RLInfo("user", None, pubkey_bytes, None, None, None, None, None) + info_as_string = json.dumps(rl_info.to_json_dict()) + await wallet_state_manager.user_store.create_wallet( + "RL User", WalletType.RATE_LIMITED, info_as_string + ) + wallet_info = await wallet_state_manager.user_store.get_last_wallet() + if wallet_info is None: + raise + + self = await RLWallet.create( + config, key_config, wallet_state_manager, wallet_info, wallet, name + ) + + await wallet_state_manager.puzzle_store.add_derivation_paths( + [ + DerivationRecord( + unused, + token_bytes(), + pubkey_bytes, + WalletType.RATE_LIMITED, + wallet_info.id, + ) + ] + ) + await wallet_state_manager.puzzle_store.set_used_up_to(unused) - self = await RLWallet.create( - config, key_config, wallet_state_manager, wallet_info, wallet, name - ) return self @staticmethod async def create( config: Dict, key_config: Dict, - wallet_state_manager: WalletStateManager, + wallet_state_manager: Any, info: WalletInfo, wallet: Wallet, name: str = None, @@ -194,13 +216,16 @@ class RLWallet: index = await self.wallet_state_manager.puzzle_store.index_for_pubkey( self.rl_info.admin_pubkey.hex() ) - await self.wallet_state_manager.puzzle_store.add_derivation_path_of_interest( + + assert index is not None + record = DerivationRecord( index, rl_puzzle_hash, self.rl_info.admin_pubkey, WalletType.RATE_LIMITED, self.wallet_info.id, ) + await self.wallet_state_manager.puzzle_store.add_derivation_paths([record]) spend_bundle = await self.standard_wallet.generate_signed_transaction( amount, rl_puzzle_hash, uint64(0), origin_id, coins @@ -263,13 +288,15 @@ class RLWallet: index = await self.wallet_state_manager.puzzle_store.index_for_pubkey( self.rl_info.user_pubkey.hex() ) - await self.wallet_state_manager.puzzle_store.add_derivation_path_of_interest( + assert index is not None + record = DerivationRecord( index, rl_puzzle_hash, self.rl_info.user_pubkey, WalletType.RATE_LIMITED, self.wallet_info.id, ) + await self.wallet_state_manager.puzzle_store.add_derivation_paths([record]) data_str = json.dumps(new_rl_info.to_json_dict()) new_wallet_info = WalletInfo( diff --git a/src/wallet/transaction_record.py b/src/wallet/transaction_record.py index f0fb7a5207ce..cf853223ef1c 100644 --- a/src/wallet/transaction_record.py +++ b/src/wallet/transaction_record.py @@ -26,7 +26,7 @@ class TransactionRecord(Streamable): spend_bundle: Optional[SpendBundle] additions: List[Coin] removals: List[Coin] - wallet_id: uint64 + wallet_id: uint32 # Represents the list of peers that we sent the transaction to, whether each one # included it in the mempool, and what the error message (if any) was diff --git a/src/wallet/wallet.py b/src/wallet/wallet.py index 8c2f51619bac..5b095e4b527c 100644 --- a/src/wallet/wallet.py +++ b/src/wallet/wallet.py @@ -24,19 +24,16 @@ from src.wallet.puzzles.puzzle_utils import ( make_assert_coin_consumed_condition, make_create_coin_condition, ) -from src.wallet.util.wallet_types import WalletType from src.wallet.wallet_coin_record import WalletCoinRecord from src.wallet.transaction_record import TransactionRecord from src.wallet.wallet_info import WalletInfo -from src.wallet.wallet_state_manager import WalletStateManager - class Wallet: private_key: ExtendedPrivateKey key_config: Dict config: Dict - wallet_state_manager: WalletStateManager + wallet_state_manager: Any log: logging.Logger @@ -46,7 +43,7 @@ class Wallet: async def create( config: Dict, key_config: Dict, - wallet_state_manager: WalletStateManager, + wallet_state_manager: Any, info: WalletInfo, name: str = None, ): @@ -87,17 +84,11 @@ class Wallet: return puzzle_for_pk(pubkey) async def get_new_puzzlehash(self) -> bytes32: - index = await self.wallet_state_manager.puzzle_store.get_max_derivation_path() - index += 1 - pubkey: bytes = bytes(self.get_public_key(index)) - puzzle: Program = self.puzzle_for_pk(pubkey) - puzzlehash: bytes32 = puzzle.get_hash() - - await self.wallet_state_manager.puzzle_store.add_derivation_path_of_interest( - index, puzzlehash, pubkey, WalletType.STANDARD_WALLET, self.wallet_info.id - ) - - return puzzlehash + return ( + await self.wallet_state_manager.get_unused_derivation_record( + self.wallet_info.id + ) + ).puzzle_hash def make_solution(self, primaries=None, min_time=0, me=None, consumed=None): condition_list = [] diff --git a/src/wallet/wallet_info.py b/src/wallet/wallet_info.py index d6c373b32d19..29b54520c3f2 100644 --- a/src/wallet/wallet_info.py +++ b/src/wallet/wallet_info.py @@ -2,6 +2,7 @@ from dataclasses import dataclass from src.util.streamable import streamable, Streamable from src.wallet.util.wallet_types import WalletType +from src.util.ints import uint32 @dataclass(frozen=True) @@ -11,7 +12,7 @@ class WalletInfo(Streamable): # TODO(straya): describe """ - id: int + id: uint32 name: str type: WalletType data: str diff --git a/src/wallet/wallet_node.py b/src/wallet/wallet_node.py index 60f3d241a839..0360e3dd4b77 100644 --- a/src/wallet/wallet_node.py +++ b/src/wallet/wallet_node.py @@ -1,6 +1,6 @@ import asyncio import time -from typing import Dict, Optional, Tuple, List, Any, AsyncGenerator +from typing import Dict, Optional, Tuple, List, AsyncGenerator import concurrent import random import logging @@ -21,11 +21,7 @@ from src.server.outbound_message import OutboundMessage, NodeType, Message, Deli from src.util.ints import uint32, uint64 from src.types.sized_bytes import bytes32 from src.util.api_decorators import api_request -from src.wallet.rl_wallet.rl_wallet import RLWallet from src.wallet.transaction_record import TransactionRecord -from src.wallet.util.wallet_types import WalletType -from src.wallet.wallet import Wallet -from src.wallet.wallet_info import WalletInfo from src.wallet.wallet_state_manager import WalletStateManager from src.wallet.block_record import BlockRecord from src.types.header_block import HeaderBlock @@ -47,8 +43,6 @@ class WalletNode: # Maintains the state of the wallet (blockchain and transactions), handles DB connections wallet_state_manager: WalletStateManager - main_wallet: Wallet - wallets: Dict[int, Any] # Maintains headers recently received. Once the desired removals and additions are downloaded, # the data is persisted in the WalletStateManager. These variables are also used to store @@ -95,44 +89,10 @@ class WalletNode: mkdir(path.parent) self.wallet_state_manager = await WalletStateManager.create( - config, path, self.constants + key_config, config, path, self.constants ) self.wallet_state_manager.set_pending_callback(self.pending_tx_handler) - main_wallet_info = await self.wallet_state_manager.get_main_wallet() - assert main_wallet_info is not None - - self.main_wallet = await Wallet.create( - config, key_config, self.wallet_state_manager, main_wallet_info - ) - - wallets: List[WalletInfo] = await self.wallet_state_manager.get_all_wallets() - self.wallets = {} - main_wallet = await Wallet.create( - config, key_config, self.wallet_state_manager, main_wallet_info - ) - self.main_wallet = main_wallet - self.wallets[main_wallet_info.id] = main_wallet - - for wallet_info in wallets: - self.log.info(f"wallet_info {wallet_info}") - if wallet_info.type == WalletType.STANDARD_WALLET: - if wallet_info.id == 1: - continue - wallet = await Wallet.create( - config, key_config, self.wallet_state_manager, main_wallet_info - ) - self.wallets[wallet_info.id] = wallet - elif wallet_info.type == WalletType.RATE_LIMITED: - wallet = await RLWallet.create( - config, - key_config, - self.wallet_state_manager, - wallet_info, - self.main_wallet, - ) - self.wallets[wallet_info.id] = wallet - # Normal operation data self.cached_blocks = {} self.future_block_hashes = {} diff --git a/src/wallet/wallet_puzzle_store.py b/src/wallet/wallet_puzzle_store.py index 376dadecf801..0fe8e05f8463 100644 --- a/src/wallet/wallet_puzzle_store.py +++ b/src/wallet/wallet_puzzle_store.py @@ -1,9 +1,14 @@ import asyncio -from typing import Set, Tuple, Optional +from blspy import PublicKey +from typing import Set, Tuple, Optional, List import aiosqlite +import logging from src.types.sized_bytes import bytes32 -from src.util.ints import uint32, uint64 +from src.util.ints import uint32 from src.wallet.util.wallet_types import WalletType +from src.wallet.derivation_record import DerivationRecord + +log = logging.getLogger(__name__) class WalletPuzzleStore: @@ -28,14 +33,17 @@ class WalletPuzzleStore: await self.db_connection.execute( ( f"CREATE TABLE IF NOT EXISTS derivation_paths(" - f"id int PRIMARY KEY," + f"derivation_index int," f" pubkey text," - f" puzzle_hash text," + f" puzzle_hash text PRIMARY_KEY," f" wallet_type int," f" wallet_id int," - f" used int)" + f" used tinyint)" ) ) + await self.db_connection.execute( + "CREATE INDEX IF NOT EXISTS derivation_index_index on derivation_paths(derivation_index)" + ) await self.db_connection.execute( "CREATE INDEX IF NOT EXISTS ph on derivation_paths(puzzle_hash)" @@ -74,26 +82,66 @@ class WalletPuzzleStore: await cursor.close() await self.db_connection.commit() - async def add_derivation_path_of_interest( - self, - index: int, - puzzlehash: bytes32, - pubkey: bytes, - wallet_type: WalletType, - wallet_id: int, - ): + async def add_derivation_paths(self, records: List[DerivationRecord]) -> None: """ - Inserts new derivation path, puzzle, pubkey, wallet into DB. + Insert many derivation paths into the database. """ + sql_records = [] + for record in records: + sql_records.append( + ( + record.index, + bytes(record.pubkey).hex(), + record.puzzle_hash.hex(), + record.wallet_type.value, + record.wallet_id, + 0, + ), + ) - cursor = await self.db_connection.execute( + cursor = await self.db_connection.executemany( "INSERT OR REPLACE INTO derivation_paths VALUES(?, ?, ?, ?, ?, ?)", - (index, pubkey.hex(), puzzlehash.hex(), wallet_type.value, wallet_id, 0), + sql_records, ) await cursor.close() await self.db_connection.commit() + async def get_derivation_record( + self, index: uint32, wallet_id: uint32 + ) -> Optional[DerivationRecord]: + """ + Returns the derivation record by index and wallet id. + """ + cursor = await self.db_connection.execute( + "SELECT * FROM derivation_paths WHERE derivation_index=? and wallet_id=?;", + (index, wallet_id,), + ) + row = await cursor.fetchone() + await cursor.close() + + if row is not None and row[0] is not None: + return DerivationRecord( + row[0], + bytes.fromhex(row[2]), + PublicKey.from_bytes(bytes.fromhex(row[1])), + row[3], + row[4], + ) + + return None + + async def set_used_up_to(self, index: uint32) -> None: + """ + Sets a derivation path to used so we don't use it again. + """ + pass + cursor = await self.db_connection.execute( + "UPDATE derivation_paths SET used=1 WHERE derivation_index<=?", (index,), + ) + await cursor.close() + await self.db_connection.commit() + async def puzzle_hash_exists(self, puzzle_hash: bytes32) -> bool: """ Checks if passed puzzle_hash is present in the db. @@ -102,54 +150,50 @@ class WalletPuzzleStore: cursor = await self.db_connection.execute( "SELECT * from derivation_paths WHERE puzzle_hash=?", (puzzle_hash.hex(),) ) - rows = await cursor.fetchall() + row = await cursor.fetchone() await cursor.close() - if len(list(rows)) > 0: - return True + return row is not None - return False - - async def index_for_pubkey(self, pubkey: str) -> int: + async def index_for_pubkey(self, pubkey: PublicKey) -> Optional[uint32]: """ - Returns derivation path for the given pubkey. - Returns -1 if not present. + Returns derivation paths for the given pubkey. + Returns None if not present. """ cursor = await self.db_connection.execute( - "SELECT * from derivation_paths WHERE pubkey=?", (pubkey,) + "SELECT * from derivation_paths WHERE pubkey=?", (bytes(pubkey).hex(),) ) row = await cursor.fetchone() await cursor.close() - if row: - return row[0] + if row is not None: + return uint32(row[0]) - return -1 + return None - async def index_for_puzzle_hash(self, puzzle_hash: bytes32) -> int: + async def index_for_puzzle_hash(self, puzzle_hash: bytes32) -> Optional[uint32]: """ Returns the derivation path for the puzzle_hash. - Returns -1 if not present. + Returns None if not present. """ - cursor = await self.db_connection.execute( "SELECT * from derivation_paths WHERE puzzle_hash=?", (puzzle_hash.hex(),) ) row = await cursor.fetchone() await cursor.close() - if row: - return row[0] + if row is not None: + return uint32(row[0]) - return -1 + return None async def wallet_info_for_puzzle_hash( self, puzzle_hash: bytes32 - ) -> Optional[Tuple[uint64, WalletType]]: + ) -> Optional[Tuple[uint32, WalletType]]: """ Returns the derivation path for the puzzle_hash. - Returns -1 if not present. + Returns None if not present. """ cursor = await self.db_connection.execute( @@ -158,7 +202,7 @@ class WalletPuzzleStore: row = await cursor.fetchone() await cursor.close() - if row: + if row is not None: return row[4], WalletType(row[3]) return None @@ -178,18 +222,33 @@ class WalletPuzzleStore: return result - async def get_max_derivation_path(self): + async def get_last_derivation_path(self) -> Optional[uint32]: """ - Returns the highest derivation path currently stored. + Returns the last derivation path by derivation_index. """ cursor = await self.db_connection.execute( - "SELECT MAX(id) FROM derivation_paths;" + "SELECT MAX(derivation_index) FROM derivation_paths;" ) row = await cursor.fetchone() await cursor.close() - if row[0] is not None: - return row[0] + if row is not None and row[0] is not None: + return uint32(row[0]) - return -1 + return None + + async def get_unused_derivation_path(self) -> Optional[uint32]: + """ + Returns the first unused derivation path by derivation_index. + """ + cursor = await self.db_connection.execute( + "SELECT MIN(derivation_index) FROM derivation_paths WHERE used=0;" + ) + row = await cursor.fetchone() + await cursor.close() + + if row is not None and row[0] is not None: + return uint32(row[0]) + + return None diff --git a/src/wallet/wallet_state_manager.py b/src/wallet/wallet_state_manager.py index 8ff1eb53285b..30ad4fc1ccb3 100644 --- a/src/wallet/wallet_state_manager.py +++ b/src/wallet/wallet_state_manager.py @@ -1,12 +1,13 @@ import time from pathlib import Path -from typing import Dict, Optional, List, Set, Tuple, Callable +from typing import Dict, Optional, List, Set, Tuple, Callable, Any import logging import asyncio import aiosqlite from chiabip158 import PyBIP158 +from blspy import PublicKey from src.types.coin import Coin from src.types.spend_bundle import SpendBundle @@ -30,6 +31,11 @@ from src.util.significant_bits import truncate_to_significant_bits from src.wallet.wallet_user_store import WalletUserStore from src.types.mempool_inclusion_status import MempoolInclusionStatus from src.util.errors import Err +from src.wallet.wallet import Wallet +from src.wallet.rl_wallet.rl_wallet import RLWallet +from src.types.program import Program +from src.wallet.derivation_record import DerivationRecord +from src.wallet.util.wallet_types import WalletType class WalletStateManager: @@ -64,9 +70,16 @@ class WalletStateManager: db_path: Path db_connection: aiosqlite.Connection + main_wallet: Wallet + wallets: Dict[uint32, Any] + @staticmethod async def create( - config: Dict, db_path: Path, constants: Dict, name: str = None, + key_config: Dict, + config: Dict, + db_path: Path, + constants: Dict, + name: str = None, ): self = WalletStateManager() self.config = config @@ -94,6 +107,33 @@ class WalletStateManager: self.difficulty_resets_prev = {} self.db_path = db_path + main_wallet_info = await self.user_store.get_wallet_by_id(1) + assert main_wallet_info is not None + + self.main_wallet = await Wallet.create( + config, key_config, self, main_wallet_info + ) + + self.wallets = {} + main_wallet = await Wallet.create(config, key_config, self, main_wallet_info) + self.wallets[main_wallet_info.id] = main_wallet + + for wallet_info in await self.get_all_wallets(): + self.log.info(f"wallet_info {wallet_info}") + if wallet_info.type == WalletType.STANDARD_WALLET: + if wallet_info.id == 1: + continue + wallet = await Wallet.create(config, key_config, self, main_wallet_info) + self.wallets[wallet_info.id] = wallet + elif wallet_info.type == WalletType.RATE_LIMITED: + wallet = await RLWallet.create( + config, key_config, self, wallet_info, self.main_wallet, + ) + self.wallets[wallet_info.id] = wallet + + async with self.puzzle_store.lock: + await self.create_more_puzzle_hashes() + if len(self.block_records) > 0: # Initializes the state based on the DB block records # Header hash with the highest weight @@ -135,8 +175,80 @@ class WalletStateManager: ), genesis_hb, ) + return self + async def create_more_puzzle_hashes(self): + """ + For all wallets in the user store, generates the first few puzzle hashes so + that we can restore the wallet from only the private keys. + """ + for wallet_id in self.wallets.keys(): + target_wallet = self.wallets[wallet_id] + unused: Optional[ + uint32 + ] = await self.puzzle_store.get_unused_derivation_path() + last: Optional[uint32] = await self.puzzle_store.get_last_derivation_path() + + to_generate = 200 + start_index = 0 + derivation_paths: List[DerivationRecord] = [] + + if last is None: + assert unused is None + if unused is not None: + assert last is not None + start_index = last + 1 + to_generate -= last - unused + + for index in range(start_index, start_index + to_generate): + pubkey: PublicKey = target_wallet.get_public_key(index) + puzzle: Program = target_wallet.puzzle_for_pk(bytes(pubkey)) + puzzlehash: bytes32 = puzzle.get_hash() + self.log.info( + f"Generating public key at index {index} puzzle hash {puzzlehash.hex()}" + ) + derivation_paths.append( + DerivationRecord( + uint32(index), + puzzlehash, + pubkey, + WalletType.STANDARD_WALLET, + uint32(target_wallet.wallet_info.id), + ) + ) + + await self.puzzle_store.add_derivation_paths(derivation_paths) + + async def get_unused_derivation_record(self, wallet_id: uint32) -> DerivationRecord: + """ + Creates a puzzle hash for the given wallet, and then makes more puzzle hashes + for every wallet to ensure we always have more in the database. Never reusue the + same public key more than once (for privacy). + """ + async with self.puzzle_store.lock: + # If we have no unused public keys, we will create new ones + unused: Optional[ + uint32 + ] = await self.puzzle_store.get_unused_derivation_path() + if unused is None: + await self.create_more_puzzle_hashes() + + # Now we must have unused public keys + unused = await self.puzzle_store.get_unused_derivation_path() + assert unused is not None + record: Optional[ + DerivationRecord + ] = await self.puzzle_store.get_derivation_record(unused, wallet_id) + assert record is not None + + # Set this key to used so we never use it again + await self.puzzle_store.set_used_up_to(record.index) + + # Create more puzzle hashes / keys + await self.create_more_puzzle_hashes() + return record + def set_callback(self, callback: Callable): """ Callback to be called when the state of the wallet changes. @@ -302,8 +414,6 @@ class WalletStateManager: if unconfirmed_record: await self.tx_store.set_confirmed(unconfirmed_record.name(), index) - self.state_changed("coin_removed") - async def coin_added(self, coin: Coin, index: uint32, coinbase: bool): """ Adding coin to the db @@ -536,6 +646,18 @@ class WalletStateManager: self.block_records[block.header_hash] = block await self.wallet_store.add_block_record(block, False) + max_puzzle_index = uint32(0) + async with self.puzzle_store.lock: + for addition in block.additions: + index = await self.puzzle_store.index_for_puzzle_hash( + addition.puzzle_hash + ) + assert index is not None + if index > max_puzzle_index: + max_puzzle_index = index + await self.puzzle_store.set_used_up_to(max_puzzle_index) + await self.create_more_puzzle_hashes() + # Genesis case if self.lca is None: assert block.height == 0 @@ -545,6 +667,8 @@ class WalletStateManager: await self.coin_added(coin, block.height, False) for coin_name in block.removals: await self.coin_removed(coin_name, block.height) + self.state_changed("coin_added") + self.state_changed("coin_removed") self.height_to_hash[uint32(0)] = block.header_hash return ReceiveBlockResult.ADDED_TO_HEAD @@ -589,7 +713,10 @@ class WalletStateManager: await self.coin_added(coin, path_block.height, is_coinbase) for coin_name in path_block.removals: await self.coin_removed(coin_name, path_block.height) + self.lca = block.header_hash + self.state_changed("coin_added") + self.state_changed("coin_removed") self.state_changed("new_block") return ReceiveBlockResult.ADDED_TO_HEAD @@ -1133,9 +1260,6 @@ class WalletStateManager: async def get_all_wallets(self) -> List[WalletInfo]: return await self.user_store.get_all_wallets() - async def get_main_wallet(self): - return await self.user_store.get_wallet_by_id(1) - async def get_coin_records_by_spent(self, spent: bool): return await self.wallet_store.get_coin_records_by_spent(spent) diff --git a/src/wallet/wallet_transaction_store.py b/src/wallet/wallet_transaction_store.py index 6059c6a9b266..fcef9b169f44 100644 --- a/src/wallet/wallet_transaction_store.py +++ b/src/wallet/wallet_transaction_store.py @@ -30,7 +30,7 @@ class WalletTransactionStore: f"CREATE TABLE IF NOT EXISTS transaction_record(" f" transaction_record blob," f" bundle_id text PRIMARY KEY," - f" confirmed_at_index int," + f" confirmed_at_index bigint," f" created_at_time bigint," f" to_puzzle_hash text," f" amount bigint," @@ -38,7 +38,7 @@ class WalletTransactionStore: f" incoming int," f" confirmed int," f" sent int," - f" wallet_id int)" + f" wallet_id bigint)" ) ) diff --git a/src/wallet/websocket_server.py b/src/wallet/websocket_server.py index e6f9a710ae87..b6500df71f46 100644 --- a/src/wallet/websocket_server.py +++ b/src/wallet/websocket_server.py @@ -80,7 +80,7 @@ class WebSocketServer: """ wallet_id = int(request["wallet_id"]) - wallet = self.wallet_node.wallets[wallet_id] + wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id] puzzlehash = (await wallet.get_new_puzzlehash()).hex() data = { @@ -91,7 +91,7 @@ class WebSocketServer: async def send_transaction(self, websocket, request, response_api): wallet_id = int(request["wallet_id"]) - wallet = self.wallet_node.wallets[wallet_id] + wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id] try: tx = await wallet.generate_signed_transaction_dict(request) except BaseException as e: @@ -173,7 +173,7 @@ class WebSocketServer: async def get_wallet_balance(self, websocket, request, response_api): wallet_id = int(request["wallet_id"]) - wallet = self.wallet_node.wallets[wallet_id] + wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id] balance = await wallet.get_confirmed_balance() pending_balance = await wallet.get_unconfirmed_balance() @@ -217,14 +217,18 @@ class WebSocketServer: rl_admin: RLWallet = await RLWallet.create_rl_admin( config, key_config, wallet_state_manager, main_wallet ) - self.wallet_node.wallets[rl_admin.wallet_info.id] = rl_admin + self.wallet_node.wallet_state_manager.wallets[ + rl_admin.wallet_info.id + ] = rl_admin response = {"success": True, "type": "rl_wallet"} return await websocket.send(format_response(response_api, response)) elif request["mode"] == "user": rl_user: RLWallet = await RLWallet.create_rl_user( config, key_config, wallet_state_manager, main_wallet ) - self.wallet_node.wallets[rl_user.wallet_info.id] = rl_user + self.wallet_node.wallet_state_manager.wallets[ + rl_user.wallet_info.id + ] = rl_user response = {"success": True, "type": "rl_wallet"} return await websocket.send(format_response(response_api, response)) elif request["wallet_type"] == "cc_wallet": @@ -238,7 +242,7 @@ class WebSocketServer: self.wallet_node.config, self.wallet_node.key_config, self.wallet_node.wallet_state_manager, - self.wallet_node.main_wallet, + self.wallet_node.wallet_state_manager.main_wallet, ) async def get_wallets(self, websocket, response_api): @@ -252,7 +256,7 @@ class WebSocketServer: async def rl_set_admin_info(self, websocket, request, response_api): wallet_id = int(request["wallet_id"]) - wallet: RLWallet = self.wallet_node.wallets[wallet_id] + wallet: RLWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id] user_pubkey = request["user_pubkey"] limit = uint64(int(request["limit"])) interval = uint64(int(request["interval"])) @@ -266,7 +270,7 @@ class WebSocketServer: async def rl_set_user_info(self, websocket, request, response_api): wallet_id = int(request["wallet_id"]) - wallet: RLWallet = self.wallet_node.wallets[wallet_id] + wallet: RLWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id] admin_pubkey = request["admin_pubkey"] limit = uint64(int(request["limit"])) interval = uint64(int(request["interval"])) @@ -396,7 +400,7 @@ async def start_websocket_server(): log.info("Starting websocket server.") websocket_server = await websockets.serve( - handler.safe_handle, None, config["rpc_port"] + handler.safe_handle, "localhost", config["rpc_port"] ) log.info(f"Started websocket server at port {config['rpc_port']}.") diff --git a/tests/full_node/test_full_node.py b/tests/full_node/test_full_node.py index 98fdf6a2975b..95fffcc9fdf1 100644 --- a/tests/full_node/test_full_node.py +++ b/tests/full_node/test_full_node.py @@ -78,9 +78,7 @@ class TestFullNodeProtocol: full_node_1, full_node_2, server_1, server_2 = two_nodes _, _, blocks = wallet_blocks - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("127.0.0.1", uint16(server_1._port)), None) await asyncio.sleep(2) # Allow connections to get made new_tip_1 = fnp.NewTip( @@ -772,9 +770,7 @@ class TestFullNodeProtocol: full_node_1, full_node_2, server_1, server_2 = two_nodes wallet_a, wallet_receiver, blocks = wallet_blocks - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) await asyncio.sleep(2) # Allow connections to get made msgs = [_ async for _ in full_node_1.request_peers(fnp.RequestPeers())] @@ -787,9 +783,7 @@ class TestWalletProtocol: full_node_1, full_node_2, server_1, server_2 = two_nodes wallet_a, wallet_receiver, blocks = wallet_blocks - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) blocks_list = await get_block_path(full_node_1) blocks_new = bt.get_consecutive_blocks( @@ -981,9 +975,7 @@ class TestWalletProtocol: full_node_1, full_node_2, server_1, server_2 = two_nodes wallet_a, wallet_receiver, blocks = wallet_blocks - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) blocks_list = await get_block_path(full_node_1) blocks_new = bt.get_consecutive_blocks( test_constants, 5, seed=b"test_request_removals" @@ -1166,9 +1158,7 @@ class TestWalletProtocol: full_node_1, full_node_2, server_1, server_2 = two_nodes wallet_a, wallet_receiver, blocks = wallet_blocks - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) blocks_list = await get_block_path(full_node_1) blocks_new = bt.get_consecutive_blocks( test_constants, 5, seed=b"test_request_additions" diff --git a/tests/full_node/test_full_sync.py b/tests/full_node/test_full_sync.py index a46d8f47d15b..88650088b16b 100644 --- a/tests/full_node/test_full_sync.py +++ b/tests/full_node/test_full_sync.py @@ -33,9 +33,7 @@ class TestFullSync: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) await asyncio.sleep(2) # Allow connections to get made start = time.time() @@ -86,9 +84,7 @@ class TestFullSync: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) await asyncio.sleep(2) # Allow connections to get made start = time.time() diff --git a/tests/full_node/test_node_load.py b/tests/full_node/test_node_load.py index 14e71be25592..6891d4e293c2 100644 --- a/tests/full_node/test_node_load.py +++ b/tests/full_node/test_node_load.py @@ -35,9 +35,7 @@ class TestNodeLoad: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) await asyncio.sleep(2) # Allow connections to get made @@ -77,9 +75,7 @@ class TestNodeLoad: full_node_1, full_node_2, server_1, server_2 = two_nodes blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10) - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) await asyncio.sleep(2) # Allow connections to get made diff --git a/tests/full_node/test_transactions.py b/tests/full_node/test_transactions.py index 6a37226cd366..2b23b2a4dea3 100644 --- a/tests/full_node/test_transactions.py +++ b/tests/full_node/test_transactions.py @@ -45,12 +45,10 @@ class TestTransactions: async def test_wallet_coinbase(self, wallet_node): num_blocks = 10 full_node_1, wallet_node, server_1, server_2 = wallet_node - wallet = wallet_node.main_wallet + wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) for i in range(1, num_blocks): await full_node_1.farm_new_block(FarmNewBlockProtocol(ph)) @@ -74,23 +72,19 @@ class TestTransactions: full_node_1, server_1 = full_nodes[1] full_node_2, server_2 = full_nodes[2] - ph = await wallet_0.main_wallet.get_new_puzzlehash() - ph1 = await wallet_1.main_wallet.get_new_puzzlehash() + ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash() + ph1 = await wallet_1.wallet_state_manager.main_wallet.get_new_puzzlehash() # # wallet0 <-> sever0 <-> server1 <-> server2 <-> wallet1 # await wallet_server_0.start_client( - PeerInfo(server_0._host, uint16(server_0._port)), None - ) - await server_0.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) - await server_1.start_client( - PeerInfo(server_2._host, uint16(server_2._port)), None + PeerInfo("localhost", uint16(server_0._port)), None ) + await server_0.start_client(PeerInfo("localhost", uint16(server_1._port)), None) + await server_1.start_client(PeerInfo("localhost", uint16(server_2._port)), None) await wallet_server_1.start_client( - PeerInfo(server_2._host, uint16(server_2._port)), None + PeerInfo("localhost", uint16(server_2._port)), None ) for i in range(1, num_blocks): @@ -103,12 +97,15 @@ class TestTransactions: for i in range(1, num_blocks - 2) ] ) - assert await wallet_0.main_wallet.get_confirmed_balance() == funds + assert ( + await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance() + == funds + ) - spend_bundle = await wallet_0.main_wallet.generate_signed_transaction( + spend_bundle = await wallet_0.wallet_state_manager.main_wallet.generate_signed_transaction( 10, ph1, 0 ) - await wallet_0.main_wallet.push_transaction(spend_bundle) + await wallet_0.wallet_state_manager.main_wallet.push_transaction(spend_bundle) await asyncio.sleep(3) @@ -132,8 +129,14 @@ class TestTransactions: ] ) - assert await wallet_0.main_wallet.get_confirmed_balance() == funds - 10 - assert await wallet_1.main_wallet.get_confirmed_balance() == 10 + assert ( + await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance() + == funds - 10 + ) + assert ( + await wallet_1.wallet_state_manager.main_wallet.get_confirmed_balance() + == 10 + ) @pytest.mark.asyncio async def test_mempool_tx_sync(self, three_nodes_two_wallets): @@ -145,16 +148,14 @@ class TestTransactions: full_node_1, server_1 = full_nodes[1] full_node_2, server_2 = full_nodes[2] - ph = await wallet_0.main_wallet.get_new_puzzlehash() + ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash() # wallet0 <-> sever0 <-> server1 await wallet_server_0.start_client( - PeerInfo(server_0._host, uint16(server_0._port)), None - ) - await server_0.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None + PeerInfo("localhost", uint16(server_0._port)), None ) + await server_0.start_client(PeerInfo("localhost", uint16(server_1._port)), None) for i in range(1, num_blocks): await full_node_0.farm_new_block(FarmNewBlockProtocol(ph)) @@ -174,12 +175,15 @@ class TestTransactions: for i in range(1, num_blocks - 2) ] ) - assert await wallet_0.main_wallet.get_confirmed_balance() == funds + assert ( + await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance() + == funds + ) - spend_bundle = await wallet_0.main_wallet.generate_signed_transaction( + spend_bundle = await wallet_0.wallet_state_manager.main_wallet.generate_signed_transaction( 10, token_bytes(), 0 ) - await wallet_0.main_wallet.push_transaction(spend_bundle) + await wallet_0.wallet_state_manager.main_wallet.push_transaction(spend_bundle) await asyncio.sleep(2) @@ -194,9 +198,7 @@ class TestTransactions: # make a final connection. # wallet0 <-> sever0 <-> server1 <-> server2 - await server_1.start_client( - PeerInfo(server_2._host, uint16(server_2._port)), None - ) + await server_1.start_client(PeerInfo("localhost", uint16(server_2._port)), None) await asyncio.sleep(2) diff --git a/tests/rpc/test_rpc.py b/tests/rpc/test_rpc.py index d62a3cc1acfc..74f90fc1e83f 100644 --- a/tests/rpc/test_rpc.py +++ b/tests/rpc/test_rpc.py @@ -68,7 +68,7 @@ class TestRpc: assert len(await client.get_connections()) == 0 - await client.open_connection(server_2._host, server_2._port) + await client.open_connection("localhost", server_2._port) await asyncio.sleep(2) connections = await client.get_connections() assert len(connections) == 1 diff --git a/tests/setup_nodes.py b/tests/setup_nodes.py index 3be80172008e..420c972d8867 100644 --- a/tests/setup_nodes.py +++ b/tests/setup_nodes.py @@ -421,14 +421,14 @@ async def setup_full_system(dic={}): node2, node2_server = await node_iters[6].__anext__() await harvester_server.start_client( - PeerInfo(farmer_server._host, uint16(farmer_server._port)), None + PeerInfo("127.0.0.1", uint16(farmer_server._port)), None ) await farmer_server.start_client( - PeerInfo(node1_server._host, uint16(node1_server._port)), None + PeerInfo("127.0.0.1", uint16(node1_server._port)), None ) await timelord_server.start_client( - PeerInfo(node1_server._host, uint16(node1_server._port)), None + PeerInfo("127.0.0.1", uint16(node1_server._port)), None ) yield (node1, node2) diff --git a/tests/test_filter.py b/tests/test_filter.py index eae5438126ac..d4a72fb74603 100644 --- a/tests/test_filter.py +++ b/tests/test_filter.py @@ -33,7 +33,7 @@ class TestFilter: config = load_config("config.yaml", "wallet") key_config = {"wallet_sk": sk} full_node_1, wallet_node, server_1, server_2 = wallet_and_node - wallet = wallet_node.main_wallet + wallet = wallet_node.wallet_state_manager.main_wallet num_blocks = 2 ph = await wallet.get_new_puzzlehash() diff --git a/tests/wallet/test_puzzle_store.py b/tests/wallet/test_puzzle_store.py new file mode 100644 index 000000000000..df865b0d5996 --- /dev/null +++ b/tests/wallet/test_puzzle_store.py @@ -0,0 +1,99 @@ +import asyncio +from secrets import token_bytes +from pathlib import Path +from typing import Any, Dict +import sqlite3 +import random +import pytest +import aiosqlite +from blspy import PrivateKey +from src.full_node.store import FullNodeStore +from src.types.full_block import FullBlock +from src.types.sized_bytes import bytes32 +from src.util.ints import uint32, uint64 +from src.wallet.wallet_puzzle_store import WalletPuzzleStore +from src.wallet.derivation_record import DerivationRecord +from src.wallet.util.wallet_types import WalletType + + +@pytest.fixture(scope="module") +def event_loop(): + loop = asyncio.get_event_loop() + yield loop + + +class TestPuzzleStore: + @pytest.mark.asyncio + async def test_puzzle_store(self): + db_filename = Path("puzzle_store_test.db") + + if db_filename.exists(): + db_filename.unlink() + + con = await aiosqlite.connect(db_filename) + db = await WalletPuzzleStore.create(con) + try: + derivation_recs = [] + wallet_types = [t for t in WalletType] + + for i in range(1000): + derivation_recs.append( + DerivationRecord( + uint32(i), + token_bytes(32), + PrivateKey.from_seed(token_bytes(5)).get_public_key(), + WalletType.STANDARD_WALLET, + uint32(1), + ) + ) + derivation_recs.append( + DerivationRecord( + uint32(i), + token_bytes(32), + PrivateKey.from_seed(token_bytes(5)).get_public_key(), + WalletType.RATE_LIMITED, + uint32(2), + ) + ) + assert await db.puzzle_hash_exists(derivation_recs[0].puzzle_hash) == False + assert await db.index_for_pubkey(derivation_recs[0].pubkey) == None + assert ( + await db.index_for_puzzle_hash(derivation_recs[2].puzzle_hash) == None + ) + assert ( + await db.wallet_info_for_puzzle_hash(derivation_recs[2].puzzle_hash) + == None + ) + assert len((await db.get_all_puzzle_hashes())) == 0 + assert await db.get_last_derivation_path() == None + assert await db.get_unused_derivation_path() == None + assert await db.get_derivation_record(0, 2) == None + + await db.add_derivation_paths(derivation_recs) + + assert await db.puzzle_hash_exists(derivation_recs[0].puzzle_hash) == True + assert await db.index_for_pubkey(derivation_recs[4].pubkey) == 2 + assert await db.index_for_puzzle_hash(derivation_recs[2].puzzle_hash) == 1 + assert await db.wallet_info_for_puzzle_hash( + derivation_recs[2].puzzle_hash + ) == (derivation_recs[2].wallet_id, derivation_recs[2].wallet_type) + assert len((await db.get_all_puzzle_hashes())) == 2000 + assert await db.get_last_derivation_path() == 999 + assert await db.get_unused_derivation_path() == 0 + assert await db.get_derivation_record(0, 2) == derivation_recs[1] + + # Indeces up to 250 + await db.set_used_up_to(249) + + assert await db.get_unused_derivation_path() == 250 + + except Exception as e: + print(e, type(e)) + await db._clear_database() + await db.close() + db_filename.unlink() + raise e + + await db._clear_database() + await db.close() + db_filename.unlink() diff --git a/tests/wallet/test_wallet.py b/tests/wallet/test_wallet.py index fbece970f179..8dc2c5bb4721 100644 --- a/tests/wallet/test_wallet.py +++ b/tests/wallet/test_wallet.py @@ -52,12 +52,10 @@ class TestWalletSimulator: async def test_wallet_coinbase(self, wallet_node): num_blocks = 10 full_node_1, wallet_node, server_1, server_2 = wallet_node - wallet = wallet_node.main_wallet + wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) for i in range(1, num_blocks): await full_node_1.farm_new_block(FarmNewBlockProtocol(ph)) @@ -81,12 +79,10 @@ class TestWalletSimulator: server_2, server_3, ) = two_wallet_nodes - wallet = wallet_node.main_wallet + wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) for i in range(0, num_blocks): await full_node_1.farm_new_block(FarmNewBlockProtocol(ph)) @@ -104,7 +100,9 @@ class TestWalletSimulator: assert await wallet.get_unconfirmed_balance() == funds spend_bundle = await wallet.generate_signed_transaction( - 10, await wallet_node_2.main_wallet.get_new_puzzlehash(), 0 + 10, + await wallet_node_2.wallet_state_manager.main_wallet.get_new_puzzlehash(), + 0, ) await wallet.push_transaction(spend_bundle) @@ -137,12 +135,10 @@ class TestWalletSimulator: async def test_wallet_coinbase_reorg(self, wallet_node): num_blocks = 10 full_node_1, wallet_node, server_1, server_2 = wallet_node - wallet = wallet_node.main_wallet + wallet = wallet_node.wallet_state_manager.main_wallet ph = await wallet.get_new_puzzlehash() - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) for i in range(1, num_blocks): await full_node_1.farm_new_block(FarmNewBlockProtocol(ph)) @@ -179,11 +175,11 @@ class TestWalletSimulator: full_node_1, server_1 = full_nodes[1] full_node_2, server_2 = full_nodes[2] - ph = await wallet_0.main_wallet.get_new_puzzlehash() + ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash() # wallet0 <-> sever0 await wallet_server_0.start_client( - PeerInfo(server_0._host, uint16(server_0._port)), None + PeerInfo("localhost", uint16(server_0._port)), None ) for i in range(1, num_blocks): @@ -208,12 +204,15 @@ class TestWalletSimulator: for i in range(1, num_blocks - 2) ] ) - assert await wallet_0.main_wallet.get_confirmed_balance() == funds + assert ( + await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance() + == funds + ) - spend_bundle = await wallet_0.main_wallet.generate_signed_transaction( + spend_bundle = await wallet_0.wallet_state_manager.main_wallet.generate_signed_transaction( 10, token_bytes(), 0 ) - await wallet_0.main_wallet.push_transaction(spend_bundle) + await wallet_0.wallet_state_manager.main_wallet.push_transaction(spend_bundle) await asyncio.sleep(1) @@ -222,7 +221,7 @@ class TestWalletSimulator: # wallet0 <-> sever1 await wallet_server_0.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), wallet_0._on_connect + PeerInfo("localhost", uint16(server_1._port)), wallet_0._on_connect ) await asyncio.sleep(1) @@ -231,7 +230,7 @@ class TestWalletSimulator: # wallet0 <-> sever2 await wallet_server_0.start_client( - PeerInfo(server_2._host, uint16(server_2._port)), wallet_0._on_connect + PeerInfo("localhost", uint16(server_2._port)), wallet_0._on_connect ) await asyncio.sleep(1) @@ -249,16 +248,16 @@ class TestWalletSimulator: wallet_0_server, wallet_1_server, ) = two_wallet_nodes_five_freeze - wallet_0 = wallet_node_0.main_wallet - wallet_1 = wallet_node_1.main_wallet + wallet_0 = wallet_node_0.wallet_state_manager.main_wallet + wallet_1 = wallet_node_1.wallet_state_manager.main_wallet ph = await wallet_0.get_new_puzzlehash() await wallet_0_server.start_client( - PeerInfo(full_node_server._host, uint16(full_node_server._port)), None + PeerInfo("localhost", uint16(full_node_server._port)), None ) await wallet_1_server.start_client( - PeerInfo(full_node_server._host, uint16(full_node_server._port)), None + PeerInfo("localhost", uint16(full_node_server._port)), None ) for i in range(0, num_blocks): @@ -277,7 +276,9 @@ class TestWalletSimulator: assert await wallet_0.get_unconfirmed_balance() == funds spend_bundle = await wallet_0.generate_signed_transaction( - 10, await wallet_node_1.main_wallet.get_new_puzzlehash(), 0 + 10, + await wallet_node_1.wallet_state_manager.main_wallet.get_new_puzzlehash(), + 0, ) await wallet_0.push_transaction(spend_bundle) diff --git a/tests/wallet/test_wallet_sync.py b/tests/wallet/test_wallet_sync.py index 1068bb28db27..7c468da901b7 100644 --- a/tests/wallet/test_wallet_sync.py +++ b/tests/wallet/test_wallet_sync.py @@ -43,9 +43,7 @@ class TestWalletSync: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) start = time.time() found = False @@ -104,9 +102,7 @@ class TestWalletSync: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) start = time.time() found = False @@ -137,9 +133,7 @@ class TestWalletSync: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) start = time.time() while time.time() - start < 60: # The second node should eventually catch up to the first one, and have the @@ -159,7 +153,7 @@ class TestWalletSync: @pytest.mark.asyncio async def test_short_sync_with_transactions_wallet(self, wallet_node): full_node_1, wallet_node, server_1, server_2 = wallet_node - wallet_a = wallet_node.main_wallet + wallet_a = wallet_node.wallet_state_manager.main_wallet wallet_a_dummy = WalletTool() wallet_b = WalletTool() coinbase_puzzlehash = await wallet_a.get_new_puzzlehash() @@ -178,9 +172,7 @@ class TestWalletSync: ) ] await asyncio.sleep(2) - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) await asyncio.sleep(2) assert ( wallet_node.wallet_state_manager.block_records[ @@ -215,9 +207,7 @@ class TestWalletSync: pass # Do a short sync from 0 to 14 - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) start = time.time() broke = False while time.time() - start < 60: @@ -282,9 +272,7 @@ class TestWalletSync: pass # Do a sync from 0 to 22 - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) broke = False while time.time() - start < 60: @@ -359,9 +347,7 @@ class TestWalletSync: ): pass - await server_2.start_client( - PeerInfo(server_1._host, uint16(server_1._port)), None - ) + await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None) broke = False while time.time() - start < 60: