wallet changes

This commit is contained in:
Yostra 2020-02-13 11:57:40 -08:00
parent c907cf6431
commit aecf339cc9
2 changed files with 220 additions and 74 deletions

View File

@ -1,26 +1,41 @@
from pathlib import Path
from typing import Dict, Optional, List, Set, Tuple
import clvm
from blspy import ExtendedPrivateKey, PublicKey
import logging
import src.protocols.wallet_protocol
from src.full_node import OutboundMessageGenerator
from src.protocols.wallet_protocol import ProofHash
from src.server.outbound_message import OutboundMessage, NodeType, Message, Delivery
from src.server.server import ChiaServer
from src.types.full_block import additions_for_npc
from src.types.hashable.BLSSignature import BLSSignature
from src.types.hashable.Coin import Coin
from src.types.hashable.CoinRecord import CoinRecord
from src.types.hashable.CoinSolution import CoinSolution
from src.types.hashable.Program import Program
from src.types.hashable.SpendBundle import SpendBundle
from src.types.name_puzzle_condition import NPC
from src.types.sized_bytes import bytes32
from src.util.Hash import std_hash
from src.util.api_decorators import api_request
from src.util.condition_tools import (
conditions_for_solution,
conditions_by_opcode,
hash_key_pairs_for_conditions_dict,
)
from src.util.ints import uint32, uint64
from src.util.mempool_check_conditions import get_name_puzzle_conditions
from src.wallet.BLSPrivateKey import BLSPrivateKey
from src.wallet.puzzles.p2_conditions import puzzle_for_conditions
from src.wallet.puzzles.p2_delegated_puzzle import puzzle_for_pk
from src.wallet.puzzles.puzzle_utils import make_assert_my_coin_id_condition, make_assert_time_exceeds_condition, \
make_assert_coin_consumed_condition, make_create_coin_condition
from src.wallet.puzzles.puzzle_utils import (
make_assert_my_coin_id_condition,
make_assert_time_exceeds_condition,
make_assert_coin_consumed_condition,
make_create_coin_condition,
)
from src.wallet.wallet_store import WalletStore
class Wallet:
@ -32,8 +47,26 @@ class Wallet:
pubkey_num_lookup: Dict[bytes, int]
tmp_balance: int
tmp_coins: Set[Coin]
wallet_store: WalletStore
header_hash: List[bytes32]
start_index: int
def __init__(self, config: Dict, key_config: Dict, name: str = None):
unconfirmed_removals: List[Coin]
unconfirmed_removal_amount: int
unconfirmed_additions: List[Coin]
unconfirmed_addition_amount: int
# This dict maps coin_id to SpendBundle, it will contain duplicate values by design
coin_spend_bundle_map: Dict[bytes32, SpendBundle]
# Spendbundle_ID : Spendbundle
pending_spend_bundles: Dict[bytes32, SpendBundle]
log: logging.Logger
@staticmethod
async def create(config: Dict, key_config: Dict, name: str = None):
self = Wallet()
print("init wallet")
self.config = config
self.key_config = key_config
@ -46,19 +79,46 @@ class Wallet:
self.pubkey_num_lookup = {}
self.tmp_balance = 0
self.tmp_utxos = set()
self.tmp_coins = set()
pub_hex = self.private_key.get_public_key().serialize().hex()
path = Path(f"wallet_db_{pub_hex}.db")
self.wallet_store = await WalletStore.create(path)
self.header_hash = []
self.unconfirmed_additions = []
self.unconfirmed_removals = []
self.pending_spend_bundles = {}
self.coin_spend_bundle_map = {}
self.unconfirmed_addition_amount = 0
self.unconfirmed_removal_amount = 0
return self
def get_next_public_key(self) -> PublicKey:
pubkey = self.private_key.public_child(
self.next_address).get_public_key()
pubkey = self.private_key.public_child(self.next_address).get_public_key()
self.pubkey_num_lookup[pubkey.serialize()] = self.next_address
self.next_address = self.next_address + 1
return pubkey
async def get_balance(self) -> uint64:
record_list: Set[
CoinRecord
] = await self.wallet_store.get_coin_records_by_spent(False)
amount: uint64 = uint64(0)
for record in record_list:
amount = uint64(amount + record.coin.amount)
for removal in self.unconfirmed_removals:
amount = uint64(amount - removal.amount)
return uint64(amount)
def can_generate_puzzle_hash(self, hash: bytes32) -> bool:
return any(map(lambda child: hash == puzzle_for_pk(
self.private_key.public_child(child).get_public_key().serialize()).get_hash(),
reversed(range(self.next_address))))
return any(
map(
lambda child: hash
== puzzle_for_pk(
self.private_key.public_child(child).get_public_key().serialize()
).get_hash(),
reversed(range(self.next_address)),
)
)
def puzzle_for_pk(self, pubkey) -> Program:
return puzzle_for_pk(pubkey)
@ -73,73 +133,124 @@ class Wallet:
puzzlehash: bytes32 = puzzle.get_hash()
return puzzlehash
def select_coins(self, amount) -> Optional[Set[Coin]]:
if amount > self.tmp_balance:
return None
used_coins: Set[Coin] = set()
while sum(map(lambda coin: coin.amount, used_coins)) < amount:
used_coins.add(self.tmp_utxos.pop())
return used_coins
async def select_coins(self, amount) -> Optional[Set[Coin]]:
# TODO pick proper coins
return None
def set_server(self, server: ChiaServer):
self.server = server
def sign(self, value: bytes32, pubkey: PublicKey):
private_key = self.private_key.private_child(
self.pubkey_num_lookup[pubkey]).get_private_key()
self.pubkey_num_lookup[pubkey]
).get_private_key()
bls_key = BLSPrivateKey(private_key)
return bls_key.sign(value)
def make_solution(self, primaries=[], min_time=0, me={}, consumed=[]):
def make_solution(self, primaries=None, min_time=0, me=None, consumed=None):
ret = []
for primary in primaries:
ret.append(make_create_coin_condition(
primary['puzzlehash'], primary['amount']))
for coin in consumed:
ret.append(make_assert_coin_consumed_condition(coin))
if primaries:
for primary in primaries:
ret.append(
make_create_coin_condition(primary["puzzlehash"], primary["amount"])
)
if consumed:
for coin in consumed:
ret.append(make_assert_coin_consumed_condition(coin))
if min_time > 0:
ret.append(make_assert_time_exceeds_condition(min_time))
if me:
ret.append(make_assert_my_coin_id_condition(me['id']))
ret.append(make_assert_my_coin_id_condition(me["id"]))
return clvm.to_sexp_f([puzzle_for_conditions(ret), []])
def get_keys(self, hash: bytes32) -> Optional[Tuple[PublicKey, ExtendedPrivateKey]]:
for child in range(self.next_address):
pubkey = self.private_key.public_child(
child).get_public_key()
pubkey = self.private_key.public_child(child).get_public_key()
if hash == puzzle_for_pk(pubkey.serialize()).get_hash():
return pubkey, self.private_key.private_child(child).get_private_key()
return None
def generate_unsigned_transaction(self, amount: int, newpuzzlehash: bytes32, fee: int = 0) -> List[Tuple[Program, CoinSolution]]:
async def generate_unsigned_transaction(
self, amount: int, newpuzzlehash: bytes32, fee: int = 0
) -> List[Tuple[Program, CoinSolution]]:
if self.tmp_balance < amount:
return None
utxos = self.select_coins(amount + fee)
return []
utxos = await self.select_coins(amount + fee)
if utxos is None:
return []
spends: List[Tuple[Program, CoinSolution]] = []
output_created = False
spend_value = sum([coin.amount for coin in utxos])
change = spend_value - amount - fee
for coin in utxos:
puzzle_hash = coin.puzzle_hash
pubkey, secretkey = self.get_keys(puzzle_hash)
maybe = self.get_keys(puzzle_hash)
if not maybe:
return []
pubkey, secretkey = maybe
puzzle: Program = puzzle_for_pk(pubkey.serialize())
if output_created is False:
primaries = [{'puzzlehash': newpuzzlehash, 'amount': amount}]
primaries = [{"puzzlehash": newpuzzlehash, "amount": amount}]
if change > 0:
changepuzzlehash = self.get_new_puzzlehash()
primaries.append(
{'puzzlehash': changepuzzlehash, 'amount': change})
primaries.append({"puzzlehash": changepuzzlehash, "amount": change})
# add change coin into temp_utxo set
self.tmp_utxos.add(Coin(coin, changepuzzlehash, change))
self.tmp_coins.add(Coin(coin, changepuzzlehash, uint64(change)))
solution = self.make_solution(primaries=primaries)
output_created = True
else:
solution = self.make_solution(consumed=[coin.name()])
spends.append((puzzle, CoinSolution(coin, solution)))
self.tmp_balance -= (amount + fee)
self.tmp_balance -= amount + fee
return spends
def sign_transaction(self, spends: List[Tuple[Program, CoinSolution]]):
sigs = []
for puzzle, solution in spends:
keys = self.get_keys(solution.coin.puzzle_hash)
if not keys:
return None
pubkey, secretkey = keys
secretkey = BLSPrivateKey(secretkey)
code_ = [puzzle, solution.solution]
sexp = clvm.to_sexp_f(code_)
err, con = conditions_for_solution(sexp)
if err or not con:
return None
conditions_dict = conditions_by_opcode(con)
for _ in hash_key_pairs_for_conditions_dict(conditions_dict):
signature = secretkey.sign(_.message_hash)
sigs.append(signature)
aggsig = BLSSignature.aggregate(sigs)
solution_list: List[CoinSolution] = [
CoinSolution(
coin_solution.coin, clvm.to_sexp_f([puzzle, coin_solution.solution])
)
for (puzzle, coin_solution) in spends
]
spend_bundle = SpendBundle(solution_list, aggsig)
return spend_bundle
async def generate_signed_transaction(
self, amount, newpuzzlehash, fee: int = 0
) -> Optional[SpendBundle]:
transaction = await self.generate_unsigned_transaction(
amount, newpuzzlehash, fee
)
if len(transaction) == 0:
return None
return self.sign_transaction(transaction)
async def coin_removed(self, coin_name: bytes32, index: uint32):
self.log.info("remove coin")
await self.wallet_store.set_spent(coin_name, index)
async def coin_added(self, coin: Coin, index: uint32, coinbase: bool):
self.log.info("add coin")
coin_record: CoinRecord = CoinRecord(coin, index, uint32(0), False, coinbase)
await self.wallet_store.add_coin_record(coin_record)
async def _on_connect(self) -> OutboundMessageGenerator:
"""
Whenever we connect to a FullNode we request new proof_hashes by sending last proof hash we have
@ -147,11 +258,15 @@ class Wallet:
self.log.info(f"Requesting proof hashes")
request = ProofHash(std_hash(b"deadbeef"))
yield OutboundMessage(
NodeType.FULL_NODE, Message("request_proof_hashes", request), Delivery.BROADCAST
NodeType.FULL_NODE,
Message("request_proof_hashes", request),
Delivery.BROADCAST,
)
@api_request
async def proof_hash(self, request: src.protocols.wallet_protocol.ProofHash) -> OutboundMessageGenerator:
async def proof_hash(
self, request: src.protocols.wallet_protocol.ProofHash
) -> OutboundMessageGenerator:
"""
Received a proof hash from the FullNode
"""
@ -159,17 +274,54 @@ class Wallet:
reply_request = ProofHash(std_hash(b"a"))
# TODO Store and decide if we want full proof for this proof hash
yield OutboundMessage(
NodeType.FULL_NODE, Message("request_full_proof_for_hash", reply_request), Delivery.RESPOND
NodeType.FULL_NODE,
Message("request_full_proof_for_hash", reply_request),
Delivery.RESPOND,
)
@api_request
async def full_proof_for_hash(self, request: src.protocols.wallet_protocol.FullProofForHash):
async def full_proof_for_hash(
self, request: src.protocols.wallet_protocol.FullProofForHash
):
"""
We've received a full proof for hash we requested
"""
# TODO Validate full proof
self.log.info(f"Received new proof: {request}")
@api_request
async def received_body(self, response: src.protocols.wallet_protocol.RespondBody):
"""
Called when body is received from the FullNode
"""
additions: List[Coin] = []
if self.can_generate_puzzle_hash(response.body.coinbase.puzzle_hash):
await self.coin_added(response.body.coinbase, response.height, True)
if self.can_generate_puzzle_hash(response.body.fees_coin.puzzle_hash):
await self.coin_added(response.body.fees_coin, response.height, True)
npc_list: List[NPC]
if response.body.transactions:
error, npc_list, cost = await get_name_puzzle_conditions(
response.body.transactions
)
additions.extend(additions_for_npc(npc_list))
for added_coin in additions:
if self.can_generate_puzzle_hash(added_coin.puzzle_hash):
await self.coin_added(added_coin, response.height, False)
for npc in npc_list:
if self.can_generate_puzzle_hash(npc.puzzle_hash):
await self.coin_removed(npc.coin_name, response.height)
@api_request
async def new_tip(self, header: src.protocols.wallet_protocol.Header):
self.log.info("new tip received")
async def send_transaction(self, spend_bundle: SpendBundle):
msg = OutboundMessage(
NodeType.FULL_NODE,
@ -180,6 +332,16 @@ class Wallet:
self.log.info(reply)
@api_request
async def transaction_ack(self, id: bytes32):
async def transaction_ack(self, ack: src.protocols.wallet_protocol.TransactionAck):
# TODO Remove from retry queue
print(f"tx has been received by the fullnode {}")
if ack.status:
self.log.info(f"SpendBundle has been received by the FullNode. id: {id}")
else:
self.log.info(f"SpendBundle has been rejected by the FullNode. id: {id}")
async def requestLCA(self):
msg = OutboundMessage(
NodeType.FULL_NODE, Message("request_lca", None), Delivery.BROADCAST,
)
async for reply in self.server.push_message(msg):
self.log.info(reply)

View File

@ -1,14 +1,10 @@
import asyncio
from typing import Dict, Optional, List
from typing import Dict, Optional, List, Set
from pathlib import Path
import aiosqlite
from src.types.body import Body
from src.types.full_block import FullBlock
from src.types.hashable.Coin import Coin
from src.types.hashable.CoinRecord import CoinRecord
from src.types.sized_bytes import bytes32
from src.types.header import Header
from src.util.ints import uint32
@ -113,13 +109,7 @@ class WalletStore:
await self.add_coin_record(spent)
# Checks DB and DiffStores for CoinRecord with coin_name and returns it
async def get_coin_record(
self, coin_name: bytes32, header: Header = None
) -> Optional[CoinRecord]:
if header is not None and header.header_hash in self.head_diffs:
diff_store = self.head_diffs[header.header_hash]
if coin_name.hex() in diff_store.diffs:
return diff_store.diffs[coin_name.hex()]
async def get_coin_record(self, coin_name: bytes32) -> Optional[CoinRecord]:
if coin_name.hex() in self.lca_coin_records:
return self.lca_coin_records[coin_name.hex()]
cursor = await self.coin_record_db.execute(
@ -128,16 +118,14 @@ class WalletStore:
row = await cursor.fetchone()
await cursor.close()
if row is not None:
coin = Coin(bytes32(bytes.fromhex(row[6])),
bytes32(bytes.fromhex(row[5])),
row[7])
coin = Coin(
bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), row[7]
)
return CoinRecord(coin, row[1], row[2], row[3], row[4])
return None
# Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
async def get_coin_records_by_spent(
self, spent: bool
) -> List[CoinRecord]:
async def get_coin_records_by_spent(self, spent: bool) -> Set[CoinRecord]:
coins = set()
cursor = await self.coin_record_db.execute(
@ -146,17 +134,15 @@ class WalletStore:
rows = await cursor.fetchall()
await cursor.close()
for row in rows:
coin = Coin(bytes32(bytes.fromhex(row[6])),
bytes32(bytes.fromhex(row[5])),
row[7])
coins.add(
CoinRecord(coin, row[1], row[2], row[3], row[4])
coin = Coin(
bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), row[7]
)
return list(coins)
coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4]))
return coins
# Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
async def get_coin_records_by_puzzle_hash(
self, puzzle_hash: bytes32, header: Header = None
self, puzzle_hash: bytes32
) -> List[CoinRecord]:
coins = set()
cursor = await self.coin_record_db.execute(
@ -165,12 +151,10 @@ class WalletStore:
rows = await cursor.fetchall()
await cursor.close()
for row in rows:
coin = Coin(bytes32(bytes.fromhex(row[6])),
bytes32(bytes.fromhex(row[5])),
row[7])
coins.add(
CoinRecord(coin, row[1], row[2], row[3], row[4])
coin = Coin(
bytes32(bytes.fromhex(row[6])), bytes32(bytes.fromhex(row[5])), row[7]
)
coins.add(CoinRecord(coin, row[1], row[2], row[3], row[4]))
return list(coins)
async def rollback_lca_to_block(self, block_index):