mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-09-20 16:08:51 +03:00
Refactor puzzle store, tests, generate future puzzle hashes
This commit is contained in:
parent
a6126b28f2
commit
b5ebf6c494
@ -106,7 +106,6 @@ function set_callbacks(socket) {
|
||||
|
||||
if (command == "start_server") {
|
||||
get_wallets();
|
||||
get_new_puzzlehash();
|
||||
get_transactions();
|
||||
get_wallet_balance(g_wallet_id);
|
||||
get_height_info();
|
||||
@ -266,6 +265,11 @@ copy.addEventListener("click", () => {
|
||||
})
|
||||
|
||||
async function get_new_puzzlehash() {
|
||||
if (global_syncing) {
|
||||
alert("Cannot create address while syncing.")
|
||||
return;
|
||||
}
|
||||
|
||||
/*
|
||||
Sends websocket request for new puzzle_hash
|
||||
*/
|
||||
|
@ -4,7 +4,6 @@ echo "Starting local blockchain simulation. Runs a local introducer and chia sys
|
||||
echo "Note that this simulation will not work if connected to external nodes."
|
||||
|
||||
# Starts a harvester, farmer, timelord, introducer, and 3 full nodes, locally.
|
||||
# Make sure to point the full node in config/config.yaml to the local introducer: 127.0.0.1:8445.
|
||||
# Please note that the simulation is meant to be run locally and not connected to external nodes.
|
||||
|
||||
_run_bg_cmd python -m src.server.start_harvester --logging.log_stdout=True
|
||||
|
@ -240,7 +240,7 @@ async def start_rpc_server(
|
||||
|
||||
runner = web.AppRunner(app, access_log=None)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, None, int(rpc_port))
|
||||
site = web.TCPSite(runner, "localhost", int(rpc_port))
|
||||
await site.start()
|
||||
|
||||
async def cleanup():
|
||||
|
@ -71,10 +71,9 @@ async def spawn_process(host, port, counter):
|
||||
|
||||
async def spawn_all_processes():
|
||||
await asyncio.sleep(15)
|
||||
host = config["host"]
|
||||
port = config["port"]
|
||||
process_count = config["process_count"]
|
||||
awaitables = [spawn_process(host, port, i) for i in range(process_count)]
|
||||
awaitables = [spawn_process("127.0.0.1", port, i) for i in range(process_count)]
|
||||
await asyncio.gather(*awaitables)
|
||||
|
||||
|
||||
|
@ -149,7 +149,7 @@ wallet:
|
||||
port: 8444
|
||||
|
||||
testing: False
|
||||
database_path: wallet/db/blockchain_wallet_v4.db
|
||||
database_path: wallet/db/blockchain_wallet_v5.db
|
||||
|
||||
logging: *logging
|
||||
|
||||
|
22
src/wallet/derivation_record.py
Normal file
22
src/wallet/derivation_record.py
Normal file
@ -0,0 +1,22 @@
|
||||
from dataclasses import dataclass
|
||||
from blspy import PublicKey
|
||||
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.util.streamable import Streamable, streamable
|
||||
from src.util.ints import uint32
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@streamable
|
||||
class DerivationRecord(Streamable):
|
||||
"""
|
||||
These are records representing a puzzle hash, which is generated from a
|
||||
public key, derivation index, and wallet type. Stored in the puzzle_store.
|
||||
"""
|
||||
|
||||
index: uint32
|
||||
puzzle_hash: bytes32
|
||||
pubkey: PublicKey
|
||||
wallet_type: WalletType
|
||||
wallet_id: uint32
|
@ -4,7 +4,7 @@ import logging
|
||||
from binascii import hexlify
|
||||
from dataclasses import dataclass
|
||||
from secrets import token_bytes
|
||||
from typing import Dict, Optional, List, Tuple
|
||||
from typing import Dict, Optional, List, Tuple, Any
|
||||
|
||||
import clvm
|
||||
import json
|
||||
@ -18,7 +18,7 @@ from src.types.coin_solution import CoinSolution
|
||||
from src.types.program import Program
|
||||
from src.types.spend_bundle import SpendBundle
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.util.ints import uint64
|
||||
from src.util.ints import uint64, uint32
|
||||
from src.util.streamable import streamable, Streamable
|
||||
from src.wallet.rl_wallet.rl_wallet_puzzles import (
|
||||
rl_puzzle_for_pk,
|
||||
@ -32,7 +32,7 @@ from src.wallet.util.wallet_types import WalletType
|
||||
from src.wallet.wallet import Wallet
|
||||
from src.wallet.wallet_coin_record import WalletCoinRecord
|
||||
from src.wallet.wallet_info import WalletInfo
|
||||
from src.wallet.wallet_state_manager import WalletStateManager
|
||||
from src.wallet.derivation_record import DerivationRecord
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -53,7 +53,7 @@ class RLWallet:
|
||||
key_config: Dict
|
||||
config: Dict
|
||||
server: Optional[ChiaServer]
|
||||
wallet_state_manager: WalletStateManager
|
||||
wallet_state_manager: Any
|
||||
log: logging.Logger
|
||||
wallet_info: WalletInfo
|
||||
rl_coin_record: WalletCoinRecord
|
||||
@ -64,19 +64,21 @@ class RLWallet:
|
||||
async def create_rl_admin(
|
||||
config: Dict,
|
||||
key_config: Dict,
|
||||
wallet_state_manager: WalletStateManager,
|
||||
wallet_state_manager: Any,
|
||||
wallet: Wallet,
|
||||
name: str = None,
|
||||
):
|
||||
current_max_index = (
|
||||
await wallet_state_manager.puzzle_store.get_max_derivation_path() + 1
|
||||
)
|
||||
my_pubkey_index = current_max_index + 1
|
||||
unused: Optional[
|
||||
uint32
|
||||
] = await wallet_state_manager.puzzle_store.get_unused_derivation_path()
|
||||
if unused is None:
|
||||
await wallet_state_manager.create_more_puzzle_hashes()
|
||||
unused = await wallet_state_manager.puzzle_store.get_unused_derivation_path()
|
||||
assert unused is not None
|
||||
|
||||
sk_hex = key_config["wallet_sk"]
|
||||
private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex))
|
||||
pubkey_bytes: bytes = bytes(
|
||||
private_key.public_child(my_pubkey_index).get_public_key()
|
||||
)
|
||||
pubkey_bytes: bytes = bytes(private_key.public_child(unused).get_public_key())
|
||||
|
||||
rl_info = RLInfo("admin", pubkey_bytes, None, None, None, None, None, None)
|
||||
info_as_string = json.dumps(rl_info.to_json_dict())
|
||||
@ -86,13 +88,19 @@ class RLWallet:
|
||||
wallet_info = await wallet_state_manager.user_store.get_last_wallet()
|
||||
if wallet_info is None:
|
||||
raise
|
||||
await wallet_state_manager.puzzle_store.add_derivation_path_of_interest(
|
||||
my_pubkey_index,
|
||||
token_bytes(),
|
||||
pubkey_bytes,
|
||||
WalletType.RATE_LIMITED,
|
||||
wallet_info.id,
|
||||
|
||||
await wallet_state_manager.puzzle_store.add_derivation_paths(
|
||||
[
|
||||
DerivationRecord(
|
||||
unused,
|
||||
token_bytes(),
|
||||
pubkey_bytes,
|
||||
WalletType.RATE_LIMITED,
|
||||
wallet_info.id,
|
||||
)
|
||||
]
|
||||
)
|
||||
await wallet_state_manager.puzzle_store.set_used_up_to(unused)
|
||||
|
||||
self = await RLWallet.create(
|
||||
config, key_config, wallet_state_manager, wallet_info, wallet, name
|
||||
@ -103,46 +111,60 @@ class RLWallet:
|
||||
async def create_rl_user(
|
||||
config: Dict,
|
||||
key_config: Dict,
|
||||
wallet_state_manager: WalletStateManager,
|
||||
wallet_state_manager: Any,
|
||||
wallet: Wallet,
|
||||
name: str = None,
|
||||
):
|
||||
current_max_index = (
|
||||
await wallet_state_manager.puzzle_store.get_max_derivation_path() + 1
|
||||
)
|
||||
my_pubkey_index = current_max_index + 1
|
||||
sk_hex = key_config["wallet_sk"]
|
||||
private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex))
|
||||
pubkey_bytes: bytes = bytes(
|
||||
private_key.public_child(my_pubkey_index).get_public_key()
|
||||
)
|
||||
async with wallet_state_manager.puzzle_store.lock:
|
||||
unused: Optional[
|
||||
uint32
|
||||
] = await wallet_state_manager.puzzle_store.get_unused_derivation_path()
|
||||
if unused is None:
|
||||
await wallet_state_manager.create_more_puzzle_hashes()
|
||||
unused = (
|
||||
await wallet_state_manager.puzzle_store.get_unused_derivation_path()
|
||||
)
|
||||
assert unused is not None
|
||||
|
||||
rl_info = RLInfo("user", None, pubkey_bytes, None, None, None, None, None)
|
||||
info_as_string = json.dumps(rl_info.to_json_dict())
|
||||
await wallet_state_manager.user_store.create_wallet(
|
||||
"RL User", WalletType.RATE_LIMITED, info_as_string
|
||||
)
|
||||
wallet_info = await wallet_state_manager.user_store.get_last_wallet()
|
||||
if wallet_info is None:
|
||||
raise
|
||||
await wallet_state_manager.puzzle_store.add_derivation_path_of_interest(
|
||||
my_pubkey_index,
|
||||
token_bytes(),
|
||||
pubkey_bytes,
|
||||
WalletType.RATE_LIMITED,
|
||||
wallet_info.id,
|
||||
)
|
||||
sk_hex = key_config["wallet_sk"]
|
||||
private_key = ExtendedPrivateKey.from_bytes(bytes.fromhex(sk_hex))
|
||||
pubkey_bytes: bytes = bytes(
|
||||
private_key.public_child(unused).get_public_key()
|
||||
)
|
||||
|
||||
rl_info = RLInfo("user", None, pubkey_bytes, None, None, None, None, None)
|
||||
info_as_string = json.dumps(rl_info.to_json_dict())
|
||||
await wallet_state_manager.user_store.create_wallet(
|
||||
"RL User", WalletType.RATE_LIMITED, info_as_string
|
||||
)
|
||||
wallet_info = await wallet_state_manager.user_store.get_last_wallet()
|
||||
if wallet_info is None:
|
||||
raise
|
||||
|
||||
self = await RLWallet.create(
|
||||
config, key_config, wallet_state_manager, wallet_info, wallet, name
|
||||
)
|
||||
|
||||
await wallet_state_manager.puzzle_store.add_derivation_paths(
|
||||
[
|
||||
DerivationRecord(
|
||||
unused,
|
||||
token_bytes(),
|
||||
pubkey_bytes,
|
||||
WalletType.RATE_LIMITED,
|
||||
wallet_info.id,
|
||||
)
|
||||
]
|
||||
)
|
||||
await wallet_state_manager.puzzle_store.set_used_up_to(unused)
|
||||
|
||||
self = await RLWallet.create(
|
||||
config, key_config, wallet_state_manager, wallet_info, wallet, name
|
||||
)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
config: Dict,
|
||||
key_config: Dict,
|
||||
wallet_state_manager: WalletStateManager,
|
||||
wallet_state_manager: Any,
|
||||
info: WalletInfo,
|
||||
wallet: Wallet,
|
||||
name: str = None,
|
||||
@ -194,13 +216,16 @@ class RLWallet:
|
||||
index = await self.wallet_state_manager.puzzle_store.index_for_pubkey(
|
||||
self.rl_info.admin_pubkey.hex()
|
||||
)
|
||||
await self.wallet_state_manager.puzzle_store.add_derivation_path_of_interest(
|
||||
|
||||
assert index is not None
|
||||
record = DerivationRecord(
|
||||
index,
|
||||
rl_puzzle_hash,
|
||||
self.rl_info.admin_pubkey,
|
||||
WalletType.RATE_LIMITED,
|
||||
self.wallet_info.id,
|
||||
)
|
||||
await self.wallet_state_manager.puzzle_store.add_derivation_paths([record])
|
||||
|
||||
spend_bundle = await self.standard_wallet.generate_signed_transaction(
|
||||
amount, rl_puzzle_hash, uint64(0), origin_id, coins
|
||||
@ -263,13 +288,15 @@ class RLWallet:
|
||||
index = await self.wallet_state_manager.puzzle_store.index_for_pubkey(
|
||||
self.rl_info.user_pubkey.hex()
|
||||
)
|
||||
await self.wallet_state_manager.puzzle_store.add_derivation_path_of_interest(
|
||||
assert index is not None
|
||||
record = DerivationRecord(
|
||||
index,
|
||||
rl_puzzle_hash,
|
||||
self.rl_info.user_pubkey,
|
||||
WalletType.RATE_LIMITED,
|
||||
self.wallet_info.id,
|
||||
)
|
||||
await self.wallet_state_manager.puzzle_store.add_derivation_paths([record])
|
||||
|
||||
data_str = json.dumps(new_rl_info.to_json_dict())
|
||||
new_wallet_info = WalletInfo(
|
||||
|
@ -26,7 +26,7 @@ class TransactionRecord(Streamable):
|
||||
spend_bundle: Optional[SpendBundle]
|
||||
additions: List[Coin]
|
||||
removals: List[Coin]
|
||||
wallet_id: uint64
|
||||
wallet_id: uint32
|
||||
|
||||
# Represents the list of peers that we sent the transaction to, whether each one
|
||||
# included it in the mempool, and what the error message (if any) was
|
||||
|
@ -24,19 +24,16 @@ from src.wallet.puzzles.puzzle_utils import (
|
||||
make_assert_coin_consumed_condition,
|
||||
make_create_coin_condition,
|
||||
)
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
from src.wallet.wallet_coin_record import WalletCoinRecord
|
||||
from src.wallet.transaction_record import TransactionRecord
|
||||
from src.wallet.wallet_info import WalletInfo
|
||||
|
||||
from src.wallet.wallet_state_manager import WalletStateManager
|
||||
|
||||
|
||||
class Wallet:
|
||||
private_key: ExtendedPrivateKey
|
||||
key_config: Dict
|
||||
config: Dict
|
||||
wallet_state_manager: WalletStateManager
|
||||
wallet_state_manager: Any
|
||||
|
||||
log: logging.Logger
|
||||
|
||||
@ -46,7 +43,7 @@ class Wallet:
|
||||
async def create(
|
||||
config: Dict,
|
||||
key_config: Dict,
|
||||
wallet_state_manager: WalletStateManager,
|
||||
wallet_state_manager: Any,
|
||||
info: WalletInfo,
|
||||
name: str = None,
|
||||
):
|
||||
@ -87,17 +84,11 @@ class Wallet:
|
||||
return puzzle_for_pk(pubkey)
|
||||
|
||||
async def get_new_puzzlehash(self) -> bytes32:
|
||||
index = await self.wallet_state_manager.puzzle_store.get_max_derivation_path()
|
||||
index += 1
|
||||
pubkey: bytes = bytes(self.get_public_key(index))
|
||||
puzzle: Program = self.puzzle_for_pk(pubkey)
|
||||
puzzlehash: bytes32 = puzzle.get_hash()
|
||||
|
||||
await self.wallet_state_manager.puzzle_store.add_derivation_path_of_interest(
|
||||
index, puzzlehash, pubkey, WalletType.STANDARD_WALLET, self.wallet_info.id
|
||||
)
|
||||
|
||||
return puzzlehash
|
||||
return (
|
||||
await self.wallet_state_manager.get_unused_derivation_record(
|
||||
self.wallet_info.id
|
||||
)
|
||||
).puzzle_hash
|
||||
|
||||
def make_solution(self, primaries=None, min_time=0, me=None, consumed=None):
|
||||
condition_list = []
|
||||
|
@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
|
||||
from src.util.streamable import streamable, Streamable
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
from src.util.ints import uint32
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -11,7 +12,7 @@ class WalletInfo(Streamable):
|
||||
# TODO(straya): describe
|
||||
"""
|
||||
|
||||
id: int
|
||||
id: uint32
|
||||
name: str
|
||||
type: WalletType
|
||||
data: str
|
||||
|
@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple, List, Any, AsyncGenerator
|
||||
from typing import Dict, Optional, Tuple, List, AsyncGenerator
|
||||
import concurrent
|
||||
import random
|
||||
import logging
|
||||
@ -21,11 +21,7 @@ from src.server.outbound_message import OutboundMessage, NodeType, Message, Deli
|
||||
from src.util.ints import uint32, uint64
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.util.api_decorators import api_request
|
||||
from src.wallet.rl_wallet.rl_wallet import RLWallet
|
||||
from src.wallet.transaction_record import TransactionRecord
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
from src.wallet.wallet import Wallet
|
||||
from src.wallet.wallet_info import WalletInfo
|
||||
from src.wallet.wallet_state_manager import WalletStateManager
|
||||
from src.wallet.block_record import BlockRecord
|
||||
from src.types.header_block import HeaderBlock
|
||||
@ -47,8 +43,6 @@ class WalletNode:
|
||||
|
||||
# Maintains the state of the wallet (blockchain and transactions), handles DB connections
|
||||
wallet_state_manager: WalletStateManager
|
||||
main_wallet: Wallet
|
||||
wallets: Dict[int, Any]
|
||||
|
||||
# Maintains headers recently received. Once the desired removals and additions are downloaded,
|
||||
# the data is persisted in the WalletStateManager. These variables are also used to store
|
||||
@ -95,44 +89,10 @@ class WalletNode:
|
||||
mkdir(path.parent)
|
||||
|
||||
self.wallet_state_manager = await WalletStateManager.create(
|
||||
config, path, self.constants
|
||||
key_config, config, path, self.constants
|
||||
)
|
||||
self.wallet_state_manager.set_pending_callback(self.pending_tx_handler)
|
||||
|
||||
main_wallet_info = await self.wallet_state_manager.get_main_wallet()
|
||||
assert main_wallet_info is not None
|
||||
|
||||
self.main_wallet = await Wallet.create(
|
||||
config, key_config, self.wallet_state_manager, main_wallet_info
|
||||
)
|
||||
|
||||
wallets: List[WalletInfo] = await self.wallet_state_manager.get_all_wallets()
|
||||
self.wallets = {}
|
||||
main_wallet = await Wallet.create(
|
||||
config, key_config, self.wallet_state_manager, main_wallet_info
|
||||
)
|
||||
self.main_wallet = main_wallet
|
||||
self.wallets[main_wallet_info.id] = main_wallet
|
||||
|
||||
for wallet_info in wallets:
|
||||
self.log.info(f"wallet_info {wallet_info}")
|
||||
if wallet_info.type == WalletType.STANDARD_WALLET:
|
||||
if wallet_info.id == 1:
|
||||
continue
|
||||
wallet = await Wallet.create(
|
||||
config, key_config, self.wallet_state_manager, main_wallet_info
|
||||
)
|
||||
self.wallets[wallet_info.id] = wallet
|
||||
elif wallet_info.type == WalletType.RATE_LIMITED:
|
||||
wallet = await RLWallet.create(
|
||||
config,
|
||||
key_config,
|
||||
self.wallet_state_manager,
|
||||
wallet_info,
|
||||
self.main_wallet,
|
||||
)
|
||||
self.wallets[wallet_info.id] = wallet
|
||||
|
||||
# Normal operation data
|
||||
self.cached_blocks = {}
|
||||
self.future_block_hashes = {}
|
||||
|
@ -1,9 +1,14 @@
|
||||
import asyncio
|
||||
from typing import Set, Tuple, Optional
|
||||
from blspy import PublicKey
|
||||
from typing import Set, Tuple, Optional, List
|
||||
import aiosqlite
|
||||
import logging
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.util.ints import uint32, uint64
|
||||
from src.util.ints import uint32
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
from src.wallet.derivation_record import DerivationRecord
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WalletPuzzleStore:
|
||||
@ -28,14 +33,17 @@ class WalletPuzzleStore:
|
||||
await self.db_connection.execute(
|
||||
(
|
||||
f"CREATE TABLE IF NOT EXISTS derivation_paths("
|
||||
f"id int PRIMARY KEY,"
|
||||
f"derivation_index int,"
|
||||
f" pubkey text,"
|
||||
f" puzzle_hash text,"
|
||||
f" puzzle_hash text PRIMARY_KEY,"
|
||||
f" wallet_type int,"
|
||||
f" wallet_id int,"
|
||||
f" used int)"
|
||||
f" used tinyint)"
|
||||
)
|
||||
)
|
||||
await self.db_connection.execute(
|
||||
"CREATE INDEX IF NOT EXISTS derivation_index_index on derivation_paths(derivation_index)"
|
||||
)
|
||||
|
||||
await self.db_connection.execute(
|
||||
"CREATE INDEX IF NOT EXISTS ph on derivation_paths(puzzle_hash)"
|
||||
@ -74,26 +82,66 @@ class WalletPuzzleStore:
|
||||
await cursor.close()
|
||||
await self.db_connection.commit()
|
||||
|
||||
async def add_derivation_path_of_interest(
|
||||
self,
|
||||
index: int,
|
||||
puzzlehash: bytes32,
|
||||
pubkey: bytes,
|
||||
wallet_type: WalletType,
|
||||
wallet_id: int,
|
||||
):
|
||||
async def add_derivation_paths(self, records: List[DerivationRecord]) -> None:
|
||||
"""
|
||||
Inserts new derivation path, puzzle, pubkey, wallet into DB.
|
||||
Insert many derivation paths into the database.
|
||||
"""
|
||||
sql_records = []
|
||||
for record in records:
|
||||
sql_records.append(
|
||||
(
|
||||
record.index,
|
||||
bytes(record.pubkey).hex(),
|
||||
record.puzzle_hash.hex(),
|
||||
record.wallet_type.value,
|
||||
record.wallet_id,
|
||||
0,
|
||||
),
|
||||
)
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
cursor = await self.db_connection.executemany(
|
||||
"INSERT OR REPLACE INTO derivation_paths VALUES(?, ?, ?, ?, ?, ?)",
|
||||
(index, pubkey.hex(), puzzlehash.hex(), wallet_type.value, wallet_id, 0),
|
||||
sql_records,
|
||||
)
|
||||
|
||||
await cursor.close()
|
||||
await self.db_connection.commit()
|
||||
|
||||
async def get_derivation_record(
|
||||
self, index: uint32, wallet_id: uint32
|
||||
) -> Optional[DerivationRecord]:
|
||||
"""
|
||||
Returns the derivation record by index and wallet id.
|
||||
"""
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT * FROM derivation_paths WHERE derivation_index=? and wallet_id=?;",
|
||||
(index, wallet_id,),
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if row is not None and row[0] is not None:
|
||||
return DerivationRecord(
|
||||
row[0],
|
||||
bytes.fromhex(row[2]),
|
||||
PublicKey.from_bytes(bytes.fromhex(row[1])),
|
||||
row[3],
|
||||
row[4],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
async def set_used_up_to(self, index: uint32) -> None:
|
||||
"""
|
||||
Sets a derivation path to used so we don't use it again.
|
||||
"""
|
||||
pass
|
||||
cursor = await self.db_connection.execute(
|
||||
"UPDATE derivation_paths SET used=1 WHERE derivation_index<=?", (index,),
|
||||
)
|
||||
await cursor.close()
|
||||
await self.db_connection.commit()
|
||||
|
||||
async def puzzle_hash_exists(self, puzzle_hash: bytes32) -> bool:
|
||||
"""
|
||||
Checks if passed puzzle_hash is present in the db.
|
||||
@ -102,54 +150,50 @@ class WalletPuzzleStore:
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT * from derivation_paths WHERE puzzle_hash=?", (puzzle_hash.hex(),)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if len(list(rows)) > 0:
|
||||
return True
|
||||
return row is not None
|
||||
|
||||
return False
|
||||
|
||||
async def index_for_pubkey(self, pubkey: str) -> int:
|
||||
async def index_for_pubkey(self, pubkey: PublicKey) -> Optional[uint32]:
|
||||
"""
|
||||
Returns derivation path for the given pubkey.
|
||||
Returns -1 if not present.
|
||||
Returns derivation paths for the given pubkey.
|
||||
Returns None if not present.
|
||||
"""
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT * from derivation_paths WHERE pubkey=?", (pubkey,)
|
||||
"SELECT * from derivation_paths WHERE pubkey=?", (bytes(pubkey).hex(),)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if row:
|
||||
return row[0]
|
||||
if row is not None:
|
||||
return uint32(row[0])
|
||||
|
||||
return -1
|
||||
return None
|
||||
|
||||
async def index_for_puzzle_hash(self, puzzle_hash: bytes32) -> int:
|
||||
async def index_for_puzzle_hash(self, puzzle_hash: bytes32) -> Optional[uint32]:
|
||||
"""
|
||||
Returns the derivation path for the puzzle_hash.
|
||||
Returns -1 if not present.
|
||||
Returns None if not present.
|
||||
"""
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT * from derivation_paths WHERE puzzle_hash=?", (puzzle_hash.hex(),)
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if row:
|
||||
return row[0]
|
||||
if row is not None:
|
||||
return uint32(row[0])
|
||||
|
||||
return -1
|
||||
return None
|
||||
|
||||
async def wallet_info_for_puzzle_hash(
|
||||
self, puzzle_hash: bytes32
|
||||
) -> Optional[Tuple[uint64, WalletType]]:
|
||||
) -> Optional[Tuple[uint32, WalletType]]:
|
||||
"""
|
||||
Returns the derivation path for the puzzle_hash.
|
||||
Returns -1 if not present.
|
||||
Returns None if not present.
|
||||
"""
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
@ -158,7 +202,7 @@ class WalletPuzzleStore:
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if row:
|
||||
if row is not None:
|
||||
return row[4], WalletType(row[3])
|
||||
|
||||
return None
|
||||
@ -178,18 +222,33 @@ class WalletPuzzleStore:
|
||||
|
||||
return result
|
||||
|
||||
async def get_max_derivation_path(self):
|
||||
async def get_last_derivation_path(self) -> Optional[uint32]:
|
||||
"""
|
||||
Returns the highest derivation path currently stored.
|
||||
Returns the last derivation path by derivation_index.
|
||||
"""
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT MAX(id) FROM derivation_paths;"
|
||||
"SELECT MAX(derivation_index) FROM derivation_paths;"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if row[0] is not None:
|
||||
return row[0]
|
||||
if row is not None and row[0] is not None:
|
||||
return uint32(row[0])
|
||||
|
||||
return -1
|
||||
return None
|
||||
|
||||
async def get_unused_derivation_path(self) -> Optional[uint32]:
|
||||
"""
|
||||
Returns the first unused derivation path by derivation_index.
|
||||
"""
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT MIN(derivation_index) FROM derivation_paths WHERE used=0;"
|
||||
)
|
||||
row = await cursor.fetchone()
|
||||
await cursor.close()
|
||||
|
||||
if row is not None and row[0] is not None:
|
||||
return uint32(row[0])
|
||||
|
||||
return None
|
||||
|
@ -1,12 +1,13 @@
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
from typing import Dict, Optional, List, Set, Tuple, Callable
|
||||
from typing import Dict, Optional, List, Set, Tuple, Callable, Any
|
||||
import logging
|
||||
import asyncio
|
||||
|
||||
import aiosqlite
|
||||
from chiabip158 import PyBIP158
|
||||
from blspy import PublicKey
|
||||
|
||||
from src.types.coin import Coin
|
||||
from src.types.spend_bundle import SpendBundle
|
||||
@ -30,6 +31,11 @@ from src.util.significant_bits import truncate_to_significant_bits
|
||||
from src.wallet.wallet_user_store import WalletUserStore
|
||||
from src.types.mempool_inclusion_status import MempoolInclusionStatus
|
||||
from src.util.errors import Err
|
||||
from src.wallet.wallet import Wallet
|
||||
from src.wallet.rl_wallet.rl_wallet import RLWallet
|
||||
from src.types.program import Program
|
||||
from src.wallet.derivation_record import DerivationRecord
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
|
||||
|
||||
class WalletStateManager:
|
||||
@ -64,9 +70,16 @@ class WalletStateManager:
|
||||
db_path: Path
|
||||
db_connection: aiosqlite.Connection
|
||||
|
||||
main_wallet: Wallet
|
||||
wallets: Dict[uint32, Any]
|
||||
|
||||
@staticmethod
|
||||
async def create(
|
||||
config: Dict, db_path: Path, constants: Dict, name: str = None,
|
||||
key_config: Dict,
|
||||
config: Dict,
|
||||
db_path: Path,
|
||||
constants: Dict,
|
||||
name: str = None,
|
||||
):
|
||||
self = WalletStateManager()
|
||||
self.config = config
|
||||
@ -94,6 +107,33 @@ class WalletStateManager:
|
||||
self.difficulty_resets_prev = {}
|
||||
self.db_path = db_path
|
||||
|
||||
main_wallet_info = await self.user_store.get_wallet_by_id(1)
|
||||
assert main_wallet_info is not None
|
||||
|
||||
self.main_wallet = await Wallet.create(
|
||||
config, key_config, self, main_wallet_info
|
||||
)
|
||||
|
||||
self.wallets = {}
|
||||
main_wallet = await Wallet.create(config, key_config, self, main_wallet_info)
|
||||
self.wallets[main_wallet_info.id] = main_wallet
|
||||
|
||||
for wallet_info in await self.get_all_wallets():
|
||||
self.log.info(f"wallet_info {wallet_info}")
|
||||
if wallet_info.type == WalletType.STANDARD_WALLET:
|
||||
if wallet_info.id == 1:
|
||||
continue
|
||||
wallet = await Wallet.create(config, key_config, self, main_wallet_info)
|
||||
self.wallets[wallet_info.id] = wallet
|
||||
elif wallet_info.type == WalletType.RATE_LIMITED:
|
||||
wallet = await RLWallet.create(
|
||||
config, key_config, self, wallet_info, self.main_wallet,
|
||||
)
|
||||
self.wallets[wallet_info.id] = wallet
|
||||
|
||||
async with self.puzzle_store.lock:
|
||||
await self.create_more_puzzle_hashes()
|
||||
|
||||
if len(self.block_records) > 0:
|
||||
# Initializes the state based on the DB block records
|
||||
# Header hash with the highest weight
|
||||
@ -135,8 +175,80 @@ class WalletStateManager:
|
||||
),
|
||||
genesis_hb,
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
async def create_more_puzzle_hashes(self):
|
||||
"""
|
||||
For all wallets in the user store, generates the first few puzzle hashes so
|
||||
that we can restore the wallet from only the private keys.
|
||||
"""
|
||||
for wallet_id in self.wallets.keys():
|
||||
target_wallet = self.wallets[wallet_id]
|
||||
unused: Optional[
|
||||
uint32
|
||||
] = await self.puzzle_store.get_unused_derivation_path()
|
||||
last: Optional[uint32] = await self.puzzle_store.get_last_derivation_path()
|
||||
|
||||
to_generate = 200
|
||||
start_index = 0
|
||||
derivation_paths: List[DerivationRecord] = []
|
||||
|
||||
if last is None:
|
||||
assert unused is None
|
||||
if unused is not None:
|
||||
assert last is not None
|
||||
start_index = last + 1
|
||||
to_generate -= last - unused
|
||||
|
||||
for index in range(start_index, start_index + to_generate):
|
||||
pubkey: PublicKey = target_wallet.get_public_key(index)
|
||||
puzzle: Program = target_wallet.puzzle_for_pk(bytes(pubkey))
|
||||
puzzlehash: bytes32 = puzzle.get_hash()
|
||||
self.log.info(
|
||||
f"Generating public key at index {index} puzzle hash {puzzlehash.hex()}"
|
||||
)
|
||||
derivation_paths.append(
|
||||
DerivationRecord(
|
||||
uint32(index),
|
||||
puzzlehash,
|
||||
pubkey,
|
||||
WalletType.STANDARD_WALLET,
|
||||
uint32(target_wallet.wallet_info.id),
|
||||
)
|
||||
)
|
||||
|
||||
await self.puzzle_store.add_derivation_paths(derivation_paths)
|
||||
|
||||
async def get_unused_derivation_record(self, wallet_id: uint32) -> DerivationRecord:
|
||||
"""
|
||||
Creates a puzzle hash for the given wallet, and then makes more puzzle hashes
|
||||
for every wallet to ensure we always have more in the database. Never reusue the
|
||||
same public key more than once (for privacy).
|
||||
"""
|
||||
async with self.puzzle_store.lock:
|
||||
# If we have no unused public keys, we will create new ones
|
||||
unused: Optional[
|
||||
uint32
|
||||
] = await self.puzzle_store.get_unused_derivation_path()
|
||||
if unused is None:
|
||||
await self.create_more_puzzle_hashes()
|
||||
|
||||
# Now we must have unused public keys
|
||||
unused = await self.puzzle_store.get_unused_derivation_path()
|
||||
assert unused is not None
|
||||
record: Optional[
|
||||
DerivationRecord
|
||||
] = await self.puzzle_store.get_derivation_record(unused, wallet_id)
|
||||
assert record is not None
|
||||
|
||||
# Set this key to used so we never use it again
|
||||
await self.puzzle_store.set_used_up_to(record.index)
|
||||
|
||||
# Create more puzzle hashes / keys
|
||||
await self.create_more_puzzle_hashes()
|
||||
return record
|
||||
|
||||
def set_callback(self, callback: Callable):
|
||||
"""
|
||||
Callback to be called when the state of the wallet changes.
|
||||
@ -302,8 +414,6 @@ class WalletStateManager:
|
||||
if unconfirmed_record:
|
||||
await self.tx_store.set_confirmed(unconfirmed_record.name(), index)
|
||||
|
||||
self.state_changed("coin_removed")
|
||||
|
||||
async def coin_added(self, coin: Coin, index: uint32, coinbase: bool):
|
||||
"""
|
||||
Adding coin to the db
|
||||
@ -536,6 +646,18 @@ class WalletStateManager:
|
||||
self.block_records[block.header_hash] = block
|
||||
await self.wallet_store.add_block_record(block, False)
|
||||
|
||||
max_puzzle_index = uint32(0)
|
||||
async with self.puzzle_store.lock:
|
||||
for addition in block.additions:
|
||||
index = await self.puzzle_store.index_for_puzzle_hash(
|
||||
addition.puzzle_hash
|
||||
)
|
||||
assert index is not None
|
||||
if index > max_puzzle_index:
|
||||
max_puzzle_index = index
|
||||
await self.puzzle_store.set_used_up_to(max_puzzle_index)
|
||||
await self.create_more_puzzle_hashes()
|
||||
|
||||
# Genesis case
|
||||
if self.lca is None:
|
||||
assert block.height == 0
|
||||
@ -545,6 +667,8 @@ class WalletStateManager:
|
||||
await self.coin_added(coin, block.height, False)
|
||||
for coin_name in block.removals:
|
||||
await self.coin_removed(coin_name, block.height)
|
||||
self.state_changed("coin_added")
|
||||
self.state_changed("coin_removed")
|
||||
self.height_to_hash[uint32(0)] = block.header_hash
|
||||
return ReceiveBlockResult.ADDED_TO_HEAD
|
||||
|
||||
@ -589,7 +713,10 @@ class WalletStateManager:
|
||||
await self.coin_added(coin, path_block.height, is_coinbase)
|
||||
for coin_name in path_block.removals:
|
||||
await self.coin_removed(coin_name, path_block.height)
|
||||
|
||||
self.lca = block.header_hash
|
||||
self.state_changed("coin_added")
|
||||
self.state_changed("coin_removed")
|
||||
self.state_changed("new_block")
|
||||
return ReceiveBlockResult.ADDED_TO_HEAD
|
||||
|
||||
@ -1133,9 +1260,6 @@ class WalletStateManager:
|
||||
async def get_all_wallets(self) -> List[WalletInfo]:
|
||||
return await self.user_store.get_all_wallets()
|
||||
|
||||
async def get_main_wallet(self):
|
||||
return await self.user_store.get_wallet_by_id(1)
|
||||
|
||||
async def get_coin_records_by_spent(self, spent: bool):
|
||||
return await self.wallet_store.get_coin_records_by_spent(spent)
|
||||
|
||||
|
@ -30,7 +30,7 @@ class WalletTransactionStore:
|
||||
f"CREATE TABLE IF NOT EXISTS transaction_record("
|
||||
f" transaction_record blob,"
|
||||
f" bundle_id text PRIMARY KEY,"
|
||||
f" confirmed_at_index int,"
|
||||
f" confirmed_at_index bigint,"
|
||||
f" created_at_time bigint,"
|
||||
f" to_puzzle_hash text,"
|
||||
f" amount bigint,"
|
||||
@ -38,7 +38,7 @@ class WalletTransactionStore:
|
||||
f" incoming int,"
|
||||
f" confirmed int,"
|
||||
f" sent int,"
|
||||
f" wallet_id int)"
|
||||
f" wallet_id bigint)"
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -80,7 +80,7 @@ class WebSocketServer:
|
||||
"""
|
||||
|
||||
wallet_id = int(request["wallet_id"])
|
||||
wallet = self.wallet_node.wallets[wallet_id]
|
||||
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
|
||||
puzzlehash = (await wallet.get_new_puzzlehash()).hex()
|
||||
|
||||
data = {
|
||||
@ -91,7 +91,7 @@ class WebSocketServer:
|
||||
|
||||
async def send_transaction(self, websocket, request, response_api):
|
||||
wallet_id = int(request["wallet_id"])
|
||||
wallet = self.wallet_node.wallets[wallet_id]
|
||||
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
|
||||
try:
|
||||
tx = await wallet.generate_signed_transaction_dict(request)
|
||||
except BaseException as e:
|
||||
@ -173,7 +173,7 @@ class WebSocketServer:
|
||||
|
||||
async def get_wallet_balance(self, websocket, request, response_api):
|
||||
wallet_id = int(request["wallet_id"])
|
||||
wallet = self.wallet_node.wallets[wallet_id]
|
||||
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
|
||||
balance = await wallet.get_confirmed_balance()
|
||||
pending_balance = await wallet.get_unconfirmed_balance()
|
||||
|
||||
@ -217,14 +217,18 @@ class WebSocketServer:
|
||||
rl_admin: RLWallet = await RLWallet.create_rl_admin(
|
||||
config, key_config, wallet_state_manager, main_wallet
|
||||
)
|
||||
self.wallet_node.wallets[rl_admin.wallet_info.id] = rl_admin
|
||||
self.wallet_node.wallet_state_manager.wallets[
|
||||
rl_admin.wallet_info.id
|
||||
] = rl_admin
|
||||
response = {"success": True, "type": "rl_wallet"}
|
||||
return await websocket.send(format_response(response_api, response))
|
||||
elif request["mode"] == "user":
|
||||
rl_user: RLWallet = await RLWallet.create_rl_user(
|
||||
config, key_config, wallet_state_manager, main_wallet
|
||||
)
|
||||
self.wallet_node.wallets[rl_user.wallet_info.id] = rl_user
|
||||
self.wallet_node.wallet_state_manager.wallets[
|
||||
rl_user.wallet_info.id
|
||||
] = rl_user
|
||||
response = {"success": True, "type": "rl_wallet"}
|
||||
return await websocket.send(format_response(response_api, response))
|
||||
elif request["wallet_type"] == "cc_wallet":
|
||||
@ -238,7 +242,7 @@ class WebSocketServer:
|
||||
self.wallet_node.config,
|
||||
self.wallet_node.key_config,
|
||||
self.wallet_node.wallet_state_manager,
|
||||
self.wallet_node.main_wallet,
|
||||
self.wallet_node.wallet_state_manager.main_wallet,
|
||||
)
|
||||
|
||||
async def get_wallets(self, websocket, response_api):
|
||||
@ -252,7 +256,7 @@ class WebSocketServer:
|
||||
|
||||
async def rl_set_admin_info(self, websocket, request, response_api):
|
||||
wallet_id = int(request["wallet_id"])
|
||||
wallet: RLWallet = self.wallet_node.wallets[wallet_id]
|
||||
wallet: RLWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
|
||||
user_pubkey = request["user_pubkey"]
|
||||
limit = uint64(int(request["limit"]))
|
||||
interval = uint64(int(request["interval"]))
|
||||
@ -266,7 +270,7 @@ class WebSocketServer:
|
||||
|
||||
async def rl_set_user_info(self, websocket, request, response_api):
|
||||
wallet_id = int(request["wallet_id"])
|
||||
wallet: RLWallet = self.wallet_node.wallets[wallet_id]
|
||||
wallet: RLWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
|
||||
admin_pubkey = request["admin_pubkey"]
|
||||
limit = uint64(int(request["limit"]))
|
||||
interval = uint64(int(request["interval"]))
|
||||
@ -396,7 +400,7 @@ async def start_websocket_server():
|
||||
|
||||
log.info("Starting websocket server.")
|
||||
websocket_server = await websockets.serve(
|
||||
handler.safe_handle, None, config["rpc_port"]
|
||||
handler.safe_handle, "localhost", config["rpc_port"]
|
||||
)
|
||||
log.info(f"Started websocket server at port {config['rpc_port']}.")
|
||||
|
||||
|
@ -78,9 +78,7 @@ class TestFullNodeProtocol:
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
_, _, blocks = wallet_blocks
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("127.0.0.1", uint16(server_1._port)), None)
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
|
||||
new_tip_1 = fnp.NewTip(
|
||||
@ -772,9 +770,7 @@ class TestFullNodeProtocol:
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
wallet_a, wallet_receiver, blocks = wallet_blocks
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
|
||||
msgs = [_ async for _ in full_node_1.request_peers(fnp.RequestPeers())]
|
||||
@ -787,9 +783,7 @@ class TestWalletProtocol:
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
wallet_a, wallet_receiver, blocks = wallet_blocks
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
blocks_list = await get_block_path(full_node_1)
|
||||
|
||||
blocks_new = bt.get_consecutive_blocks(
|
||||
@ -981,9 +975,7 @@ class TestWalletProtocol:
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
wallet_a, wallet_receiver, blocks = wallet_blocks
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
blocks_list = await get_block_path(full_node_1)
|
||||
blocks_new = bt.get_consecutive_blocks(
|
||||
test_constants, 5, seed=b"test_request_removals"
|
||||
@ -1166,9 +1158,7 @@ class TestWalletProtocol:
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
wallet_a, wallet_receiver, blocks = wallet_blocks
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
blocks_list = await get_block_path(full_node_1)
|
||||
blocks_new = bt.get_consecutive_blocks(
|
||||
test_constants, 5, seed=b"test_request_additions"
|
||||
|
@ -33,9 +33,7 @@ class TestFullSync:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
start = time.time()
|
||||
@ -86,9 +84,7 @@ class TestFullSync:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
|
||||
start = time.time()
|
||||
|
@ -35,9 +35,7 @@ class TestNodeLoad:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
|
||||
@ -77,9 +75,7 @@ class TestNodeLoad:
|
||||
full_node_1, full_node_2, server_1, server_2 = two_nodes
|
||||
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
await asyncio.sleep(2) # Allow connections to get made
|
||||
|
||||
|
@ -45,12 +45,10 @@ class TestTransactions:
|
||||
async def test_wallet_coinbase(self, wallet_node):
|
||||
num_blocks = 10
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
wallet = wallet_node.main_wallet
|
||||
wallet = wallet_node.wallet_state_manager.main_wallet
|
||||
ph = await wallet.get_new_puzzlehash()
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
for i in range(1, num_blocks):
|
||||
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
|
||||
|
||||
@ -74,23 +72,19 @@ class TestTransactions:
|
||||
full_node_1, server_1 = full_nodes[1]
|
||||
full_node_2, server_2 = full_nodes[2]
|
||||
|
||||
ph = await wallet_0.main_wallet.get_new_puzzlehash()
|
||||
ph1 = await wallet_1.main_wallet.get_new_puzzlehash()
|
||||
ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash()
|
||||
ph1 = await wallet_1.wallet_state_manager.main_wallet.get_new_puzzlehash()
|
||||
|
||||
#
|
||||
# wallet0 <-> sever0 <-> server1 <-> server2 <-> wallet1
|
||||
#
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_0._host, uint16(server_0._port)), None
|
||||
)
|
||||
await server_0.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_1.start_client(
|
||||
PeerInfo(server_2._host, uint16(server_2._port)), None
|
||||
PeerInfo("localhost", uint16(server_0._port)), None
|
||||
)
|
||||
await server_0.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
await server_1.start_client(PeerInfo("localhost", uint16(server_2._port)), None)
|
||||
await wallet_server_1.start_client(
|
||||
PeerInfo(server_2._host, uint16(server_2._port)), None
|
||||
PeerInfo("localhost", uint16(server_2._port)), None
|
||||
)
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
@ -103,12 +97,15 @@ class TestTransactions:
|
||||
for i in range(1, num_blocks - 2)
|
||||
]
|
||||
)
|
||||
assert await wallet_0.main_wallet.get_confirmed_balance() == funds
|
||||
assert (
|
||||
await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance()
|
||||
== funds
|
||||
)
|
||||
|
||||
spend_bundle = await wallet_0.main_wallet.generate_signed_transaction(
|
||||
spend_bundle = await wallet_0.wallet_state_manager.main_wallet.generate_signed_transaction(
|
||||
10, ph1, 0
|
||||
)
|
||||
await wallet_0.main_wallet.push_transaction(spend_bundle)
|
||||
await wallet_0.wallet_state_manager.main_wallet.push_transaction(spend_bundle)
|
||||
|
||||
await asyncio.sleep(3)
|
||||
|
||||
@ -132,8 +129,14 @@ class TestTransactions:
|
||||
]
|
||||
)
|
||||
|
||||
assert await wallet_0.main_wallet.get_confirmed_balance() == funds - 10
|
||||
assert await wallet_1.main_wallet.get_confirmed_balance() == 10
|
||||
assert (
|
||||
await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance()
|
||||
== funds - 10
|
||||
)
|
||||
assert (
|
||||
await wallet_1.wallet_state_manager.main_wallet.get_confirmed_balance()
|
||||
== 10
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mempool_tx_sync(self, three_nodes_two_wallets):
|
||||
@ -145,16 +148,14 @@ class TestTransactions:
|
||||
full_node_1, server_1 = full_nodes[1]
|
||||
full_node_2, server_2 = full_nodes[2]
|
||||
|
||||
ph = await wallet_0.main_wallet.get_new_puzzlehash()
|
||||
ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash()
|
||||
|
||||
# wallet0 <-> sever0 <-> server1
|
||||
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_0._host, uint16(server_0._port)), None
|
||||
)
|
||||
await server_0.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
PeerInfo("localhost", uint16(server_0._port)), None
|
||||
)
|
||||
await server_0.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
await full_node_0.farm_new_block(FarmNewBlockProtocol(ph))
|
||||
@ -174,12 +175,15 @@ class TestTransactions:
|
||||
for i in range(1, num_blocks - 2)
|
||||
]
|
||||
)
|
||||
assert await wallet_0.main_wallet.get_confirmed_balance() == funds
|
||||
assert (
|
||||
await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance()
|
||||
== funds
|
||||
)
|
||||
|
||||
spend_bundle = await wallet_0.main_wallet.generate_signed_transaction(
|
||||
spend_bundle = await wallet_0.wallet_state_manager.main_wallet.generate_signed_transaction(
|
||||
10, token_bytes(), 0
|
||||
)
|
||||
await wallet_0.main_wallet.push_transaction(spend_bundle)
|
||||
await wallet_0.wallet_state_manager.main_wallet.push_transaction(spend_bundle)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
@ -194,9 +198,7 @@ class TestTransactions:
|
||||
# make a final connection.
|
||||
# wallet0 <-> sever0 <-> server1 <-> server2
|
||||
|
||||
await server_1.start_client(
|
||||
PeerInfo(server_2._host, uint16(server_2._port)), None
|
||||
)
|
||||
await server_1.start_client(PeerInfo("localhost", uint16(server_2._port)), None)
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
|
@ -68,7 +68,7 @@ class TestRpc:
|
||||
|
||||
assert len(await client.get_connections()) == 0
|
||||
|
||||
await client.open_connection(server_2._host, server_2._port)
|
||||
await client.open_connection("localhost", server_2._port)
|
||||
await asyncio.sleep(2)
|
||||
connections = await client.get_connections()
|
||||
assert len(connections) == 1
|
||||
|
@ -421,14 +421,14 @@ async def setup_full_system(dic={}):
|
||||
node2, node2_server = await node_iters[6].__anext__()
|
||||
|
||||
await harvester_server.start_client(
|
||||
PeerInfo(farmer_server._host, uint16(farmer_server._port)), None
|
||||
PeerInfo("127.0.0.1", uint16(farmer_server._port)), None
|
||||
)
|
||||
await farmer_server.start_client(
|
||||
PeerInfo(node1_server._host, uint16(node1_server._port)), None
|
||||
PeerInfo("127.0.0.1", uint16(node1_server._port)), None
|
||||
)
|
||||
|
||||
await timelord_server.start_client(
|
||||
PeerInfo(node1_server._host, uint16(node1_server._port)), None
|
||||
PeerInfo("127.0.0.1", uint16(node1_server._port)), None
|
||||
)
|
||||
|
||||
yield (node1, node2)
|
||||
|
@ -33,7 +33,7 @@ class TestFilter:
|
||||
config = load_config("config.yaml", "wallet")
|
||||
key_config = {"wallet_sk": sk}
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_and_node
|
||||
wallet = wallet_node.main_wallet
|
||||
wallet = wallet_node.wallet_state_manager.main_wallet
|
||||
|
||||
num_blocks = 2
|
||||
ph = await wallet.get_new_puzzlehash()
|
||||
|
99
tests/wallet/test_puzzle_store.py
Normal file
99
tests/wallet/test_puzzle_store.py
Normal file
@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
from secrets import token_bytes
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
import sqlite3
|
||||
import random
|
||||
import pytest
|
||||
import aiosqlite
|
||||
from blspy import PrivateKey
|
||||
from src.full_node.store import FullNodeStore
|
||||
from src.types.full_block import FullBlock
|
||||
from src.types.sized_bytes import bytes32
|
||||
from src.util.ints import uint32, uint64
|
||||
from src.wallet.wallet_puzzle_store import WalletPuzzleStore
|
||||
from src.wallet.derivation_record import DerivationRecord
|
||||
from src.wallet.util.wallet_types import WalletType
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.get_event_loop()
|
||||
yield loop
|
||||
|
||||
|
||||
class TestPuzzleStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_puzzle_store(self):
|
||||
db_filename = Path("puzzle_store_test.db")
|
||||
|
||||
if db_filename.exists():
|
||||
db_filename.unlink()
|
||||
|
||||
con = await aiosqlite.connect(db_filename)
|
||||
db = await WalletPuzzleStore.create(con)
|
||||
try:
|
||||
derivation_recs = []
|
||||
wallet_types = [t for t in WalletType]
|
||||
|
||||
for i in range(1000):
|
||||
derivation_recs.append(
|
||||
DerivationRecord(
|
||||
uint32(i),
|
||||
token_bytes(32),
|
||||
PrivateKey.from_seed(token_bytes(5)).get_public_key(),
|
||||
WalletType.STANDARD_WALLET,
|
||||
uint32(1),
|
||||
)
|
||||
)
|
||||
derivation_recs.append(
|
||||
DerivationRecord(
|
||||
uint32(i),
|
||||
token_bytes(32),
|
||||
PrivateKey.from_seed(token_bytes(5)).get_public_key(),
|
||||
WalletType.RATE_LIMITED,
|
||||
uint32(2),
|
||||
)
|
||||
)
|
||||
assert await db.puzzle_hash_exists(derivation_recs[0].puzzle_hash) == False
|
||||
assert await db.index_for_pubkey(derivation_recs[0].pubkey) == None
|
||||
assert (
|
||||
await db.index_for_puzzle_hash(derivation_recs[2].puzzle_hash) == None
|
||||
)
|
||||
assert (
|
||||
await db.wallet_info_for_puzzle_hash(derivation_recs[2].puzzle_hash)
|
||||
== None
|
||||
)
|
||||
assert len((await db.get_all_puzzle_hashes())) == 0
|
||||
assert await db.get_last_derivation_path() == None
|
||||
assert await db.get_unused_derivation_path() == None
|
||||
assert await db.get_derivation_record(0, 2) == None
|
||||
|
||||
await db.add_derivation_paths(derivation_recs)
|
||||
|
||||
assert await db.puzzle_hash_exists(derivation_recs[0].puzzle_hash) == True
|
||||
assert await db.index_for_pubkey(derivation_recs[4].pubkey) == 2
|
||||
assert await db.index_for_puzzle_hash(derivation_recs[2].puzzle_hash) == 1
|
||||
assert await db.wallet_info_for_puzzle_hash(
|
||||
derivation_recs[2].puzzle_hash
|
||||
) == (derivation_recs[2].wallet_id, derivation_recs[2].wallet_type)
|
||||
assert len((await db.get_all_puzzle_hashes())) == 2000
|
||||
assert await db.get_last_derivation_path() == 999
|
||||
assert await db.get_unused_derivation_path() == 0
|
||||
assert await db.get_derivation_record(0, 2) == derivation_recs[1]
|
||||
|
||||
# Indeces up to 250
|
||||
await db.set_used_up_to(249)
|
||||
|
||||
assert await db.get_unused_derivation_path() == 250
|
||||
|
||||
except Exception as e:
|
||||
print(e, type(e))
|
||||
await db._clear_database()
|
||||
await db.close()
|
||||
db_filename.unlink()
|
||||
raise e
|
||||
|
||||
await db._clear_database()
|
||||
await db.close()
|
||||
db_filename.unlink()
|
@ -52,12 +52,10 @@ class TestWalletSimulator:
|
||||
async def test_wallet_coinbase(self, wallet_node):
|
||||
num_blocks = 10
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
wallet = wallet_node.main_wallet
|
||||
wallet = wallet_node.wallet_state_manager.main_wallet
|
||||
ph = await wallet.get_new_puzzlehash()
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
for i in range(1, num_blocks):
|
||||
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
|
||||
|
||||
@ -81,12 +79,10 @@ class TestWalletSimulator:
|
||||
server_2,
|
||||
server_3,
|
||||
) = two_wallet_nodes
|
||||
wallet = wallet_node.main_wallet
|
||||
wallet = wallet_node.wallet_state_manager.main_wallet
|
||||
ph = await wallet.get_new_puzzlehash()
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
for i in range(0, num_blocks):
|
||||
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
|
||||
@ -104,7 +100,9 @@ class TestWalletSimulator:
|
||||
assert await wallet.get_unconfirmed_balance() == funds
|
||||
|
||||
spend_bundle = await wallet.generate_signed_transaction(
|
||||
10, await wallet_node_2.main_wallet.get_new_puzzlehash(), 0
|
||||
10,
|
||||
await wallet_node_2.wallet_state_manager.main_wallet.get_new_puzzlehash(),
|
||||
0,
|
||||
)
|
||||
await wallet.push_transaction(spend_bundle)
|
||||
|
||||
@ -137,12 +135,10 @@ class TestWalletSimulator:
|
||||
async def test_wallet_coinbase_reorg(self, wallet_node):
|
||||
num_blocks = 10
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
wallet = wallet_node.main_wallet
|
||||
wallet = wallet_node.wallet_state_manager.main_wallet
|
||||
ph = await wallet.get_new_puzzlehash()
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
for i in range(1, num_blocks):
|
||||
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
|
||||
|
||||
@ -179,11 +175,11 @@ class TestWalletSimulator:
|
||||
full_node_1, server_1 = full_nodes[1]
|
||||
full_node_2, server_2 = full_nodes[2]
|
||||
|
||||
ph = await wallet_0.main_wallet.get_new_puzzlehash()
|
||||
ph = await wallet_0.wallet_state_manager.main_wallet.get_new_puzzlehash()
|
||||
|
||||
# wallet0 <-> sever0
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_0._host, uint16(server_0._port)), None
|
||||
PeerInfo("localhost", uint16(server_0._port)), None
|
||||
)
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
@ -208,12 +204,15 @@ class TestWalletSimulator:
|
||||
for i in range(1, num_blocks - 2)
|
||||
]
|
||||
)
|
||||
assert await wallet_0.main_wallet.get_confirmed_balance() == funds
|
||||
assert (
|
||||
await wallet_0.wallet_state_manager.main_wallet.get_confirmed_balance()
|
||||
== funds
|
||||
)
|
||||
|
||||
spend_bundle = await wallet_0.main_wallet.generate_signed_transaction(
|
||||
spend_bundle = await wallet_0.wallet_state_manager.main_wallet.generate_signed_transaction(
|
||||
10, token_bytes(), 0
|
||||
)
|
||||
await wallet_0.main_wallet.push_transaction(spend_bundle)
|
||||
await wallet_0.wallet_state_manager.main_wallet.push_transaction(spend_bundle)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -222,7 +221,7 @@ class TestWalletSimulator:
|
||||
|
||||
# wallet0 <-> sever1
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), wallet_0._on_connect
|
||||
PeerInfo("localhost", uint16(server_1._port)), wallet_0._on_connect
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -231,7 +230,7 @@ class TestWalletSimulator:
|
||||
|
||||
# wallet0 <-> sever2
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_2._host, uint16(server_2._port)), wallet_0._on_connect
|
||||
PeerInfo("localhost", uint16(server_2._port)), wallet_0._on_connect
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
@ -249,16 +248,16 @@ class TestWalletSimulator:
|
||||
wallet_0_server,
|
||||
wallet_1_server,
|
||||
) = two_wallet_nodes_five_freeze
|
||||
wallet_0 = wallet_node_0.main_wallet
|
||||
wallet_1 = wallet_node_1.main_wallet
|
||||
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
|
||||
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
|
||||
ph = await wallet_0.get_new_puzzlehash()
|
||||
|
||||
await wallet_0_server.start_client(
|
||||
PeerInfo(full_node_server._host, uint16(full_node_server._port)), None
|
||||
PeerInfo("localhost", uint16(full_node_server._port)), None
|
||||
)
|
||||
|
||||
await wallet_1_server.start_client(
|
||||
PeerInfo(full_node_server._host, uint16(full_node_server._port)), None
|
||||
PeerInfo("localhost", uint16(full_node_server._port)), None
|
||||
)
|
||||
|
||||
for i in range(0, num_blocks):
|
||||
@ -277,7 +276,9 @@ class TestWalletSimulator:
|
||||
assert await wallet_0.get_unconfirmed_balance() == funds
|
||||
|
||||
spend_bundle = await wallet_0.generate_signed_transaction(
|
||||
10, await wallet_node_1.main_wallet.get_new_puzzlehash(), 0
|
||||
10,
|
||||
await wallet_node_1.wallet_state_manager.main_wallet.get_new_puzzlehash(),
|
||||
0,
|
||||
)
|
||||
|
||||
await wallet_0.push_transaction(spend_bundle)
|
||||
|
@ -43,9 +43,7 @@ class TestWalletSync:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
start = time.time()
|
||||
found = False
|
||||
@ -104,9 +102,7 @@ class TestWalletSync:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
start = time.time()
|
||||
found = False
|
||||
@ -137,9 +133,7 @@ class TestWalletSync:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
start = time.time()
|
||||
while time.time() - start < 60:
|
||||
# The second node should eventually catch up to the first one, and have the
|
||||
@ -159,7 +153,7 @@ class TestWalletSync:
|
||||
@pytest.mark.asyncio
|
||||
async def test_short_sync_with_transactions_wallet(self, wallet_node):
|
||||
full_node_1, wallet_node, server_1, server_2 = wallet_node
|
||||
wallet_a = wallet_node.main_wallet
|
||||
wallet_a = wallet_node.wallet_state_manager.main_wallet
|
||||
wallet_a_dummy = WalletTool()
|
||||
wallet_b = WalletTool()
|
||||
coinbase_puzzlehash = await wallet_a.get_new_puzzlehash()
|
||||
@ -178,9 +172,7 @@ class TestWalletSync:
|
||||
)
|
||||
]
|
||||
await asyncio.sleep(2)
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
await asyncio.sleep(2)
|
||||
assert (
|
||||
wallet_node.wallet_state_manager.block_records[
|
||||
@ -215,9 +207,7 @@ class TestWalletSync:
|
||||
pass
|
||||
|
||||
# Do a short sync from 0 to 14
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
start = time.time()
|
||||
broke = False
|
||||
while time.time() - start < 60:
|
||||
@ -282,9 +272,7 @@ class TestWalletSync:
|
||||
pass
|
||||
|
||||
# Do a sync from 0 to 22
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
broke = False
|
||||
while time.time() - start < 60:
|
||||
@ -359,9 +347,7 @@ class TestWalletSync:
|
||||
):
|
||||
pass
|
||||
|
||||
await server_2.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), None
|
||||
)
|
||||
await server_2.start_client(PeerInfo("localhost", uint16(server_1._port)), None)
|
||||
|
||||
broke = False
|
||||
while time.time() - start < 60:
|
||||
|
Loading…
Reference in New Issue
Block a user