mirror of
https://github.com/Chia-Network/chia-blockchain.git
synced 2024-11-10 12:29:49 +03:00
send to three different peers
This commit is contained in:
parent
433b765aff
commit
a59f35b142
@ -497,11 +497,18 @@ class ChiaServer:
|
||||
elif full_message.function == "pong":
|
||||
return
|
||||
|
||||
f = getattr(api, full_message.function)
|
||||
if f is None:
|
||||
raise InvalidProtocolMessage(full_message.function)
|
||||
f_with_peer_name = getattr(api, full_message.function + "_with_peer_name", None)
|
||||
|
||||
if f_with_peer_name is not None:
|
||||
result = f_with_peer_name(full_message.data, connection.get_peername())
|
||||
else:
|
||||
f = getattr(api, full_message.function, None)
|
||||
|
||||
if f is None:
|
||||
raise InvalidProtocolMessage(full_message.function)
|
||||
|
||||
result = f(full_message.data)
|
||||
|
||||
result = f(full_message.data)
|
||||
if isinstance(result, AsyncGenerator):
|
||||
async for outbound_message in result:
|
||||
yield connection, outbound_message
|
||||
|
@ -22,11 +22,12 @@ class TransactionRecord(Streamable):
|
||||
fee_amount: uint64
|
||||
incoming: bool
|
||||
confirmed: bool
|
||||
sent: bool
|
||||
sent: uint32
|
||||
spend_bundle: Optional[SpendBundle]
|
||||
additions: List[Coin]
|
||||
removals: List[Coin]
|
||||
wallet_id: uint64
|
||||
sent_to: List[str]
|
||||
|
||||
def name(self) -> bytes32:
|
||||
if self.spend_bundle:
|
||||
|
@ -38,8 +38,6 @@ class Wallet:
|
||||
|
||||
log: logging.Logger
|
||||
|
||||
# TODO Don't allow user to send tx until wallet is synced
|
||||
synced: bool
|
||||
wallet_info: WalletInfo
|
||||
|
||||
@staticmethod
|
||||
@ -50,7 +48,6 @@ class Wallet:
|
||||
info: WalletInfo,
|
||||
name: str = None,
|
||||
):
|
||||
# TODO(straya): consider loading farmer keys as well
|
||||
self = Wallet()
|
||||
self.config = config
|
||||
self.key_config = key_config
|
||||
|
@ -137,6 +137,13 @@ class WalletNode:
|
||||
if self.server is None:
|
||||
return
|
||||
|
||||
messages = await self.messages_to_resend()
|
||||
for msg in messages:
|
||||
self.server.push_message(msg)
|
||||
|
||||
async def messages_to_resend(self) -> List[OutboundMessage]:
|
||||
messages: List[OutboundMessage] = []
|
||||
|
||||
records: List[
|
||||
TransactionRecord
|
||||
] = await self.wallet_state_manager.tx_store.get_not_sent()
|
||||
@ -152,11 +159,19 @@ class WalletNode:
|
||||
),
|
||||
Delivery.BROADCAST,
|
||||
)
|
||||
self.server.push_message(msg)
|
||||
messages.append(msg)
|
||||
|
||||
return messages
|
||||
|
||||
def set_server(self, server: ChiaServer):
|
||||
self.server = server
|
||||
|
||||
async def _on_connect(self) -> OutboundMessageGenerator:
|
||||
messages = await self.messages_to_resend()
|
||||
|
||||
for msg in messages:
|
||||
yield msg
|
||||
|
||||
def _shutdown(self):
|
||||
print("Shutting down")
|
||||
self._shut_down = True
|
||||
@ -220,7 +235,7 @@ class WalletNode:
|
||||
tasks = []
|
||||
for peer in to_connect:
|
||||
tasks.append(
|
||||
asyncio.create_task(self.server.start_client(peer, None, self.config))
|
||||
asyncio.create_task(self.server.start_client(peer, self._on_connect(), self.config))
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
@ -587,9 +602,9 @@ class WalletNode:
|
||||
return
|
||||
|
||||
@api_request
|
||||
async def transaction_ack(self, ack: wallet_protocol.TransactionAck):
|
||||
async def transaction_ack_with_peer_name(self, ack: wallet_protocol.TransactionAck, name: str):
|
||||
if ack.status:
|
||||
await self.wallet_state_manager.remove_from_queue(ack.txid)
|
||||
await self.wallet_state_manager.remove_from_queue(ack.txid, name)
|
||||
self.log.info(f"SpendBundle has been received by the FullNode. id: {id}")
|
||||
else:
|
||||
self.log.info(f"SpendBundle has been rejected by the FullNode. id: {id}")
|
||||
|
@ -321,11 +321,12 @@ class WalletStateManager:
|
||||
fee_amount=uint64(0),
|
||||
incoming=True,
|
||||
confirmed=True,
|
||||
sent=True,
|
||||
sent=uint32(0),
|
||||
spend_bundle=None,
|
||||
additions=[coin],
|
||||
removals=[],
|
||||
wallet_id=wallet_id,
|
||||
sent_to=[],
|
||||
)
|
||||
await self.tx_store.add_transaction_record(tx_record)
|
||||
else:
|
||||
@ -346,11 +347,12 @@ class WalletStateManager:
|
||||
fee_amount=uint64(0),
|
||||
incoming=True,
|
||||
confirmed=True,
|
||||
sent=True,
|
||||
sent=uint32(0),
|
||||
spend_bundle=None,
|
||||
additions=[coin],
|
||||
removals=[],
|
||||
wallet_id=wallet_id,
|
||||
sent_to=[],
|
||||
)
|
||||
await self.tx_store.add_transaction_record(tx_record)
|
||||
|
||||
@ -401,22 +403,23 @@ class WalletStateManager:
|
||||
fee_amount=uint64(fee_amount),
|
||||
incoming=False,
|
||||
confirmed=False,
|
||||
sent=False,
|
||||
sent=uint32(0),
|
||||
spend_bundle=spend_bundle,
|
||||
additions=add_list,
|
||||
removals=rem_list,
|
||||
wallet_id=wallet_id,
|
||||
sent_to=[],
|
||||
)
|
||||
# Wallet node will use this queue to retry sending this transaction until full nodes receives it
|
||||
await self.tx_store.add_transaction_record(tx_record)
|
||||
self.state_changed("pending_transaction")
|
||||
self.tx_pending_changed()
|
||||
|
||||
async def remove_from_queue(self, spendbundle_id: bytes32):
|
||||
async def remove_from_queue(self, spendbundle_id: bytes32, name: str):
|
||||
"""
|
||||
Full node received our transaction, no need to keep it in queue anymore
|
||||
"""
|
||||
await self.tx_store.set_sent(spendbundle_id, True)
|
||||
await self.tx_store.increment_sent(spendbundle_id, name)
|
||||
self.state_changed("tx_sent")
|
||||
|
||||
async def get_send_queue(self) -> List[TransactionRecord]:
|
||||
@ -1027,7 +1030,7 @@ class WalletStateManager:
|
||||
return
|
||||
|
||||
for record in records:
|
||||
await self.tx_store.set_sent(record.name(), False)
|
||||
await self.tx_store.set_not_sent(record.name())
|
||||
|
||||
self.tx_pending_changed()
|
||||
|
||||
|
@ -140,6 +140,7 @@ class WalletTransactionStore:
|
||||
additions=current.additions,
|
||||
removals=current.removals,
|
||||
wallet_id=current.wallet_id,
|
||||
sent_to=current.sent_to,
|
||||
)
|
||||
await self.add_transaction_record(tx)
|
||||
|
||||
@ -169,9 +170,43 @@ class WalletTransactionStore:
|
||||
|
||||
return None
|
||||
|
||||
async def set_sent(self, id: bytes32, sent: bool):
|
||||
async def increment_sent(self, id: bytes32, name: str):
|
||||
"""
|
||||
Updates transaction to be sent. (Full Node has received spend_bundle and sent ack).
|
||||
Updates transaction sent count (Full Node has received spend_bundle and sent ack).
|
||||
"""
|
||||
|
||||
current: Optional[TransactionRecord] = await self.get_transaction_record(id)
|
||||
if current is None:
|
||||
return
|
||||
|
||||
# Don't increment count if it's already sent to othis peer
|
||||
if name in current.sent_to:
|
||||
return
|
||||
|
||||
sent_to = current.sent_to.copy()
|
||||
sent_to.append(name)
|
||||
|
||||
tx: TransactionRecord = TransactionRecord(
|
||||
confirmed_at_index=current.confirmed_at_index,
|
||||
created_at_time=current.created_at_time,
|
||||
to_puzzle_hash=current.to_puzzle_hash,
|
||||
amount=current.amount,
|
||||
fee_amount=current.fee_amount,
|
||||
incoming=current.incoming,
|
||||
confirmed=current.confirmed,
|
||||
sent=uint32(current.sent + 1),
|
||||
spend_bundle=current.spend_bundle,
|
||||
additions=current.additions,
|
||||
removals=current.removals,
|
||||
wallet_id=current.wallet_id,
|
||||
sent_to=sent_to,
|
||||
)
|
||||
|
||||
await self.add_transaction_record(tx)
|
||||
|
||||
async def set_not_sent(self, id: bytes32):
|
||||
"""
|
||||
Updates transaction sent count to 0.
|
||||
"""
|
||||
|
||||
current: Optional[TransactionRecord] = await self.get_transaction_record(id)
|
||||
@ -185,11 +220,12 @@ class WalletTransactionStore:
|
||||
fee_amount=current.fee_amount,
|
||||
incoming=current.incoming,
|
||||
confirmed=current.confirmed,
|
||||
sent=sent,
|
||||
sent=uint32(0),
|
||||
spend_bundle=current.spend_bundle,
|
||||
additions=current.additions,
|
||||
removals=current.removals,
|
||||
wallet_id=current.wallet_id,
|
||||
sent_to=[],
|
||||
)
|
||||
await self.add_transaction_record(tx)
|
||||
|
||||
@ -216,7 +252,7 @@ class WalletTransactionStore:
|
||||
"""
|
||||
|
||||
cursor = await self.db_connection.execute(
|
||||
"SELECT * from transaction_record WHERE sent=?", (0,)
|
||||
"SELECT * from transaction_record WHERE sent<? and confirmed=?", (4, 0,)
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
await cursor.close()
|
||||
|
@ -3,13 +3,14 @@ from secrets import token_bytes
|
||||
|
||||
import pytest
|
||||
|
||||
from src.protocols import full_node_protocol
|
||||
from src.simulator.simulator_protocol import FarmNewBlockProtocol, ReorgProtocol
|
||||
from src.types.peer_info import PeerInfo
|
||||
from src.util.ints import uint16, uint32
|
||||
from tests.setup_nodes import (
|
||||
setup_node_simulator_and_two_wallets,
|
||||
setup_node_simulator_and_wallet,
|
||||
)
|
||||
setup_three_simulators_and_two_wallets)
|
||||
from src.consensus.block_rewards import calculate_base_fee, calculate_block_reward
|
||||
|
||||
|
||||
@ -32,6 +33,13 @@ class TestWalletSimulator:
|
||||
):
|
||||
yield _
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def three_sim_two_wallets(self):
|
||||
async for _ in setup_three_simulators_and_two_wallets(
|
||||
{"COINBASE_FREEZE_PERIOD": 0}
|
||||
):
|
||||
yield _
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_coinbase(self, wallet_node):
|
||||
num_blocks = 10
|
||||
@ -152,3 +160,72 @@ class TestWalletSimulator:
|
||||
)
|
||||
|
||||
assert await wallet.get_confirmed_balance() == funds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_wallet_send_to_three_peers(self, three_sim_two_wallets):
|
||||
num_blocks = 10
|
||||
full_nodes, wallets = three_sim_two_wallets
|
||||
|
||||
wallet_0, wallet_server_0 = wallets[0]
|
||||
full_node_0, server_0 = full_nodes[0]
|
||||
full_node_1, server_1 = full_nodes[1]
|
||||
full_node_2, server_2 = full_nodes[2]
|
||||
|
||||
ph = await wallet_0.main_wallet.get_new_puzzlehash()
|
||||
|
||||
# wallet0 <-> sever0
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_0._host, uint16(server_0._port)), None
|
||||
)
|
||||
|
||||
for i in range(1, num_blocks):
|
||||
await full_node_0.farm_new_block(FarmNewBlockProtocol(ph))
|
||||
|
||||
all_blocks = await full_node_0.get_current_blocks(full_node_0.get_tip())
|
||||
|
||||
for block in all_blocks:
|
||||
async for _ in full_node_1.respond_block(
|
||||
full_node_protocol.RespondBlock(block)
|
||||
):
|
||||
pass
|
||||
async for _ in full_node_2.respond_block(
|
||||
full_node_protocol.RespondBlock(block)
|
||||
):
|
||||
pass
|
||||
|
||||
await asyncio.sleep(2)
|
||||
funds = sum(
|
||||
[
|
||||
calculate_base_fee(uint32(i)) + calculate_block_reward(uint32(i))
|
||||
for i in range(1, num_blocks - 2)
|
||||
]
|
||||
)
|
||||
assert await wallet_0.main_wallet.get_confirmed_balance() == funds
|
||||
|
||||
spend_bundle = await wallet_0.main_wallet.generate_signed_transaction(
|
||||
10, token_bytes(), 0
|
||||
)
|
||||
await wallet_0.main_wallet.push_transaction(spend_bundle)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
bundle0 = full_node_0.mempool_manager.get_spendbundle(spend_bundle.name())
|
||||
assert bundle0 is not None
|
||||
|
||||
# wallet0 <-> sever1
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_1._host, uint16(server_1._port)), wallet_0._on_connect
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
bundle1 = full_node_1.mempool_manager.get_spendbundle(spend_bundle.name())
|
||||
assert bundle1 is not None
|
||||
|
||||
# wallet0 <-> sever2
|
||||
await wallet_server_0.start_client(
|
||||
PeerInfo(server_2._host, uint16(server_2._port)), wallet_0._on_connect
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
bundle2 = full_node_2.mempool_manager.get_spendbundle(spend_bundle.name())
|
||||
assert bundle2 is not None
|
||||
|
Loading…
Reference in New Issue
Block a user