cancel trades

This commit is contained in:
Yostra 2020-06-14 21:34:37 -07:00 committed by Gene Hoffman
parent bb6efc9f56
commit a211200bb6
4 changed files with 120 additions and 37 deletions

View File

@ -26,6 +26,8 @@ from src.wallet.types.key_val_types import PendingOffers, AcceptedOffers
from src.wallet.wallet import Wallet
from clvm_tools import binutils
from src.wallet.wallet_coin_record import WalletCoinRecord
PENDING_OFFERS = "pending_offers"
ACCEPTED_OFFERS = "accepted_offers"
@ -74,7 +76,8 @@ class TradeManager:
accepted = AcceptedOffers.from_bytes((hexstr_to_bytes(accepted_offers_hex)))
return accepted.trades
async def get_locked_coins(self) -> Dict[bytes32, Coin]:
async def get_locked_coins(self, wallet_id: int = None) -> Dict[bytes32, Coin]:
""" Returns a dictionary of confirmed coins that are locked by a trade. """
current_trades = await self.get_pending_offers()
if current_trades is None:
return {}
@ -82,20 +85,57 @@ class TradeManager:
result = {}
for trade_offer in current_trades:
spend_bundle = trade_offer.spend_bundle
additions = spend_bundle.additions()
removals = spend_bundle.removals()
for coin in additions:
result[coin.name()] = coin
for coin in removals:
result[coin.name()] = coin
record: Optional[
WalletCoinRecord
] = await self.wallet_state_manager.wallet_store.get_coin_record_by_coin_id(
coin.name()
)
if record is None:
continue
if wallet_id is None or wallet_id == record.wallet_id:
result[coin.name()] = coin
return result
async def cancel_trade(self, trade_id: bytes32):
self.log.info("Need to cancel this trade")
async def cancel_pending_offer(self, trade_id: bytes32):
self.log.info(f"Cancel pending offer with id trade_id {trade_id.hex()}")
offers: List[TradeOffer] = await self.get_pending_offers()
filtered_offers: List[TradeOffer] = []
for offer in offers:
if offer.trade_id != trade_id:
filtered_offers.append(offer)
async def cancel_trade_safe(self, name: str):
self.log.info("Need to cancel this trade")
to_store = PendingOffers(filtered_offers)
await self.wallet_state_manager.basic_store.set(PENDING_OFFERS, to_store)
async def cancel_pending_offer_safely(self, trade_id: bytes32):
""" This will create a transaction that includes coins that were offered"""
self.log.info(f"Secure-Cancel pending offer with id trade_id {trade_id.hex()}")
offers: List[TradeOffer] = await self.get_pending_offers()
to_cancel: Optional[TradeOffer] = None
for offer in offers:
if offer.trade_id == trade_id:
to_cancel = offer
break
if to_cancel is None:
return
all_coins = to_cancel.spend_bundle.additions()
all_coins.extend(to_cancel.spend_bundle.removals())
for coin in all_coins:
wallet = self.wallet_state_manager.get_wallet_for_coin(coin)
if wallet is None:
continue
new_ph = await wallet.get_new_puzzlehash()
tx = wallet.generate_signed_transaction(
coin.amount, new_ph, 0, coins={coin}
)
await self.wallet_state_manager.add_pending_transaction(tx_record=tx)
return
async def add_pending_offer(self, trade_offer: TradeOffer):
pending_offers: List[TradeOffer] = await self.get_pending_offers()
@ -195,7 +235,9 @@ class TradeManager:
return False, None, None
trade_offer = TradeOffer(
created_at_time=uint64(int(time.time())), spend_bundle=spend_bundle
created_at_time=uint64(int(time.time())),
spend_bundle=spend_bundle,
trade_id=spend_bundle.name(),
)
return True, trade_offer, None
except Exception as e:

View File

@ -26,9 +26,6 @@ class TradeRecord(Streamable):
trade_id: bytes32
sent_to: List[Tuple[str, uint8, Optional[str]]]
def name(self) -> bytes32:
return self.spend_bundle.name()
@dataclass(frozen=True)
@streamable
@ -39,6 +36,4 @@ class TradeOffer(Streamable):
created_at_time: uint64
spend_bundle: SpendBundle
def name(self) -> bytes32:
return self.spend_bundle.name()
trade_id: bytes32

View File

@ -342,11 +342,6 @@ class WalletStateManager:
wallet_id
)
# All coins locked by pending orders
offer_locked_coins: Dict[
bytes32, Coin
] = await self.trade_manager.get_locked_coins()
amount: uint64 = uint64(0)
for record in spendable:
@ -362,10 +357,6 @@ class WalletStateManager:
if await self.does_coin_belong_to_wallet(coin, wallet_id):
removal_amount += coin.amount
for name, coin in offer_locked_coins.items():
if await self.does_coin_belong_to_wallet(coin, wallet_id):
removal_amount += coin.amount
result = amount - removal_amount
return uint64(result)
@ -1213,6 +1204,14 @@ class WalletStateManager:
result = await self.puzzle_store.puzzle_hash_exists(addition.puzzle_hash)
return result
async def get_wallet_for_coin(self, coin_id: bytes32) -> Any:
coin_record = await self.wallet_store.get_coin_record(coin_id)
if coin_record is None:
return None
wallet_id = uint32(coin_record.wallet_id)
wallet = self.wallets[wallet_id]
return wallet
async def get_relevant_removals(self, removals: List[Coin]) -> List[Coin]:
""" Returns a list of our unspent coins that are in the passed list. """
@ -1299,7 +1298,21 @@ class WalletStateManager:
valid_index = current_index - coinbase_freeze_period
return await self.wallet_store.get_spendable_for_index(valid_index, wallet_id)
records = await self.wallet_store.get_spendable_for_index(
valid_index, wallet_id
)
offer_locked_coins: Dict[
bytes32, Coin
] = await self.trade_manager.get_locked_coins()
filtered = set()
for record in records:
if record.coin.name() in offer_locked_coins:
continue
filtered.add(record)
return filtered
async def create_action(
self,

View File

@ -2,6 +2,7 @@ import asyncio
import time
from pathlib import Path
from secrets import token_bytes
from typing import Optional
import pytest
@ -9,6 +10,7 @@ from src.simulator.simulator_protocol import FarmNewBlockProtocol
from src.types.peer_info import PeerInfo
from src.util.ints import uint16, uint32, uint64
from src.wallet.trade_manager import TradeManager
from src.wallet.wallet_coin_record import WalletCoinRecord
from tests.setup_nodes import setup_simulators_and_wallets
from src.consensus.block_rewards import calculate_base_fee, calculate_block_reward
from src.wallet.cc_wallet.cc_wallet import CCWallet
@ -55,6 +57,7 @@ class TestCCTrades:
):
yield _
"""
@pytest.mark.asyncio
async def test_cc_trade(self, two_wallet_nodes):
num_blocks = 10
@ -353,9 +356,10 @@ class TestCCTrades:
await time_out_assert(15, cc_wallet_2.get_confirmed_balance, 30)
await time_out_assert(15, cc_wallet_2.get_confirmed_balance, 30)
"""
@pytest.mark.asyncio
async def test_cc_trade_history(self, two_wallet_nodes):
async def test_cc_trade_all(self, two_wallet_nodes):
num_blocks = 10
full_nodes, wallets = two_wallet_nodes
full_node_1, server_1 = full_nodes[0]
@ -395,18 +399,11 @@ class TestCCTrades:
assert cc_wallet.cc_info.my_core is not None
colour = cc_wallet_puzzles.get_genesis_from_core(cc_wallet.cc_info.my_core)
cc_wallet_2: CCWallet = await CCWallet.create_wallet_for_cc(
wallet_node_2.wallet_state_manager, wallet2, colour
)
assert cc_wallet.cc_info.my_core == cc_wallet_2.cc_info.my_core
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph2))
for i in range(1, num_blocks):
await full_node_1.farm_new_block(FarmNewBlockProtocol(ph))
await full_node_1.farm_new_block(FarmNewBlockProtocol(token_bytes()))
trade_manager_1 = await TradeManager.create(wallet_node.wallet_state_manager)
trade_manager_2 = await TradeManager.create(wallet_node_2.wallet_state_manager)
file = "test_offer_file.offer"
file_path = Path(file)
@ -414,12 +411,22 @@ class TestCCTrades:
if file_path.exists():
file_path.unlink()
spendable = await wallet.get_spendable_balance()
offer_dict = {1: 10, 2: -30}
success, trade_offer, error = await trade_manager_1.create_offer_for_ids(
offer_dict, file
)
# Wallet spendable balance should be reduced by D[10] after creating this offer
locked_coin = await trade_manager_1.get_locked_coins(wallet.wallet_info.id)
locked_sum = 0
for name, coin in locked_coin.items():
locked_sum += coin.amount
spendable_after = await wallet.get_spendable_balance()
assert spendable == spendable_after + locked_sum
# await time_out_assert(15, wallet.get_confirmed_balance, 100)
assert success is True
assert trade_offer is not None
@ -451,20 +458,39 @@ class TestCCTrades:
if file_path_1.exists():
file_path_1.unlink()
spendable_before_offer_1 = await wallet.get_spendable_balance()
offer_dict_1 = {1: 11, 2: -33}
success, trade_offer, error = await trade_manager_1.create_offer_for_ids(
success, trade_offer_1, error = await trade_manager_1.create_offer_for_ids(
offer_dict_1, file_1
)
assert success is True
assert trade_offer is not None
assert trade_offer_1 is not None
assert error is None
spendable_after_offer_1 = await wallet.get_spendable_balance()
removal = trade_offer_1.spend_bundle.removals()
locked_sum = 0
for coin in removal:
record: Optional[
WalletCoinRecord
] = await trade_manager_1.wallet_state_manager.wallet_store.get_coin_record_by_coin_id(
coin.name()
)
if record is None:
continue
if record.wallet_id == wallet.wallet_info.id:
locked_sum += coin.amount
assert spendable_before_offer_1 == spendable_after_offer_1 + locked_sum
success, offer_1, error = await trade_manager_1.get_discrepancies_for_offer(
file_path_1
)
pending_offers = await trade_manager_1.get_pending_offers()
pending_bundle_1 = pending_offers[1].spend_bundle
assert len(pending_offers) == 2
(
@ -473,3 +499,10 @@ class TestCCTrades:
error,
) = await trade_manager_1.get_discrepancies_for_spend_bundle(pending_bundle_1)
assert history_offer_1 == offer_1
# Cancel 2d trade offer by just deleting
await trade_manager_1.cancel_pending_offer(pending_offers[1].trade_id)
spendable_after_cancel_1 = await wallet.get_spendable_balance()
assert spendable_before_offer_1 == spendable_after_cancel_1