send to three different peers

This commit is contained in:
Yostra 2020-03-27 00:56:47 -07:00
parent 433b765aff
commit a59f35b142
7 changed files with 159 additions and 23 deletions

View File

@ -497,11 +497,18 @@ class ChiaServer:
elif full_message.function == "pong":
return
f = getattr(api, full_message.function)
if f is None:
raise InvalidProtocolMessage(full_message.function)
f_with_peer_name = getattr(api, full_message.function + "_with_peer_name", None)
if f_with_peer_name is not None:
result = f_with_peer_name(full_message.data, connection.get_peername())
else:
f = getattr(api, full_message.function, None)
if f is None:
raise InvalidProtocolMessage(full_message.function)
result = f(full_message.data)
result = f(full_message.data)
if isinstance(result, AsyncGenerator):
async for outbound_message in result:
yield connection, outbound_message

View File

@ -22,11 +22,12 @@ class TransactionRecord(Streamable):
fee_amount: uint64
incoming: bool
confirmed: bool
sent: bool
sent: uint32
spend_bundle: Optional[SpendBundle]
additions: List[Coin]
removals: List[Coin]
wallet_id: uint64
sent_to: List[str]
def name(self) -> bytes32:
if self.spend_bundle:

View File

@ -38,8 +38,6 @@ class Wallet:
log: logging.Logger
# TODO Don't allow user to send tx until wallet is synced
synced: bool
wallet_info: WalletInfo
@staticmethod
@ -50,7 +48,6 @@ class Wallet:
info: WalletInfo,
name: str = None,
):
# TODO(straya): consider loading farmer keys as well
self = Wallet()
self.config = config
self.key_config = key_config

View File

@ -137,6 +137,13 @@ class WalletNode:
if self.server is None:
return
messages = await self.messages_to_resend()
for msg in messages:
self.server.push_message(msg)
async def messages_to_resend(self) -> List[OutboundMessage]:
messages: List[OutboundMessage] = []
records: List[
TransactionRecord
] = await self.wallet_state_manager.tx_store.get_not_sent()
@ -152,11 +159,19 @@ class WalletNode:
),
Delivery.BROADCAST,
)
self.server.push_message(msg)
messages.append(msg)
return messages
def set_server(self, server: ChiaServer):
self.server = server
async def _on_connect(self) -> OutboundMessageGenerator:
messages = await self.messages_to_resend()
for msg in messages:
yield msg
def _shutdown(self):
print("Shutting down")
self._shut_down = True
@ -220,7 +235,7 @@ class WalletNode:
tasks = []
for peer in to_connect:
tasks.append(
asyncio.create_task(self.server.start_client(peer, None, self.config))
asyncio.create_task(self.server.start_client(peer, self._on_connect(), self.config))
)
await asyncio.gather(*tasks)
@ -587,9 +602,9 @@ class WalletNode:
return
@api_request
async def transaction_ack(self, ack: wallet_protocol.TransactionAck):
async def transaction_ack_with_peer_name(self, ack: wallet_protocol.TransactionAck, name: str):
if ack.status:
await self.wallet_state_manager.remove_from_queue(ack.txid)
await self.wallet_state_manager.remove_from_queue(ack.txid, name)
self.log.info(f"SpendBundle has been received by the FullNode. id: {id}")
else:
self.log.info(f"SpendBundle has been rejected by the FullNode. id: {id}")

View File

@ -321,11 +321,12 @@ class WalletStateManager:
fee_amount=uint64(0),
incoming=True,
confirmed=True,
sent=True,
sent=uint32(0),
spend_bundle=None,
additions=[coin],
removals=[],
wallet_id=wallet_id,
sent_to=[],
)
await self.tx_store.add_transaction_record(tx_record)
else:
@ -346,11 +347,12 @@ class WalletStateManager:
fee_amount=uint64(0),
incoming=True,
confirmed=True,
sent=True,
sent=uint32(0),
spend_bundle=None,
additions=[coin],
removals=[],
wallet_id=wallet_id,
sent_to=[],
)
await self.tx_store.add_transaction_record(tx_record)
@ -401,22 +403,23 @@ class WalletStateManager:
fee_amount=uint64(fee_amount),
incoming=False,
confirmed=False,
sent=False,
sent=uint32(0),
spend_bundle=spend_bundle,
additions=add_list,
removals=rem_list,
wallet_id=wallet_id,
sent_to=[],
)
# Wallet node will use this queue to retry sending this transaction until full nodes receives it
await self.tx_store.add_transaction_record(tx_record)
self.state_changed("pending_transaction")
self.tx_pending_changed()
async def remove_from_queue(self, spendbundle_id: bytes32):
async def remove_from_queue(self, spendbundle_id: bytes32, name: str):
"""
Full node received our transaction, no need to keep it in queue anymore
"""
await self.tx_store.set_sent(spendbundle_id, True)
await self.tx_store.increment_sent(spendbundle_id, name)
self.state_changed("tx_sent")
async def get_send_queue(self) -> List[TransactionRecord]:
@ -1027,7 +1030,7 @@ class WalletStateManager:
return
for record in records:
await self.tx_store.set_sent(record.name(), False)
await self.tx_store.set_not_sent(record.name())
self.tx_pending_changed()

View File

@ -140,6 +140,7 @@ class WalletTransactionStore:
additions=current.additions,
removals=current.removals,
wallet_id=current.wallet_id,
sent_to=current.sent_to,
)
await self.add_transaction_record(tx)
@ -169,9 +170,43 @@ class WalletTransactionStore:
return None
async def set_sent(self, id: bytes32, sent: bool):
async def increment_sent(self, id: bytes32, name: str):
"""
Updates transaction to be sent. (Full Node has received spend_bundle and sent ack).
Updates transaction sent count (Full Node has received spend_bundle and sent ack).
"""
current: Optional[TransactionRecord] = await self.get_transaction_record(id)
if current is None:
return
# Don't increment count if it's already sent to othis peer
if name in current.sent_to:
return
sent_to = current.sent_to.copy()
sent_to.append(name)
tx: TransactionRecord = TransactionRecord(
confirmed_at_index=current.confirmed_at_index,
created_at_time=current.created_at_time,
to_puzzle_hash=current.to_puzzle_hash,
amount=current.amount,
fee_amount=current.fee_amount,
incoming=current.incoming,
confirmed=current.confirmed,
sent=uint32(current.sent + 1),
spend_bundle=current.spend_bundle,
additions=current.additions,
removals=current.removals,
wallet_id=current.wallet_id,
sent_to=sent_to,
)
await self.add_transaction_record(tx)
async def set_not_sent(self, id: bytes32):
"""
Updates transaction sent count to 0.
"""
current: Optional[TransactionRecord] = await self.get_transaction_record(id)
@ -185,11 +220,12 @@ class WalletTransactionStore:
fee_amount=current.fee_amount,
incoming=current.incoming,
confirmed=current.confirmed,
sent=sent,
sent=uint32(0),
spend_bundle=current.spend_bundle,
additions=current.additions,
removals=current.removals,
wallet_id=current.wallet_id,
sent_to=[],
)
await self.add_transaction_record(tx)
@ -216,7 +252,7 @@ class WalletTransactionStore:
"""
cursor = await self.db_connection.execute(
"SELECT * from transaction_record WHERE sent=?", (0,)
"SELECT * from transaction_record WHERE sent<? and confirmed=?", (4, 0,)
)
rows = await cursor.fetchall()
await cursor.close()

View File

@ -3,13 +3,14 @@ from secrets import token_bytes
import pytest
from src.protocols import full_node_protocol
from src.simulator.simulator_protocol import FarmNewBlockProtocol, ReorgProtocol
from src.types.peer_info import PeerInfo
from src.util.ints import uint16, uint32
from tests.setup_nodes import (
setup_node_simulator_and_two_wallets,
setup_node_simulator_and_wallet,
)
setup_three_simulators_and_two_wallets)
from src.consensus.block_rewards import calculate_base_fee, calculate_block_reward
@ -32,6 +33,13 @@ class TestWalletSimulator:
):
yield _
@pytest.fixture(scope="function")
async def three_sim_two_wallets(self):
async for _ in setup_three_simulators_and_two_wallets(
{"COINBASE_FREEZE_PERIOD": 0}
):
yield _
@pytest.mark.asyncio
async def test_wallet_coinbase(self, wallet_node):
num_blocks = 10
@ -152,3 +160,72 @@ class TestWalletSimulator:
)
assert await wallet.get_confirmed_balance() == funds
@pytest.mark.asyncio
async def test_wallet_send_to_three_peers(self, three_sim_two_wallets):
num_blocks = 10
full_nodes, wallets = three_sim_two_wallets
wallet_0, wallet_server_0 = wallets[0]
full_node_0, server_0 = full_nodes[0]
full_node_1, server_1 = full_nodes[1]
full_node_2, server_2 = full_nodes[2]
ph = await wallet_0.main_wallet.get_new_puzzlehash()
# wallet0 <-> sever0
await wallet_server_0.start_client(
PeerInfo(server_0._host, uint16(server_0._port)), None
)
for i in range(1, num_blocks):
await full_node_0.farm_new_block(FarmNewBlockProtocol(ph))
all_blocks = await full_node_0.get_current_blocks(full_node_0.get_tip())
for block in all_blocks:
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(block)
):
pass
async for _ in full_node_2.respond_block(
full_node_protocol.RespondBlock(block)
):
pass
await asyncio.sleep(2)
funds = sum(
[
calculate_base_fee(uint32(i)) + calculate_block_reward(uint32(i))
for i in range(1, num_blocks - 2)
]
)
assert await wallet_0.main_wallet.get_confirmed_balance() == funds
spend_bundle = await wallet_0.main_wallet.generate_signed_transaction(
10, token_bytes(), 0
)
await wallet_0.main_wallet.push_transaction(spend_bundle)
await asyncio.sleep(1)
bundle0 = full_node_0.mempool_manager.get_spendbundle(spend_bundle.name())
assert bundle0 is not None
# wallet0 <-> sever1
await wallet_server_0.start_client(
PeerInfo(server_1._host, uint16(server_1._port)), wallet_0._on_connect
)
await asyncio.sleep(1)
bundle1 = full_node_1.mempool_manager.get_spendbundle(spend_bundle.name())
assert bundle1 is not None
# wallet0 <-> sever2
await wallet_server_0.start_client(
PeerInfo(server_2._host, uint16(server_2._port)), wallet_0._on_connect
)
await asyncio.sleep(1)
bundle2 = full_node_2.mempool_manager.get_spendbundle(spend_bundle.name())
assert bundle2 is not None