Add bulk cancel API (#12427)

* Add bulk cancel API

* Add cancel_all flag

* Add asset ID

* Resolve comments

* Fix pre-commit
This commit is contained in:
Kronus91 2022-08-04 14:31:27 -07:00 committed by GitHub
parent fad414132e
commit fba150627c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 264 additions and 2 deletions

View File

@ -122,6 +122,7 @@ class WalletRpcApi:
"/get_all_offers": self.get_all_offers,
"/get_offers_count": self.get_offers_count,
"/cancel_offer": self.cancel_offer,
"/cancel_offers": self.cancel_offers,
"/get_cat_list": self.get_cat_list,
# DID Wallet
"/did_set_wallet_name": self.did_set_wallet_name,
@ -1128,7 +1129,6 @@ class WalletRpcApi:
secure = request["secure"]
trade_id = bytes32.from_hexstr(request["trade_id"])
fee: uint64 = uint64(request.get("fee", 0))
async with self.service.wallet_state_manager.lock:
if secure:
await wsm.trade_manager.cancel_pending_offer_safely(bytes32(trade_id), fee=fee)
@ -1136,6 +1136,54 @@ class WalletRpcApi:
await wsm.trade_manager.cancel_pending_offer(bytes32(trade_id))
return {}
async def cancel_offers(self, request: Dict) -> EndpointResult:
secure = request["secure"]
batch_fee: uint64 = uint64(request.get("batch_fee", 0))
batch_size = request.get("batch_size", 5)
cancel_all = request.get("cancel_all", False)
if cancel_all:
asset_id = None
else:
asset_id = request.get("asset_id", "xch")
start: int = 0
end: int = start + batch_size
trade_mgr = self.service.wallet_state_manager.trade_manager
log.info(f"Start cancelling offers for {'asset_id: '+asset_id if asset_id is not None else 'all'} ...")
# Traverse offers page by page
key = None
if asset_id is not None and asset_id != "xch":
key = bytes32.from_hexstr(asset_id)
while True:
records: List[TradeRecord] = []
trades = await trade_mgr.trade_store.get_trades_between(
start,
end,
reverse=True,
exclude_my_offers=False,
exclude_taken_offers=True,
include_completed=False,
)
for trade in trades:
if cancel_all:
records.append(trade)
continue
if trade.offer and trade.offer != b"":
offer = Offer.from_bytes(trade.offer)
if key in offer.driver_dict:
records.append(trade)
continue
async with self.service.wallet_state_manager.lock:
await trade_mgr.cancel_pending_offers(records, batch_fee, secure)
log.info(f"Cancelled offers {start} to {end} ...")
# If fewer records were returned than requested, we're done
if len(trades) < batch_size:
break
start = end
end += batch_size
return {"success": True}
##########################################################################################
# Distributed Identities
##########################################################################################

View File

@ -289,6 +289,99 @@ class TradeManager:
return all_txs
async def cancel_pending_offers(
self, trades: List[TradeRecord], fee: uint64 = uint64(0), secure: bool = True
) -> Optional[List[TransactionRecord]]:
"""This will create a transaction that includes coins that were offered"""
all_txs: List[TransactionRecord] = []
bundles: List[SpendBundle] = []
fee_to_pay: uint64 = fee
for trade in trades:
if trade is None:
self.log.error("Cannot find offer, skip cancellation.")
continue
for coin in Offer.from_bytes(trade.offer).get_primary_coins():
wallet = await self.wallet_state_manager.get_wallet_for_coin(coin.name())
if wallet is None:
self.log.error(f"Cannot find wallet for offer {trade.trade_id}, skip cancellation.")
continue
if wallet.type() == WalletType.NFT:
new_ph = await wallet.wallet_state_manager.main_wallet.get_new_puzzlehash()
else:
new_ph = await wallet.get_new_puzzlehash()
# This should probably not switch on whether or not we're spending a XCH but it has to for now
if wallet.type() == WalletType.STANDARD_WALLET:
if fee_to_pay > coin.amount:
selected_coins: Set[Coin] = await wallet.select_coins(
uint64(fee_to_pay - coin.amount),
exclude=[coin],
)
selected_coins.add(coin)
else:
selected_coins = {coin}
tx: TransactionRecord = await wallet.generate_signed_transaction(
uint64(sum([c.amount for c in selected_coins]) - fee_to_pay),
new_ph,
fee=fee_to_pay,
coins=selected_coins,
ignore_max_send_amount=True,
)
if tx is not None and tx.spend_bundle is not None:
bundles.append(tx.spend_bundle)
all_txs.append(dataclasses.replace(tx, spend_bundle=None))
else:
# ATTENTION: new_wallets
txs = await wallet.generate_signed_transaction(
[coin.amount], [new_ph], fee=fee_to_pay, coins={coin}, ignore_max_send_amount=True
)
for tx in txs:
if tx is not None and tx.spend_bundle is not None:
bundles.append(tx.spend_bundle)
all_txs.append(dataclasses.replace(tx, spend_bundle=None))
fee_to_pay = uint64(0)
cancellation_addition = Coin(coin.name(), new_ph, coin.amount)
all_txs.append(
TransactionRecord(
confirmed_at_height=uint32(0),
created_at_time=uint64(int(time.time())),
to_puzzle_hash=new_ph,
amount=uint64(coin.amount),
fee_amount=fee,
confirmed=False,
sent=uint32(10),
spend_bundle=None,
additions=[cancellation_addition],
removals=[coin],
wallet_id=wallet.id(),
sent_to=[],
trade_id=None,
type=uint32(TransactionType.INCOMING_TX.value),
name=cancellation_addition.name(),
memos=[],
)
)
# Aggregate spend bundles to the first tx
if len(all_txs) > 0:
all_txs[0] = dataclasses.replace(all_txs[0], spend_bundle=SpendBundle.aggregate(bundles))
if secure:
for tx in all_txs:
await self.wallet_state_manager.add_pending_transaction(
tx_record=dataclasses.replace(tx, fee_amount=fee)
)
else:
self.wallet_state_manager.state_changed("offer_cancelled")
for trade in trades:
if secure:
await self.trade_store.set_status(trade.trade_id, TradeStatus.PENDING_CANCEL)
else:
await self.trade_store.set_status(trade.trade_id, TradeStatus.CANCELLED)
return all_txs
async def save_trade(self, trade: TradeRecord):
await self.trade_store.add_trade_record(trade)
self.wallet_state_manager.state_changed("offer_added")

View File

@ -1036,5 +1036,126 @@ async def test_nft_offer_sell_cancel(two_wallet_nodes: Any, trusted: Any) -> Non
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await time_out_assert(20, wallets_are_synced, True, [wallet_node_maker], full_node_api)
await time_out_assert(15, get_trade_and_status, TradeStatus.CANCELLED, trade_manager_maker, trade_make)
await time_out_assert(20, get_trade_and_status, TradeStatus.CANCELLED, trade_manager_maker, trade_make)
@pytest.mark.parametrize(
"trusted",
[True],
)
@pytest.mark.asyncio
async def test_nft_offer_sell_cancel_in_batch(two_wallet_nodes: Any, trusted: Any) -> None:
num_blocks = 3
full_nodes, wallets, _ = two_wallet_nodes
full_node_api: FullNodeSimulator = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_maker, server_0 = wallets[0]
wallet_maker = wallet_node_maker.wallet_state_manager.main_wallet
ph_maker = await wallet_maker.get_new_puzzlehash()
ph_token = bytes32(token_bytes())
if trusted:
wallet_node_maker.config["trusted_peers"] = {
full_node_api.full_node.server.node_id.hex(): full_node_api.full_node.server.node_id.hex()
}
else:
wallet_node_maker.config["trusted_peers"] = {}
await server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
for _ in range(1, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph_maker))
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph_token))
funds = sum(
[calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) for i in range(1, num_blocks)]
)
await time_out_assert(10, wallet_maker.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_maker.get_confirmed_balance, funds)
did_wallet_maker: DIDWallet = await DIDWallet.create_new_did_wallet(
wallet_node_maker.wallet_state_manager, wallet_maker, uint64(1)
)
spend_bundle_list = await wallet_node_maker.wallet_state_manager.tx_store.get_unconfirmed_for_wallet(
did_wallet_maker.id()
)
spend_bundle = spend_bundle_list[0].spend_bundle
await time_out_assert_not_none(5, full_node_api.full_node.mempool_manager.get_spendbundle, spend_bundle.name())
for _ in range(1, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph_token))
await time_out_assert(15, wallet_maker.get_pending_change_balance, 0)
await time_out_assert(10, wallet_maker.get_unconfirmed_balance, funds - 1)
await time_out_assert(10, wallet_maker.get_confirmed_balance, funds - 1)
hex_did_id = did_wallet_maker.get_my_DID()
did_id = bytes32.fromhex(hex_did_id)
target_puzhash = ph_maker
royalty_puzhash = ph_maker
royalty_basis_pts = uint16(200)
nft_wallet_maker = await NFTWallet.create_new_nft_wallet(
wallet_node_maker.wallet_state_manager, wallet_maker, name="NFT WALLET DID 1", did_id=did_id
)
metadata = Program.to(
[
("u", ["https://www.chia.net/img/branding/chia-logo.svg"]),
("h", "0xD4584AD463139FA8C0D9F68F4B59F185"),
]
)
sb = await nft_wallet_maker.generate_new_nft(
metadata,
target_puzhash,
royalty_puzhash,
royalty_basis_pts,
did_id,
)
assert sb
# ensure hints are generated
assert compute_memos(sb)
await time_out_assert_not_none(5, full_node_api.full_node.mempool_manager.get_spendbundle, sb.name())
for i in range(1, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph_token))
await time_out_assert(10, len, 1, nft_wallet_maker.my_nft_coins)
# maker create offer: NFT for xch
trade_manager_maker = wallet_maker.wallet_state_manager.trade_manager
coins_maker = nft_wallet_maker.my_nft_coins
assert len(coins_maker) == 1
nft_to_offer = coins_maker[0]
nft_to_offer_info: Optional[PuzzleInfo] = match_puzzle(nft_to_offer.full_puzzle)
nft_to_offer_asset_id: bytes32 = create_asset_id(nft_to_offer_info) # type: ignore
xch_requested = 1000
maker_fee = uint64(433)
offer_did_nft_for_xch = {nft_to_offer_asset_id: -1, wallet_maker.id(): xch_requested}
success, trade_make, error = await trade_manager_maker.create_offer_for_ids(
offer_did_nft_for_xch, {}, fee=maker_fee
)
FEE = uint64(2000000000000)
txs = await trade_manager_maker.cancel_pending_offers([trade_make], fee=FEE, secure=True)
async def get_trade_and_status(trade_manager: Any, trade: Any) -> TradeStatus:
trade_rec = await trade_manager.get_trade_by_id(trade.trade_id)
return TradeStatus(trade_rec.status)
await time_out_assert(15, get_trade_and_status, TradeStatus.PENDING_CANCEL, trade_manager_maker, trade_make)
for tx in txs:
if tx.spend_bundle is not None:
await time_out_assert(15, tx_in_pool, True, full_node_api.full_node.mempool_manager, tx.spend_bundle.name())
for i in range(1, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await time_out_assert(15, get_trade_and_status, TradeStatus.CANCELLED, trade_manager_maker, trade_make)