Ms.networking2 (#284)

* Improve test speed with smaller discriminants, less blocks, less keys, smaller plots
* Add new RPC files
* Refactor RPC servers and clients
* Removed websocket server
* Fixing websocket issues
* Fix more bugs
* Migration
* Try to fix introducer memory leak
* More logging
* Start client instead of open connection
* No drain
* remove testing deps
* Support timeout
* Fix python black
* Richard fixes
* Don't always auth, change testing code, fix synced display
* Don't keep connections alive introducer
* Fix more LGTM alerts
* Fix wrong import clvm_tools
* Fix spelling mistakes
* Setup nodes fully using Service code
* Log rotation and fix test
This commit is contained in:
Mariano Sorgente 2020-06-17 08:46:51 +09:00 committed by Gene Hoffman
parent 5d582d58ab
commit 35822c8796
79 changed files with 1737 additions and 2156 deletions

View File

@ -9,19 +9,13 @@ const electron = require("electron");
const app = electron.app;
const BrowserWindow = electron.BrowserWindow;
const path = require("path");
const WebSocket = require("ws");
const ipcMain = require("electron").ipcMain;
const config = require("./config");
const dev_config = require("./dev_config");
const local_test = config.local_test;
const redux_tool = dev_config.redux_tool;
var url = require("url");
const Tail = require("tail").Tail;
const os = require("os");
// Only takes effect if local_test is false. Connects to a local introducer.
var local_introducer = false;
global.sharedObj = { local_test: local_test };
/*************************************************************
@ -37,6 +31,7 @@ const PY_MODULE = "server"; // without .py suffix
let pyProc = null;
const guessPackaged = () => {
let packed;
if (process.platform === "win32") {
const fullPath = path.join(__dirname, PY_WIN_DIST_FOLDER);
packed = require("fs").existsSync(fullPath);
@ -63,7 +58,7 @@ const getScriptPath = () => {
const createPyProc = () => {
let script = getScriptPath();
processOptions = {};
let processOptions = {};
//processOptions.detached = true;
//processOptions.stdio = "ignore";
pyProc = null;
@ -111,8 +106,8 @@ const exitPyProc = () => {
if (pyProc != null) {
if (process.platform === "win32") {
process.stdout.write("Killing daemon on windows");
var cp = require('child_process');
cp.execSync('taskkill /PID ' + pyProc.pid + ' /T /F')
var cp = require("child_process");
cp.execSync("taskkill /PID " + pyProc.pid + " /T /F");
} else {
process.stdout.write("Killing daemon on other platforms");
pyProc.kill();

View File

@ -130,10 +130,29 @@ async function track_progress(store, location) {
}
}
function refreshAllState(store) {
store.dispatch(format_message("get_wallets", {}));
let start_farmer = startService(service_farmer);
let start_harvester = startService(service_harvester);
store.dispatch(start_farmer);
store.dispatch(start_harvester);
store.dispatch(get_height_info());
store.dispatch(get_sync_status());
store.dispatch(get_connection_info());
store.dispatch(getBlockChainState());
store.dispatch(getLatestBlocks());
store.dispatch(getFullNodeConnections());
store.dispatch(getLatestChallenges());
store.dispatch(getFarmerConnections());
store.dispatch(getPlots());
store.dispatch(isServiceRunning(service_plotter));
}
export const handle_message = (store, payload) => {
store.dispatch(incomingMessage(payload));
if (payload.command === "ping") {
if (payload.origin === service_wallet_server) {
store.dispatch(get_connection_info());
store.dispatch(format_message("get_public_keys", {}));
} else if (payload.origin === service_full_node) {
store.dispatch(getBlockChainState());
@ -147,28 +166,12 @@ export const handle_message = (store, payload) => {
}
} else if (payload.command === "log_in") {
if (payload.data.success) {
store.dispatch(format_message("get_wallets", {}));
let start_farmer = startService(service_farmer);
let start_harvester = startService(service_harvester);
store.dispatch(start_farmer);
store.dispatch(start_harvester);
store.dispatch(get_height_info());
store.dispatch(get_sync_status());
store.dispatch(get_connection_info());
store.dispatch(isServiceRunning(service_plotter));
refreshAllState(store);
}
} else if (payload.command === "add_key") {
if (payload.data.success) {
store.dispatch(format_message("get_wallets", {}));
store.dispatch(format_message("get_public_keys", {}));
store.dispatch(get_height_info());
store.dispatch(get_sync_status());
store.dispatch(get_connection_info());
let start_farmer = startService(service_farmer);
let start_harvester = startService(service_harvester);
store.dispatch(start_farmer);
store.dispatch(start_harvester);
store.dispatch(isServiceRunning(service_plotter));
refreshAllState(store);
}
} else if (payload.command === "delete_key") {
if (payload.data.success) {
@ -224,8 +227,8 @@ export const handle_message = (store, payload) => {
} else if (payload.command === "create_new_wallet") {
if (payload.data.success) {
store.dispatch(format_message("get_wallets", {}));
store.dispatch(createState(true, false));
}
store.dispatch(createState(true, false));
} else if (payload.command === "cc_set_name") {
if (payload.data.success) {
const wallet_id = payload.data.wallet_id;
@ -236,19 +239,6 @@ export const handle_message = (store, payload) => {
store.dispatch(openDialog("Success!", "Offer accepted"));
}
store.dispatch(resetTrades());
} else if (payload.command === "get_wallets") {
if (payload.data.success) {
const wallets = payload.data.wallets;
for (let wallet of wallets) {
store.dispatch(get_balance_for_wallet(wallet.id));
store.dispatch(get_transactions(wallet.id));
store.dispatch(get_puzzle_hash(wallet.id));
if (wallet.type === "COLOURED_COIN") {
store.dispatch(get_colour_name(wallet.id));
store.dispatch(get_colour_info(wallet.id));
}
}
}
} else if (payload.command === "get_discrepancies_for_offer") {
if (payload.data.success) {
store.dispatch(offerParsed(payload.data.discrepancies));
@ -306,7 +296,7 @@ export const handle_message = (store, payload) => {
}
if (payload.data.success === false) {
if (payload.data.reason) {
store.dispatch(openDialog("Error?", payload.data.reason));
store.dispatch(openDialog("Error: ", payload.data.reason));
}
}
};

View File

@ -53,7 +53,6 @@ export const tradeReducer = (state = { ...initial_state }, action) => {
new_trades.push(trade);
return { ...state, trades: new_trades };
case "RESET_TRADE":
trade = [];
state = { ...initial_state };
return state;
case "OFFER_PARSING":

View File

@ -177,7 +177,7 @@ export const incomingReducer = (state = { ...initial_state }, action) => {
// console.log("wallet_id here: " + id);
wallet.puzzle_hash = puzzle_hash;
return { ...state };
} else if (command === "get_connection_info") {
} else if (command === "get_connections") {
if (data.success || data.connections) {
const connections = data.connections;
state.status["connections"] = connections;
@ -189,7 +189,6 @@ export const incomingReducer = (state = { ...initial_state }, action) => {
state.status["height"] = height;
return { ...state };
} else if (command === "get_sync_status") {
// console.log("command get_sync_status");
if (data.success) {
const syncing = data.syncing;
state.status["syncing"] = syncing;

View File

@ -134,7 +134,7 @@ export const get_sync_status = () => {
export const get_connection_info = () => {
var action = walletMessage();
action.message.command = "get_connection_info";
action.message.command = "get_connections";
action.message.data = {};
return action;
};

View File

@ -330,18 +330,9 @@ const BalanceCard = props => {
const balancebox_unit = " " + cc_unit;
const balancebox_hline =
"<tr><td colspan='2' style='text-align:center'><hr width='50%'></td></tr>";
const balance_ptotal_chia = mojo_to_colouredcoin_string(
balance_ptotal,
"mojo"
);
const balance_pending_chia = mojo_to_colouredcoin_string(
balance_pending,
"mojo"
);
const balance_change_chia = mojo_to_colouredcoin_string(
balance_change,
"mojo"
);
const balance_ptotal_chia = mojo_to_colouredcoin_string(balance_ptotal);
const balance_pending_chia = mojo_to_colouredcoin_string(balance_pending);
const balance_change_chia = mojo_to_colouredcoin_string(balance_change);
const acc_content =
balancebox_1 +
balancebox_2 +

View File

@ -2,33 +2,28 @@ import React, { Component } from "react";
import Button from "@material-ui/core/Button";
import CssBaseline from "@material-ui/core/CssBaseline";
import TextField from "@material-ui/core/TextField";
import Link from "@material-ui/core/Link";
import Grid from "@material-ui/core/Grid";
import Typography from "@material-ui/core/Typography";
import {
withTheme,
useTheme,
withStyles,
makeStyles
} from "@material-ui/styles";
import { withTheme, withStyles, makeStyles } from "@material-ui/styles";
import Container from "@material-ui/core/Container";
import ArrowBackIosIcon from "@material-ui/icons/ArrowBackIos";
import { connect, useSelector, useDispatch } from "react-redux";
import { genereate_mnemonics } from "../modules/message";
import { withRouter } from "react-router-dom";
function Copyright() {
return (
<Typography variant="body2" color="textSecondary" align="center">
{"Copyright © "}
<Link color="inherit" href="https://chia.net">
Your Website
</Link>{" "}
{new Date().getFullYear()}
{"."}
</Typography>
);
}
// function Copyright() {
// return (
// <Typography variant="body2" color="textSecondary" align="center">
// {"Copyright © "}
// <Link color="inherit" href="https://chia.net">
// Your Website
// </Link>{" "}
// {new Date().getFullYear()}
// {"."}
// </Typography>
// );
// }
const CssTextField = withStyles({
root: {
"& MuiFormLabel-root": {
@ -143,32 +138,9 @@ class MnemonicLabel extends Component {
}
}
class MnemonicGrid extends Component {
render() {
return (
<Grid item xs={2}>
<CssTextField
variant="outlined"
margin="normal"
disabled
fullWidth
color="primary"
id="email"
label={this.props.index}
name="email"
autoComplete="email"
autoFocus
defaultValue={this.props.word}
/>
</Grid>
);
}
}
const UIPart = () => {
const words = useSelector(state => state.wallet_state.mnemonic);
const classes = useStyles();
const theme = useTheme();
return (
<div className={classes.root}>
<ArrowBackIosIcon className={classes.navigator}> </ArrowBackIosIcon>
@ -209,9 +181,4 @@ const CreateMnemonics = () => {
return UIPart();
};
const mapStateToProps = state => {
return {
mnemonic: state.wallet_state.mnemonic
};
};
export default withTheme(withRouter(connect()(CreateMnemonics)));

View File

@ -268,7 +268,7 @@ const Challenges = props => {
>
<TableHead>
<TableRow>
<TableCell>Challange hash</TableCell>
<TableCell>Challenge hash</TableCell>
<TableCell align="right">Height</TableCell>
<TableCell align="right">Number of proofs</TableCell>
<TableCell align="right">Best estimate</TableCell>

View File

@ -254,7 +254,7 @@ const BalanceCard = props => {
title="Spendable Balance"
balance={balance_spendable}
tooltip={
"This is the amount of Chia that you can currently use to make transactions. It does not include pending farming rewards, pending incoming transctions, and Chia that you have just spend but is not yet in the blockchain."
"This is the amount of Chia that you can currently use to make transactions. It does not include pending farming rewards, pending incoming transctions, and Chia that you have just spent but is not yet in the blockchain."
}
/>
<Grid item xs={12}>
@ -388,7 +388,7 @@ const SendCard = props => {
}
return (
<Paper className={(classes.paper, classes.sendCard)}>
<Paper className={classes.paper}>
<Grid container spacing={0}>
<Grid item xs={12}>
<div className={classes.cardTitle}>
@ -478,7 +478,7 @@ const HistoryCard = props => {
var id = props.wallet_id;
const classes = useStyles();
return (
<Paper className={(classes.paper, classes.sendCard)}>
<Paper className={classes.paper}>
<Grid container spacing={0}>
<Grid item xs={12}>
<div className={classes.cardTitle}>
@ -588,7 +588,7 @@ const AddressCard = props => {
}
return (
<Paper className={(classes.paper, classes.sendCard)}>
<Paper className={classes.paper}>
<Grid container spacing={0}>
<Grid item xs={12}>
<div className={classes.cardTitle}>

View File

@ -227,7 +227,7 @@ const OfferView = () => {
}
return (
<Paper className={(classes.paper, classes.balancePaper)}>
<Paper className={classes.paper}>
<Grid container spacing={0}>
<Grid item xs={12}>
<div className={classes.cardTitle}>
@ -311,7 +311,7 @@ const DropView = () => {
: { visibility: "hidden" };
return (
<Paper className={(classes.paper, classes.balancePaper)}>
<Paper className={classes.paper}>
<Grid container spacing={0}>
<Grid item xs={12}>
<div className={classes.cardTitle}>
@ -415,7 +415,7 @@ const CreateOffer = () => {
}
return (
<Paper className={(classes.paper, classes.balancePaper)}>
<Paper className={classes.paper}>
<Grid container spacing={0}>
<Grid item xs={12}>
<div className={classes.cardTitle}>

View File

@ -15,7 +15,7 @@ module.exports = {
const updateDotExe = path.resolve(path.join(rootAtomFolder, "Update.exe"));
const exeName = path.basename(process.execPath);
const spawn = function(command, args) {
let spawnedProcess, error;
let spawnedProcess;
try {
spawnedProcess = ChildProcess.spawn(command, args, { detached: true });

View File

@ -21,6 +21,7 @@ dependencies = [
"keyring_jeepney==0.2",
"keyrings.cryptfile==1.3.4",
"cryptography==2.9.2", #Python cryptography library for TLS
"concurrent-log-handler==0.9.16", # Log to a file concurrently and rotate logs
]
upnp_dependencies = [
@ -40,7 +41,7 @@ kwargs = dict(
name="chia-blockchain",
author="Mariano Sorgente",
author_email="mariano@chia.net",
description="Chia proof of space plotting, proving, and verifying (wraps C++)",
description="Chia blockchain full node, farmer, timelord, and wallet.",
url="https://chia.net/",
license="Apache License",
python_requires=">=3.7, <4",

View File

@ -28,7 +28,7 @@ def main():
plot_config = load_config(root_path, plot_config_filename)
config = load_config(root_path, config_filename)
initialize_logging("%(name)-22s", {"log_stdout": True}, root_path)
initialize_logging("check_plots", {"log_stdout": True}, root_path)
log = logging.getLogger(__name__)
v = Verifier()

View File

@ -238,7 +238,16 @@ def chia_init(root_path: Path):
PATH_MANIFEST_LIST: List[Tuple[Path, List[str]]] = [
(Path(os.path.expanduser("~/.chia/beta-%s" % _)), MANIFEST)
for _ in ["1.0b7", "1.0b6", "1.0b5", "1.0b5.dev0", "1.0b4", "1.0b3", "1.0b2", "1.0b1"]
for _ in [
"1.0b7",
"1.0b6",
"1.0b5",
"1.0b5.dev0",
"1.0b4",
"1.0b3",
"1.0b2",
"1.0b1",
]
]
for old_path, manifest in PATH_MANIFEST_LIST:

View File

@ -205,7 +205,7 @@ async def show_async(args, parser):
print(f"Connecting to {ip}, {port}")
try:
await client.open_connection(ip, int(port))
except BaseException:
except Exception:
# TODO: catch right exception
print(f"Failed to connect to {ip}:{port}")
if args.remove_connection:
@ -221,7 +221,7 @@ async def show_async(args, parser):
)
try:
await client.close_connection(con["node_id"])
except BaseException:
except Exception:
result_txt = (
f"Failed to disconnect NodeID {args.remove_connection}"
)

View File

@ -66,6 +66,7 @@ async def async_start(args, parser):
else:
error = msg["data"]["error"]
print(f"{service} failed to start. Error: {error}")
await daemon.close()
def start(args, parser):

View File

@ -39,6 +39,7 @@ async def async_stop(args, parser):
if args.daemon:
r = await daemon.exit()
await daemon.close()
print(f"daemon: {r}")
return 0
@ -54,6 +55,7 @@ async def async_stop(args, parser):
print("stop failed")
return_val = 1
await daemon.close()
return return_val

View File

@ -0,0 +1,19 @@
from typing import Dict, Any
from src.util.ints import uint32
def find_fork_point_in_chain(hash_to_block: Dict, block_1: Any, block_2: Any) -> uint32:
""" Tries to find height where new chain (block_2) diverged from block_1 (assuming prev blocks
are all included in chain)"""
while block_2.height > 0 or block_1.height > 0:
if block_2.height > block_1.height:
block_2 = hash_to_block[block_2.prev_header_hash]
elif block_1.height > block_2.height:
block_1 = hash_to_block[block_1.prev_header_hash]
else:
if block_2.header_hash == block_1.header_hash:
return block_2.height
block_2 = hash_to_block[block_2.prev_header_hash]
block_1 = hash_to_block[block_1.prev_header_hash]
assert block_2 == block_1 # Genesis block is the same, genesis fork
return uint32(0)

View File

@ -25,7 +25,10 @@ class DaemonProxy:
async def listener():
while True:
message = await self.websocket.recv()
try:
message = await self.websocket.recv()
except websockets.exceptions.ConnectionClosedOK:
return
decoded = json.loads(message)
id = decoded["request_id"]
@ -84,6 +87,9 @@ class DaemonProxy:
response = await self._get(request)
return response
async def close(self):
await self.websocket.close()
async def exit(self):
request = self.format_request("exit", {})
return await self._get(request)
@ -113,5 +119,5 @@ async def connect_to_daemon_and_validate(root_path):
except Exception as ex:
# ConnectionRefusedError means that daemon is not yet running
if not isinstance(ex, ConnectionRefusedError):
print("Exception connecting to daemon: {ex}")
print(f"Exception connecting to daemon: {ex}")
return None

View File

@ -101,8 +101,8 @@ class WebSocketServer:
self.log.info("Daemon WebSocketServer closed")
async def stop(self):
self.websocket_server.close()
await self.exit()
self.websocket_server.close()
async def safe_handle(self, websocket, path):
async for message in websocket:
@ -110,20 +110,22 @@ class WebSocketServer:
decoded = json.loads(message)
# self.log.info(f"Message received: {decoded}")
await self.handle_message(websocket, decoded)
except (BaseException, websockets.exceptions.ConnectionClosed) as e:
if isinstance(e, websockets.exceptions.ConnectionClosed):
service_name = self.remote_address_map[websocket.remote_address[1]]
self.log.info(
f"ConnectionClosed. Closing websocket with {service_name}"
)
if service_name in self.connections:
self.connections.pop(service_name)
await websocket.close()
else:
tb = traceback.format_exc()
self.log.error(f"Error while handling message: {tb}")
error = {"success": False, "error": f"{e}"}
await websocket.send(format_response(message, error))
except (
websockets.exceptions.ConnectionClosed,
websockets.exceptions.ConnectionClosedOK,
) as e:
service_name = self.remote_address_map[websocket.remote_address[1]]
self.log.info(
f"ConnectionClosed. Closing websocket with {service_name} {e}"
)
if service_name in self.connections:
self.connections.pop(service_name)
await websocket.close()
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Error while handling message: {tb}")
error = {"success": False, "error": f"{e}"}
await websocket.send(format_response(message, error))
async def ping_task(self):
await asyncio.sleep(30)
@ -132,7 +134,7 @@ class WebSocketServer:
connection = self.connections[service_name]
self.log.info(f"About to ping: {service_name}")
await connection.ping()
except (BaseException, websockets.exceptions.ConnectionClosed) as e:
except Exception as e:
self.log.info(f"Ping error: {e}")
self.connections.pop(service_name)
self.remote_address_map.pop(remote_address)
@ -164,14 +166,17 @@ class WebSocketServer:
elif command == "is_running":
response = await self.is_running(data)
elif command == "exit":
response = await self.exit()
response = await self.stop()
elif command == "register_service":
response = await self.register_service(websocket, data)
else:
response = {"success": False, "error": f"unknown_command {command}"}
full_response = format_response(message, response)
await websocket.send(full_response)
try:
await websocket.send(full_response)
except websockets.exceptions.ConnectionClosedOK:
pass
async def ping(self):
response = {"success": True, "value": "pong"}
@ -312,7 +317,10 @@ class WebSocketServer:
destination = message["destination"]
if destination in self.connections:
socket = self.connections[destination]
await socket.send(dict_to_json_str(message))
try:
await socket.send(dict_to_json_str(message))
except websockets.exceptions.ConnectionClosedOK:
pass
return None
@ -525,7 +533,7 @@ def singleton(lockfile, text="semaphore"):
async def async_run_daemon(root_path):
chia_init(root_path)
config = load_config(root_path, "config.yaml")
initialize_logging("daemon %(name)-25s", config["logging"], root_path)
initialize_logging("daemon", config["logging"], root_path)
lockfile = singleton(daemon_launch_lock_path(root_path))
if lockfile is None:
print("daemon: already launching")

View File

@ -84,10 +84,10 @@ class Farmer:
NodeType.HARVESTER, Message("harvester_handshake", msg), Delivery.RESPOND
)
def set_global_connections(self, global_connections: PeerConnections):
def _set_global_connections(self, global_connections: PeerConnections):
self.global_connections: PeerConnections = global_connections
def set_server(self, server):
def _set_server(self, server):
self.server = server
def _set_state_changed_callback(self, callback: Callable):

View File

@ -168,9 +168,7 @@ class BlockStore:
return self.proof_of_time_heights[pot_tuple]
return None
def seen_compact_proof(
self, challenge: bytes32, iter: uint64
) -> bool:
def seen_compact_proof(self, challenge: bytes32, iter: uint64) -> bool:
pot_tuple = (challenge, iter)
if pot_tuple in self.seen_compact_proofs:
return True

View File

@ -35,6 +35,7 @@ from src.util.errors import ConsensusError, Err
from src.util.hash import std_hash
from src.util.ints import uint32, uint64
from src.util.merkle_set import MerkleSet
from src.consensus.find_fork_point import find_fork_point_in_chain
log = logging.getLogger(__name__)
@ -406,7 +407,7 @@ class Blockchain:
# If LCA changed update the unspent store
elif old_lca.header_hash != self.lca_block.header_hash:
# New LCA is lower height but not the a parent of old LCA (Reorg)
fork_h = self._find_fork_point_in_chain(old_lca, self.lca_block)
fork_h = find_fork_point_in_chain(self.headers, old_lca, self.lca_block)
# Rollback to fork
await self.coin_store.rollback_lca_to_block(fork_h)
@ -452,22 +453,6 @@ class Blockchain:
curr_new = self.headers[curr_new.prev_header_hash]
curr_old = self.headers[curr_old.prev_header_hash]
def _find_fork_point_in_chain(self, block_1: Header, block_2: Header) -> uint32:
""" Tries to find height where new chain (block_2) diverged from block_1 (assuming prev blocks
are all included in chain)"""
while block_2.height > 0 or block_1.height > 0:
if block_2.height > block_1.height:
block_2 = self.headers[block_2.prev_header_hash]
elif block_1.height > block_2.height:
block_1 = self.headers[block_1.prev_header_hash]
else:
if block_2.header_hash == block_1.header_hash:
return block_2.height
block_2 = self.headers[block_2.prev_header_hash]
block_1 = self.headers[block_1.prev_header_hash]
assert block_2 == block_1 # Genesis block is the same, genesis fork
return uint32(0)
async def _create_diffs_for_tips(self, target: Header):
""" Adds to unspent store from tips down to target"""
for tip in self.tips:
@ -715,7 +700,7 @@ class Blockchain:
return Err.DOUBLE_SPEND
# Check if removals exist and were not previously spend. (unspent_db + diff_store + this_block)
fork_h = self._find_fork_point_in_chain(self.lca_block, block.header)
fork_h = find_fork_point_in_chain(self.headers, self.lca_block, block.header)
# Get additions and removals since (after) fork_h but not including this block
additions_since_fork: Dict[bytes32, Tuple[Coin, uint32]] = {}

View File

@ -4,9 +4,8 @@ import logging
import traceback
import time
import random
from asyncio import Event
from pathlib import Path
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Type, Callable
from typing import AsyncGenerator, Dict, List, Optional, Tuple, Callable
import aiosqlite
from chiabip158 import PyBIP158
@ -97,13 +96,7 @@ class FullNode:
self.db_path = path_from_root(root_path, config["database_path"])
mkdir(self.db_path.parent)
@classmethod
async def create(cls: Type, *args, **kwargs):
_ = cls(*args, **kwargs)
await _.start()
return _
async def start(self):
async def _start(self):
# create the store (db) and full node instance
self.connection = await aiosqlite.connect(self.db_path)
self.block_store = await BlockStore.create(self.connection)
@ -128,7 +121,7 @@ class FullNode:
self.broadcast_uncompact_blocks(uncompact_interval)
)
def set_global_connections(self, global_connections: PeerConnections):
def _set_global_connections(self, global_connections: PeerConnections):
self.global_connections = global_connections
def _set_server(self, server: ChiaServer):
@ -334,7 +327,7 @@ class FullNode:
f"Tip block {tip_block.header_hash} tip height {tip_block.height}"
)
self.sync_store.set_potential_hashes_received(Event())
self.sync_store.set_potential_hashes_received(asyncio.Event())
sleep_interval = 10
total_time_slept = 0
@ -885,6 +878,8 @@ class FullNode:
self.log.info("Scanning the blockchain for uncompact blocks.")
for h in range(min_height, max_height):
if self._shut_down:
return
blocks: List[FullBlock] = await self.block_store.get_blocks_at(
[uint32(h)]
)
@ -895,27 +890,18 @@ class FullNode:
if block.proof_of_time.witness_type != 0:
challenge_msg = timelord_protocol.ChallengeStart(
block.proof_of_time.challenge_hash,
block.weight,
block.proof_of_time.challenge_hash, block.weight,
)
pos_info_msg = timelord_protocol.ProofOfSpaceInfo(
block.proof_of_time.challenge_hash,
block.proof_of_time.number_of_iterations,
)
broadcast_list.append(
(
challenge_msg,
pos_info_msg,
)
)
broadcast_list.append((challenge_msg, pos_info_msg,))
# Scan only since the first uncompact block we know about.
# No block earlier than this will be uncompact in the future,
# unless a reorg happens. The range to scan next time
# is always at least 200 blocks, to protect against reorgs.
if (
uncompact_blocks == 0
and h <= max(1, max_height - 200)
):
if uncompact_blocks == 0 and h <= max(1, max_height - 200):
new_min_height = h
uncompact_blocks += 1
@ -946,7 +932,9 @@ class FullNode:
delivery,
)
)
self.log.info(f"Broadcasted {len(broadcast_list)} uncompact blocks to timelords.")
self.log.info(
f"Broadcasted {len(broadcast_list)} uncompact blocks to timelords."
)
await asyncio.sleep(uncompact_interval)
@api_request
@ -1573,7 +1561,7 @@ class FullNode:
yield ret_msg
except asyncio.CancelledError:
self.log.error("Syncing failed, CancelledError")
except BaseException as e:
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Error with syncing: {type(e)}{tb}")
finally:
@ -1786,7 +1774,6 @@ class FullNode:
) -> OutboundMessageGenerator:
# Ignore if syncing
if self.sync_store.get_sync_mode():
cost = None
status = MempoolInclusionStatus.FAILED
error: Optional[Err] = Err.UNKNOWN
else:

View File

@ -40,6 +40,8 @@ class SyncBlocksProcessor:
for batch_start_height in range(
self.fork_height + 1, self.tip_height + 1, self.BATCH_SIZE
):
if self._shut_down:
return
total_time_slept = 0
batch_end_height = min(
batch_start_height + self.BATCH_SIZE - 1, self.tip_height

View File

@ -78,7 +78,7 @@ class Harvester:
challenge_hashes: Dict[bytes32, Tuple[bytes32, str, uint8]]
pool_pubkeys: List[PublicKey]
root_path: Path
_plot_notification_task: asyncio.Future
_plot_notification_task: Optional[asyncio.Task]
_is_shutdown: bool
executor: concurrent.futures.ThreadPoolExecutor
state_changed_callback: Optional[Callable]
@ -95,14 +95,24 @@ class Harvester:
# From quality string to (challenge_hash, filename, index)
self.challenge_hashes = {}
self._plot_notification_task = asyncio.ensure_future(self._plot_notification())
self._is_shutdown = False
self._plot_notification_task = None
self.global_connections: Optional[PeerConnections] = None
self.pool_pubkeys = []
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=10)
self.state_changed_callback = None
self.server = None
async def _start(self):
self._plot_notification_task = asyncio.create_task(self._plot_notification())
def _close(self):
self._is_shutdown = True
self.executor.shutdown(wait=True)
async def _await_closed(self):
await self._plot_notification_task
def _set_state_changed_callback(self, callback: Callable):
self.state_changed_callback = callback
if self.global_connections is not None:
@ -213,19 +223,12 @@ class Harvester:
self._refresh_plots()
return True
def set_global_connections(self, global_connections: Optional[PeerConnections]):
def _set_global_connections(self, global_connections: Optional[PeerConnections]):
self.global_connections = global_connections
def set_server(self, server):
def _set_server(self, server):
self.server = server
def _shutdown(self):
self._is_shutdown = True
self.executor.shutdown(wait=True)
async def _await_shutdown(self):
await self._plot_notification_task
@api_request
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake

View File

@ -1,11 +1,12 @@
import asyncio
import logging
from typing import AsyncGenerator, Dict
from typing import AsyncGenerator, Dict, Optional
from src.protocols.introducer_protocol import RespondPeers, RequestPeers
from src.server.connection import PeerConnections
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.types.sized_bytes import bytes32
from src.server.server import ChiaServer
from src.util.api_decorators import api_request
log = logging.getLogger(__name__)
@ -16,8 +17,57 @@ class Introducer:
self.vetted: Dict[bytes32, bool] = {}
self.max_peers_to_send = max_peers_to_send
self.recent_peer_threshold = recent_peer_threshold
self._shut_down = False
self.server: Optional[ChiaServer] = None
def set_global_connections(self, global_connections: PeerConnections):
async def _start(self):
self._vetting_task = asyncio.create_task(self._vetting_loop())
def _close(self):
self._shut_down = True
async def _await_closed(self):
await self._vetting_task
def _set_server(self, server: ChiaServer):
self.server = server
async def _vetting_loop(self):
while True:
if self._shut_down:
return
try:
log.info("Vetting random peers.")
rawpeers = self.global_connections.peers.get_peers(
100, True, self.recent_peer_threshold
)
for peer in rawpeers:
if self._shut_down:
return
if peer.get_hash() not in self.vetted:
try:
log.info(f"Vetting peer {peer.host} {peer.port}")
r, w = await asyncio.wait_for(
asyncio.open_connection(peer.host, int(peer.port)),
timeout=3,
)
w.close()
except Exception as e:
log.warning(f"Could not vet {peer}. {type(e)}{str(e)}")
self.vetted[peer.get_hash()] = False
continue
log.info(f"Have vetted {peer} successfully!")
self.vetted[peer.get_hash()] = True
except Exception as e:
log.error(e)
for i in range(30):
if self._shut_down:
return
await asyncio.sleep(1)
def _set_global_connections(self, global_connections: PeerConnections):
self.global_connections: PeerConnections = global_connections
@api_request
@ -26,29 +76,14 @@ class Introducer:
) -> AsyncGenerator[OutboundMessage, None]:
max_peers = self.max_peers_to_send
rawpeers = self.global_connections.peers.get_peers(
max_peers * 2, True, self.recent_peer_threshold
max_peers * 5, True, self.recent_peer_threshold
)
peers = []
for peer in rawpeers:
if peer.get_hash() not in self.vetted:
try:
r, w = await asyncio.open_connection(peer.host, int(peer.port))
w.close()
except (
ConnectionRefusedError,
TimeoutError,
OSError,
asyncio.TimeoutError,
) as e:
log.warning(f"Could not vet {peer}. {type(e)}{str(e)}")
self.vetted[peer.get_hash()] = False
continue
log.info(f"Have vetted {peer} successfully!")
self.vetted[peer.get_hash()] = True
continue
if self.vetted[peer.get_hash()]:
peers.append(peer)

46
src/rpc/farmer_rpc_api.py Normal file
View File

@ -0,0 +1,46 @@
from typing import Callable, Set, Dict, List
from src.farmer import Farmer
from src.util.ws_message import create_payload
class FarmerRpcApi:
def __init__(self, farmer: Farmer):
self.service = farmer
self.service_name = "chia_farmer"
def get_routes(self) -> Dict[str, Callable]:
return {"/get_latest_challenges": self.get_latest_challenges}
async def _state_changed(self, change: str) -> List[str]:
if change == "challenge":
data = await self.get_latest_challenges({})
return [
create_payload(
"get_latest_challenges", data, self.service_name, "wallet_ui"
)
]
return []
async def get_latest_challenges(self, request: Dict) -> Dict:
response = []
seen_challenges: Set = set()
if self.service.current_weight == 0:
return {"success": True, "latest_challenges": []}
for pospace_fin in self.service.challenges[self.service.current_weight]:
estimates = self.service.challenge_to_estimates.get(
pospace_fin.challenge_hash, []
)
if pospace_fin.challenge_hash in seen_challenges:
continue
response.append(
{
"challenge": pospace_fin.challenge_hash,
"weight": pospace_fin.weight,
"height": pospace_fin.height,
"difficulty": pospace_fin.difficulty,
"estimates": estimates,
}
)
seen_challenges.add(pospace_fin.challenge_hash)
return {"success": True, "latest_challenges": response}

View File

@ -1,13 +1,8 @@
import aiohttp
import asyncio
from typing import Dict, Optional, List
from src.util.byte_types import hexstr_to_bytes
from src.types.sized_bytes import bytes32
from src.util.ints import uint16
from typing import Dict, List
from src.rpc.rpc_client import RpcClient
class FarmerRpcClient:
class FarmerRpcClient(RpcClient):
"""
Client to Chia RPC, connects to a local farmer. Uses HTTP/JSON, and converts back from
JSON into native python objects before returning. All api calls use POST requests.
@ -16,44 +11,5 @@ class FarmerRpcClient:
to the full node.
"""
url: str
session: aiohttp.ClientSession
closing_task: Optional[asyncio.Task]
@classmethod
async def create(cls, port: uint16):
self = cls()
self.url = f"http://localhost:{str(port)}/"
self.session = aiohttp.ClientSession()
self.closing_task = None
return self
async def fetch(self, path, request_json):
async with self.session.post(self.url + path, json=request_json) as response:
response.raise_for_status()
return await response.json()
async def get_latest_challenges(self) -> List[Dict]:
return await self.fetch("get_latest_challenges", {})
async def get_connections(self) -> List[Dict]:
response = await self.fetch("get_connections", {})
for connection in response["connections"]:
connection["node_id"] = hexstr_to_bytes(connection["node_id"])
return response["connections"]
async def open_connection(self, host: str, port: int) -> Dict:
return await self.fetch("open_connection", {"host": host, "port": int(port)})
async def close_connection(self, node_id: bytes32) -> Dict:
return await self.fetch("close_connection", {"node_id": node_id.hex()})
async def stop_node(self) -> Dict:
return await self.fetch("stop_node", {})
def close(self):
self.closing_task = asyncio.create_task(self.session.close())
async def await_closed(self):
if self.closing_task is not None:
await self.closing_task

View File

@ -1,65 +0,0 @@
from typing import Callable, Set, Dict
from src.farmer import Farmer
from src.util.ints import uint16
from src.util.ws_message import create_payload
from src.rpc.abstract_rpc_server import AbstractRpcApiHandler, start_rpc_server
class FarmerRpcApiHandler(AbstractRpcApiHandler):
def __init__(self, farmer: Farmer, stop_cb: Callable):
super().__init__(farmer, stop_cb, "chia_farmer")
async def _state_changed(self, change: str):
assert self.websocket is not None
if change == "challenge":
data = await self.get_latest_challenges({})
payload = create_payload(
"get_latest_challenges", data, self.service_name, "wallet_ui"
)
else:
await super()._state_changed(change)
return
try:
await self.websocket.send_str(payload)
except (BaseException) as e:
try:
self.log.warning(f"Sending data failed. Exception {type(e)}.")
except BrokenPipeError:
pass
async def get_latest_challenges(self, request: Dict) -> Dict:
response = []
seen_challenges: Set = set()
if self.service.current_weight == 0:
return {"success": True, "latest_challenges": []}
for pospace_fin in self.service.challenges[self.service.current_weight]:
estimates = self.service.challenge_to_estimates.get(
pospace_fin.challenge_hash, []
)
if pospace_fin.challenge_hash in seen_challenges:
continue
response.append(
{
"challenge": pospace_fin.challenge_hash,
"weight": pospace_fin.weight,
"height": pospace_fin.height,
"difficulty": pospace_fin.difficulty,
"estimates": estimates,
}
)
seen_challenges.add(pospace_fin.challenge_hash)
return {"success": True, "latest_challenges": response}
async def start_farmer_rpc_server(
farmer: Farmer, stop_node_cb: Callable, rpc_port: uint16
):
handler = FarmerRpcApiHandler(farmer, stop_node_cb)
routes = {"/get_latest_challenges": handler.get_latest_challenges}
cleanup = await start_rpc_server(handler, rpc_port, routes)
return cleanup
AbstractRpcApiHandler.register(FarmerRpcApiHandler)

View File

@ -1,6 +1,4 @@
from src.full_node.full_node import FullNode
from src.util.ints import uint16
from src.rpc.abstract_rpc_server import AbstractRpcApiHandler, start_rpc_server
from typing import Callable, List, Optional, Dict
from aiohttp import web
@ -14,38 +12,43 @@ from src.consensus.pot_iterations import calculate_min_iters_from_iterations
from src.util.ws_message import create_payload
class FullNodeRpcApiHandler(AbstractRpcApiHandler):
def __init__(self, full_node: FullNode, stop_cb: Callable):
super().__init__(full_node, stop_cb, "chia_full_node")
class FullNodeRpcApi:
def __init__(self, full_node: FullNode):
self.service = full_node
self.service_name = "chia_full_node"
self.cached_blockchain_state: Optional[Dict] = None
async def _state_changed(self, change: str):
assert self.websocket is not None
def get_routes(self) -> Dict[str, Callable]:
return {
"/get_blockchain_state": self.get_blockchain_state,
"/get_block": self.get_block,
"/get_header_by_height": self.get_header_by_height,
"/get_header": self.get_header,
"/get_unfinished_block_headers": self.get_unfinished_block_headers,
"/get_network_space": self.get_network_space,
"/get_unspent_coins": self.get_unspent_coins,
"/get_heaviest_block_seen": self.get_heaviest_block_seen,
}
async def _state_changed(self, change: str) -> List[str]:
payloads = []
if change == "block":
data = await self.get_latest_block_headers({})
assert data is not None
payloads.append(
create_payload(
"get_latest_block_headers", data, self.service_name, "wallet_ui"
)
)
data = await self.get_blockchain_state({})
assert data is not None
payloads.append(
create_payload(
"get_blockchain_state", data, self.service_name, "wallet_ui"
)
)
else:
await super()._state_changed(change)
return
try:
for payload in payloads:
await self.websocket.send_str(payload)
except (BaseException) as e:
try:
self.log.warning(f"Sending data failed. Exception {type(e)}.")
except BrokenPipeError:
pass
return payloads
return []
async def get_blockchain_state(self, request: Dict):
"""
@ -357,24 +360,3 @@ class FullNodeRpcApiHandler(AbstractRpcApiHandler):
if pot_block.weight > max_tip.weight:
max_tip = pot_block.header
return {"success": True, "tip": max_tip}
async def start_full_node_rpc_server(
full_node: FullNode, stop_node_cb: Callable, rpc_port: uint16
):
handler = FullNodeRpcApiHandler(full_node, stop_node_cb)
routes = {
"/get_blockchain_state": handler.get_blockchain_state,
"/get_block": handler.get_block,
"/get_header_by_height": handler.get_header_by_height,
"/get_header": handler.get_header,
"/get_unfinished_block_headers": handler.get_unfinished_block_headers,
"/get_network_space": handler.get_network_space,
"/get_unspent_coins": handler.get_unspent_coins,
"/get_heaviest_block_seen": handler.get_heaviest_block_seen,
}
cleanup = await start_rpc_server(handler, rpc_port, routes)
return cleanup
AbstractRpcApiHandler.register(FullNodeRpcApiHandler)

View File

@ -1,16 +1,14 @@
import aiohttp
import asyncio
from typing import Dict, Optional, List
from src.util.byte_types import hexstr_to_bytes
from src.types.full_block import FullBlock
from src.types.header import Header
from src.types.sized_bytes import bytes32
from src.util.ints import uint16, uint32, uint64
from src.util.ints import uint32, uint64
from src.types.coin_record import CoinRecord
from src.rpc.rpc_client import RpcClient
class FullNodeRpcClient:
class FullNodeRpcClient(RpcClient):
"""
Client to Chia RPC, connects to a local full node. Uses HTTP/JSON, and converts back from
JSON into native python objects before returning. All api calls use POST requests.
@ -19,23 +17,6 @@ class FullNodeRpcClient:
to the full node.
"""
url: str
session: aiohttp.ClientSession
closing_task: Optional[asyncio.Task]
@classmethod
async def create(cls, port: uint16):
self = cls()
self.url = f"http://localhost:{str(port)}/"
self.session = aiohttp.ClientSession()
self.closing_task = None
return self
async def fetch(self, path, request_json):
async with self.session.post(self.url + path, json=request_json) as response:
response.raise_for_status()
return await response.json()
async def get_blockchain_state(self) -> Dict:
response = await self.fetch("get_blockchain_state", {})
response["blockchain_state"]["tips"] = [
@ -98,21 +79,6 @@ class FullNodeRpcClient:
raise
return network_space_bytes_estimate["space"]
async def get_connections(self) -> List[Dict]:
response = await self.fetch("get_connections", {})
for connection in response["connections"]:
connection["node_id"] = hexstr_to_bytes(connection["node_id"])
return response["connections"]
async def open_connection(self, host: str, port: int) -> Dict:
return await self.fetch("open_connection", {"host": host, "port": int(port)})
async def close_connection(self, node_id: bytes32) -> Dict:
return await self.fetch("close_connection", {"node_id": node_id.hex()})
async def stop_node(self) -> Dict:
return await self.fetch("stop_node", {})
async def get_unspent_coins(
self, puzzle_hash: bytes32, header_hash: Optional[bytes32] = None
) -> List:
@ -128,10 +94,3 @@ class FullNodeRpcClient:
async def get_heaviest_block_seen(self) -> Header:
response = await self.fetch("get_heaviest_block_seen", {})
return Header.from_json_dict(response["tip"])
def close(self):
self.closing_task = asyncio.create_task(self.session.close())
async def await_closed(self):
if self.closing_task is not None:
await self.closing_task

View File

@ -1,32 +1,29 @@
from typing import Callable, Dict
from blspy import PrivateKey, PublicKey
from typing import Callable, Dict, List
from src.harvester import Harvester
from src.util.ints import uint16
from src.util.ws_message import create_payload
from src.rpc.abstract_rpc_server import AbstractRpcApiHandler, start_rpc_server
from blspy import PrivateKey, PublicKey
class HarvesterRpcApiHandler(AbstractRpcApiHandler):
def __init__(self, harvester: Harvester, stop_cb: Callable):
super().__init__(harvester, stop_cb, "chia_harvester")
class HarvesterRpcApi:
def __init__(self, harvester: Harvester):
self.service = harvester
self.service_name = "chia_harvester"
async def _state_changed(self, change: str):
assert self.websocket is not None
def get_routes(self) -> Dict[str, Callable]:
return {
"/get_plots": self.get_plots,
"/refresh_plots": self.refresh_plots,
"/delete_plot": self.delete_plot,
"/add_plot": self.add_plot,
}
async def _state_changed(self, change: str) -> List[str]:
if change == "plots":
data = await self.get_plots({})
payload = create_payload("get_plots", data, self.service_name, "wallet_ui")
else:
await super()._state_changed(change)
return
try:
await self.websocket.send_str(payload)
except (BaseException) as e:
try:
self.log.warning(f"Sending data failed. Exception {type(e)}.")
except BrokenPipeError:
pass
return [payload]
return []
async def get_plots(self, request: Dict) -> Dict:
plots, failed_to_open, not_found = self.service._get_plots()
@ -55,20 +52,3 @@ class HarvesterRpcApiHandler(AbstractRpcApiHandler):
plot_sk = PrivateKey.from_bytes(bytes.fromhex(request["plot_sk"]))
success = self.service._add_plot(filename, plot_sk, pool_pk)
return {"success": success}
async def start_harvester_rpc_server(
harvester: Harvester, stop_node_cb: Callable, rpc_port: uint16
):
handler = HarvesterRpcApiHandler(harvester, stop_node_cb)
routes = {
"/get_plots": handler.get_plots,
"/refresh_plots": handler.refresh_plots,
"/delete_plot": handler.delete_plot,
"/add_plot": handler.add_plot,
}
cleanup = await start_rpc_server(handler, rpc_port, routes)
return cleanup
AbstractRpcApiHandler.register(HarvesterRpcApiHandler)

View File

@ -1,14 +1,9 @@
import aiohttp
import asyncio
from blspy import PrivateKey, PublicKey
from typing import Dict, Optional, List
from src.util.byte_types import hexstr_to_bytes
from src.types.sized_bytes import bytes32
from src.util.ints import uint16
from typing import Optional, List, Dict
from src.rpc.rpc_client import RpcClient
class HarvesterRpcClient:
class HarvesterRpcClient(RpcClient):
"""
Client to Chia RPC, connects to a local harvester. Uses HTTP/JSON, and converts back from
JSON into native python objects before returning. All api calls use POST requests.
@ -17,23 +12,6 @@ class HarvesterRpcClient:
to the full node.
"""
url: str
session: aiohttp.ClientSession
closing_task: Optional[asyncio.Task]
@classmethod
async def create(cls, port: uint16):
self = cls()
self.url = f"http://localhost:{str(port)}/"
self.session = aiohttp.ClientSession()
self.closing_task = None
return self
async def fetch(self, path, request_json):
async with self.session.post(self.url + path, json=request_json) as response:
response.raise_for_status()
return await response.json()
async def get_plots(self) -> List[Dict]:
return await self.fetch("get_plots", {})
@ -57,25 +35,3 @@ class HarvesterRpcClient:
return await self.fetch(
"add_plot", {"filename": filename, "plot_sk": plot_sk_str}
)
async def get_connections(self) -> List[Dict]:
response = await self.fetch("get_connections", {})
for connection in response["connections"]:
connection["node_id"] = hexstr_to_bytes(connection["node_id"])
return response["connections"]
async def open_connection(self, host: str, port: int) -> Dict:
return await self.fetch("open_connection", {"host": host, "port": int(port)})
async def close_connection(self, node_id: bytes32) -> Dict:
return await self.fetch("close_connection", {"node_id": node_id.hex()})
async def stop_node(self) -> Dict:
return await self.fetch("stop_node", {})
def close(self):
self.closing_task = asyncio.create_task(self.session.close())
async def await_closed(self):
if self.closing_task is not None:
await self.closing_task

56
src/rpc/rpc_client.py Normal file
View File

@ -0,0 +1,56 @@
import aiohttp
import asyncio
from typing import Dict, Optional, List
from src.util.byte_types import hexstr_to_bytes
from src.types.sized_bytes import bytes32
from src.util.ints import uint16
class RpcClient:
"""
Client to Chia RPC, connects to a local service. Uses HTTP/JSON, and converts back from
JSON into native python objects before returning. All api calls use POST requests.
Note that this is not the same as the peer protocol, or wallet protocol (which run Chia's
protocol on top of TCP), it's a separate protocol on top of HTTP thats provides easy access
to the full node.
"""
url: str
session: aiohttp.ClientSession
closing_task: Optional[asyncio.Task]
@classmethod
async def create(cls, port: uint16):
self = cls()
self.url = f"http://localhost:{str(port)}/"
self.session = aiohttp.ClientSession()
self.closing_task = None
return self
async def fetch(self, path, request_json):
async with self.session.post(self.url + path, json=request_json) as response:
response.raise_for_status()
return await response.json()
async def get_connections(self) -> List[Dict]:
response = await self.fetch("get_connections", {})
for connection in response["connections"]:
connection["node_id"] = hexstr_to_bytes(connection["node_id"])
return response["connections"]
async def open_connection(self, host: str, port: int) -> Dict:
return await self.fetch("open_connection", {"host": host, "port": int(port)})
async def close_connection(self, node_id: bytes32) -> Dict:
return await self.fetch("close_connection", {"node_id": node_id.hex()})
async def stop_node(self) -> Dict:
return await self.fetch("stop_node", {})
def close(self):
self.closing_task = asyncio.create_task(self.session.close())
async def await_closed(self):
if self.closing_task is not None:
await self.closing_task

View File

@ -1,5 +1,4 @@
from typing import Callable, Dict, Any
from abc import ABC, abstractmethod
from typing import Callable, Dict, Any, List
import aiohttp
import logging
@ -8,21 +7,21 @@ import json
import traceback
from src.types.peer_info import PeerInfo
from src.util.ints import uint16
from src.util.byte_types import hexstr_to_bytes
from src.util.json_util import obj_to_response
from src.util.ws_message import create_payload, format_response, pong
from src.util.ints import uint16
log = logging.getLogger(__name__)
class AbstractRpcApiHandler(ABC):
class RpcServer:
"""
Implementation of RPC server.
"""
def __init__(self, service: Any, stop_cb: Callable, service_name: str):
self.service = service
def __init__(self, rpc_api: Any, service_name: str, stop_cb: Callable):
self.rpc_api = rpc_api
self.stop_cb: Callable = stop_cb
self.log = log
self.shut_down = False
@ -34,34 +33,28 @@ class AbstractRpcApiHandler(ABC):
if self.websocket is not None:
await self.websocket.close()
@classmethod
def __subclasshook__(cls, C):
if cls is AbstractRpcApiHandler:
if any("_state_changed" in B.__dict__ for B in C.__mro__):
return True
return NotImplemented
@abstractmethod
async def _state_changed(self, change: str):
async def _state_changed(self, *args):
change = args[0]
assert self.websocket is not None
payloads: List[str] = await self.rpc_api._state_changed(*args)
if change == "add_connection" or change == "close_connection":
data = await self.get_connections({})
payload = create_payload(
"get_connections", data, self.service_name, "wallet_ui"
)
try:
await self.websocket.send_str(payload)
except (BaseException) as e:
payloads.append(payload)
for payload in payloads:
try:
await self.websocket.send_str(payload)
except Exception as e:
self.log.warning(f"Sending data failed. Exception {type(e)}.")
except BrokenPipeError:
pass
def state_changed(self, change: str):
def state_changed(self, *args):
if self.websocket is None:
return
asyncio.create_task(self._state_changed(change))
asyncio.create_task(self._state_changed(*args))
def _wrap_http_handler(self, f) -> Callable:
async def inner(request) -> aiohttp.web.Response:
@ -74,9 +67,9 @@ class AbstractRpcApiHandler(ABC):
return inner
async def get_connections(self, request: Dict) -> Dict:
if self.service.global_connections is None:
if self.rpc_api.service.global_connections is None:
return {"success": False}
connections = self.service.global_connections.get_connections()
connections = self.rpc_api.service.global_connections.get_connections()
con_info = [
{
"type": con.connection_type,
@ -100,25 +93,25 @@ class AbstractRpcApiHandler(ABC):
port = request["port"]
target_node: PeerInfo = PeerInfo(host, uint16(int(port)))
if getattr(self.service, "server", None) is None or not (
await self.service.server.start_client(target_node, None)
if getattr(self.rpc_api.service, "server", None) is None or not (
await self.rpc_api.service.server.start_client(target_node, None)
):
raise aiohttp.web.HTTPInternalServerError()
return {"success": True}
async def close_connection(self, request: Dict):
node_id = hexstr_to_bytes(request["node_id"])
if self.service.global_connections is None:
if self.rpc_api.service.global_connections is None:
raise aiohttp.web.HTTPInternalServerError()
connections_to_close = [
c
for c in self.service.global_connections.get_connections()
for c in self.rpc_api.service.global_connections.get_connections()
if c.node_id == node_id
]
if len(connections_to_close) == 0:
raise aiohttp.web.HTTPNotFound()
for connection in connections_to_close:
self.service.global_connections.close(connection)
self.rpc_api.service.global_connections.close(connection)
return {"success": True}
async def stop_node(self, request):
@ -145,6 +138,9 @@ class AbstractRpcApiHandler(ABC):
return pong()
f = getattr(self, command, None)
if f is not None:
return await f(data)
f = getattr(self.rpc_api, command, None)
if f is not None:
return await f(data)
else:
@ -156,9 +152,10 @@ class AbstractRpcApiHandler(ABC):
message = json.loads(payload)
response = await self.ws_api(message)
if response is not None:
# log.info(f"Sending {message} {response}")
await websocket.send_str(format_response(message, response))
except BaseException as e:
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Error while handling message: {tb}")
error = {"success": False, "error": f"{e}"}
@ -210,49 +207,54 @@ class AbstractRpcApiHandler(ABC):
await self.connection(ws)
self.websocket = None
await session.close()
except BaseException as e:
except Exception as e:
self.log.warning(f"Exception: {e}")
if session is not None:
await session.close()
await asyncio.sleep(1)
async def start_rpc_server(
handler: AbstractRpcApiHandler, rpc_port: uint16, http_routes: Dict[str, Callable]
):
async def start_rpc_server(rpc_api: Any, rpc_port: uint16, stop_cb: Callable):
"""
Starts an HTTP server with the following RPC methods, to be used by local clients to
query the node.
"""
app = aiohttp.web.Application()
handler.service._set_state_changed_callback(handler.state_changed)
rpc_server = RpcServer(rpc_api, rpc_api.service_name, stop_cb)
rpc_server.rpc_api.service._set_state_changed_callback(rpc_server.state_changed)
http_routes: Dict[str, Callable] = rpc_api.get_routes()
routes = [
aiohttp.web.post(route, handler._wrap_http_handler(func))
aiohttp.web.post(route, rpc_server._wrap_http_handler(func))
for (route, func) in http_routes.items()
]
routes += [
aiohttp.web.post(
"/get_connections", handler._wrap_http_handler(handler.get_connections)
"/get_connections",
rpc_server._wrap_http_handler(rpc_server.get_connections),
),
aiohttp.web.post(
"/open_connection", handler._wrap_http_handler(handler.open_connection)
"/open_connection",
rpc_server._wrap_http_handler(rpc_server.open_connection),
),
aiohttp.web.post(
"/close_connection", handler._wrap_http_handler(handler.close_connection)
"/close_connection",
rpc_server._wrap_http_handler(rpc_server.close_connection),
),
aiohttp.web.post(
"/stop_node", rpc_server._wrap_http_handler(rpc_server.stop_node)
),
aiohttp.web.post("/stop_node", handler._wrap_http_handler(handler.stop_node)),
]
app.add_routes(routes)
daemon_connection = asyncio.create_task(handler.connect_to_daemon())
daemon_connection = asyncio.create_task(rpc_server.connect_to_daemon())
runner = aiohttp.web.AppRunner(app, access_log=None)
await runner.setup()
site = aiohttp.web.TCPSite(runner, "localhost", int(rpc_port))
await site.start()
async def cleanup():
await handler.stop()
await rpc_server.stop()
await runner.cleanup()
await daemon_connection

528
src/rpc/wallet_rpc_api.py Normal file
View File

@ -0,0 +1,528 @@
import asyncio
import logging
import time
from pathlib import Path
from blspy import ExtendedPrivateKey, PrivateKey
from secrets import token_bytes
from typing import List, Optional, Tuple, Dict, Callable
from src.util.byte_types import hexstr_to_bytes
from src.util.keychain import (
seed_from_mnemonic,
generate_mnemonic,
bytes_to_mnemonic,
)
from src.util.path import path_from_root
from src.util.ws_message import create_payload
from src.cmds.init import check_keys
from src.server.outbound_message import NodeType, OutboundMessage, Message, Delivery
from src.simulator.simulator_protocol import FarmNewBlockProtocol
from src.util.ints import uint64, uint32
from src.wallet.util.wallet_types import WalletType
from src.wallet.rl_wallet.rl_wallet import RLWallet
from src.wallet.cc_wallet.cc_wallet import CCWallet
from src.wallet.wallet_info import WalletInfo
from src.wallet.wallet_node import WalletNode
from src.types.mempool_inclusion_status import MempoolInclusionStatus
# Timeout for response from wallet/full node for sending a transaction
TIMEOUT = 30
log = logging.getLogger(__name__)
class WalletRpcApi:
def __init__(self, wallet_node: WalletNode):
self.service = wallet_node
self.service_name = "chia-wallet"
def get_routes(self) -> Dict[str, Callable]:
return {
"/get_wallet_balance": self.get_wallet_balance,
"/send_transaction": self.send_transaction,
"/get_next_puzzle_hash": self.get_next_puzzle_hash,
"/get_transactions": self.get_transactions,
"/farm_block": self.farm_block,
"/get_sync_status": self.get_sync_status,
"/get_height_info": self.get_height_info,
"/create_new_wallet": self.create_new_wallet,
"/get_wallets": self.get_wallets,
"/rl_set_admin_info": self.rl_set_admin_info,
"/rl_set_user_info": self.rl_set_user_info,
"/cc_set_name": self.cc_set_name,
"/cc_get_name": self.cc_get_name,
"/cc_spend": self.cc_spend,
"/cc_get_colour": self.cc_get_colour,
"/create_offer_for_ids": self.create_offer_for_ids,
"/get_discrepancies_for_offer": self.get_discrepancies_for_offer,
"/respond_to_offer": self.respond_to_offer,
"/get_wallet_summaries": self.get_wallet_summaries,
"/get_public_keys": self.get_public_keys,
"/generate_mnemonic": self.generate_mnemonic,
"/log_in": self.log_in,
"/add_key": self.add_key,
"/delete_key": self.delete_key,
"/delete_all_keys": self.delete_all_keys,
"/get_private_key": self.get_private_key,
}
async def _state_changed(self, *args) -> List[str]:
if len(args) < 2:
return []
change = args[0]
wallet_id = args[1]
data = {
"state": change,
}
if wallet_id is not None:
data["wallet_id"] = wallet_id
return [create_payload("state_changed", data, "chia-wallet", "wallet_ui")]
async def get_next_puzzle_hash(self, request: Dict) -> Dict:
"""
Returns a new puzzlehash
"""
if self.service is None:
return {"success": False}
wallet_id = uint32(int(request["wallet_id"]))
wallet = self.service.wallet_state_manager.wallets[wallet_id]
if wallet.wallet_info.type == WalletType.STANDARD_WALLET:
puzzle_hash = (await wallet.get_new_puzzlehash()).hex()
elif wallet.wallet_info.type == WalletType.COLOURED_COIN:
puzzle_hash = await wallet.get_new_inner_hash()
response = {
"wallet_id": wallet_id,
"puzzle_hash": puzzle_hash,
}
return response
async def send_transaction(self, request):
wallet_id = int(request["wallet_id"])
wallet = self.service.wallet_state_manager.wallets[wallet_id]
try:
tx = await wallet.generate_signed_transaction_dict(request)
except Exception as e:
data = {
"status": "FAILED",
"reason": f"Failed to generate signed transaction {e}",
}
return data
if tx is None:
data = {
"status": "FAILED",
"reason": "Failed to generate signed transaction",
}
return data
try:
await wallet.push_transaction(tx)
except Exception as e:
data = {
"status": "FAILED",
"reason": f"Failed to push transaction {e}",
}
return data
sent = False
start = time.time()
while time.time() - start < TIMEOUT:
sent_to: List[
Tuple[str, MempoolInclusionStatus, Optional[str]]
] = await self.service.wallet_state_manager.get_transaction_status(
tx.name()
)
if len(sent_to) == 0:
await asyncio.sleep(0.1)
continue
status, err = sent_to[0][1], sent_to[0][2]
if status == MempoolInclusionStatus.SUCCESS:
data = {"status": "SUCCESS"}
sent = True
break
elif status == MempoolInclusionStatus.PENDING:
assert err is not None
data = {"status": "PENDING", "reason": err}
sent = True
break
elif status == MempoolInclusionStatus.FAILED:
assert err is not None
data = {"status": "FAILED", "reason": err}
sent = True
break
if not sent:
data = {
"status": "FAILED",
"reason": "Timed out. Transaction may or may not have been sent.",
}
return data
async def get_transactions(self, request):
wallet_id = int(request["wallet_id"])
transactions = await self.service.wallet_state_manager.get_all_transactions(
wallet_id
)
response = {"success": True, "txs": transactions, "wallet_id": wallet_id}
return response
async def farm_block(self, request):
puzzle_hash = bytes.fromhex(request["puzzle_hash"])
request = FarmNewBlockProtocol(puzzle_hash)
msg = OutboundMessage(
NodeType.FULL_NODE, Message("farm_new_block", request), Delivery.BROADCAST,
)
self.service.server.push_message(msg)
return {"success": True}
async def get_wallet_balance(self, request: Dict):
wallet_id = uint32(int(request["wallet_id"]))
wallet = self.service.wallet_state_manager.wallets[wallet_id]
balance = await wallet.get_confirmed_balance()
pending_balance = await wallet.get_unconfirmed_balance()
spendable_balance = await wallet.get_spendable_balance()
pending_change = await wallet.get_pending_change_balance()
if wallet.wallet_info.type == WalletType.COLOURED_COIN:
frozen_balance = 0
else:
frozen_balance = await wallet.get_frozen_amount()
response = {
"wallet_id": wallet_id,
"success": True,
"confirmed_wallet_balance": balance,
"unconfirmed_wallet_balance": pending_balance,
"spendable_balance": spendable_balance,
"frozen_balance": frozen_balance,
"pending_change": pending_change,
}
return response
async def get_sync_status(self, request: Dict):
syncing = self.service.wallet_state_manager.sync_mode
return {"success": True, "syncing": syncing}
async def get_height_info(self, request: Dict):
lca = self.service.wallet_state_manager.lca
height = self.service.wallet_state_manager.block_records[lca].height
response = {"success": True, "height": height}
return response
async def create_new_wallet(self, request):
config, wallet_state_manager, main_wallet = self.get_wallet_config()
if request["wallet_type"] == "cc_wallet":
if request["mode"] == "new":
try:
cc_wallet: CCWallet = await CCWallet.create_new_cc(
wallet_state_manager, main_wallet, request["amount"]
)
return {"success": True, "type": cc_wallet.wallet_info.type.name}
except Exception as e:
log.error("FAILED {e}")
return {"success": False, "reason": str(e)}
elif request["mode"] == "existing":
try:
cc_wallet = await CCWallet.create_wallet_for_cc(
wallet_state_manager, main_wallet, request["colour"]
)
return {"success": True, "type": cc_wallet.wallet_info.type.name}
except Exception as e:
log.error("FAILED2 {e}")
return {"success": False, "reason": str(e)}
def get_wallet_config(self):
return (
self.service.config,
self.service.wallet_state_manager,
self.service.wallet_state_manager.main_wallet,
)
async def get_wallets(self, request: Dict):
wallets: List[
WalletInfo
] = await self.service.wallet_state_manager.get_all_wallets()
response = {"wallets": wallets, "success": True}
return response
async def rl_set_admin_info(self, request):
wallet_id = int(request["wallet_id"])
wallet: RLWallet = self.service.wallet_state_manager.wallets[wallet_id]
user_pubkey = request["user_pubkey"]
limit = uint64(int(request["limit"]))
interval = uint64(int(request["interval"]))
amount = uint64(int(request["amount"]))
success = await wallet.admin_create_coin(interval, limit, user_pubkey, amount)
response = {"success": success}
return response
async def rl_set_user_info(self, request):
wallet_id = int(request["wallet_id"])
wallet: RLWallet = self.service.wallet_state_manager.wallets[wallet_id]
admin_pubkey = request["admin_pubkey"]
limit = uint64(int(request["limit"]))
interval = uint64(int(request["interval"]))
origin_id = request["origin_id"]
success = await wallet.set_user_info(interval, limit, origin_id, admin_pubkey)
response = {"success": success}
return response
async def cc_set_name(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id]
await wallet.set_name(str(request["name"]))
response = {"wallet_id": wallet_id, "success": True}
return response
async def cc_get_name(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id]
name: str = await wallet.get_name()
response = {"wallet_id": wallet_id, "name": name}
return response
async def cc_spend(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id]
puzzle_hash = hexstr_to_bytes(request["innerpuzhash"])
try:
tx = await wallet.cc_spend(request["amount"], puzzle_hash)
except Exception as e:
data = {
"status": "FAILED",
"reason": f"{e}",
}
return data
if tx is None:
data = {
"status": "FAILED",
"reason": "Failed to generate signed transaction",
}
return data
sent = False
start = time.time()
while time.time() - start < TIMEOUT:
sent_to: List[
Tuple[str, MempoolInclusionStatus, Optional[str]]
] = await self.service.wallet_state_manager.get_transaction_status(
tx.name()
)
if len(sent_to) == 0:
await asyncio.sleep(0.1)
continue
status, err = sent_to[0][1], sent_to[0][2]
if status == MempoolInclusionStatus.SUCCESS:
data = {"status": "SUCCESS"}
sent = True
break
elif status == MempoolInclusionStatus.PENDING:
assert err is not None
data = {"status": "PENDING", "reason": err}
sent = True
break
elif status == MempoolInclusionStatus.FAILED:
assert err is not None
data = {"status": "FAILED", "reason": err}
sent = True
break
if not sent:
data = {
"status": "FAILED",
"reason": "Timed out. Transaction may or may not have been sent.",
}
return data
async def cc_get_colour(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.service.wallet_state_manager.wallets[wallet_id]
colour: str = await wallet.get_colour()
response = {"colour": colour, "wallet_id": wallet_id}
return response
async def get_wallet_summaries(self, request: Dict):
response = {}
for wallet_id in self.service.wallet_state_manager.wallets:
wallet = self.service.wallet_state_manager.wallets[wallet_id]
balance = await wallet.get_confirmed_balance()
type = wallet.wallet_info.type
if type == WalletType.COLOURED_COIN:
name = wallet.cc_info.my_colour_name
colour = await wallet.get_colour()
response[wallet_id] = {
"type": type,
"balance": balance,
"name": name,
"colour": colour,
}
else:
response[wallet_id] = {"type": type, "balance": balance}
return response
async def get_discrepancies_for_offer(self, request):
file_name = request["filename"]
file_path = Path(file_name)
(
success,
discrepancies,
error,
) = await self.service.trade_manager.get_discrepancies_for_offer(file_path)
if success:
response = {"success": True, "discrepancies": discrepancies}
else:
response = {"success": False, "error": error}
return response
async def create_offer_for_ids(self, request):
offer = request["ids"]
file_name = request["filename"]
(
success,
spend_bundle,
error,
) = await self.service.trade_manager.create_offer_for_ids(offer)
if success:
self.service.trade_manager.write_offer_to_disk(
Path(file_name), spend_bundle
)
response = {"success": success}
else:
response = {"success": success, "reason": error}
return response
async def respond_to_offer(self, request):
file_path = Path(request["filename"])
success, reason = await self.service.trade_manager.respond_to_offer(file_path)
if success:
response = {"success": success}
else:
response = {"success": success, "reason": reason}
return response
async def get_public_keys(self, request: Dict):
fingerprints = [
(esk.get_public_key().get_fingerprint(), seed is not None)
for (esk, seed) in self.service.keychain.get_all_private_keys()
]
response = {"success": True, "public_key_fingerprints": fingerprints}
return response
async def get_private_key(self, request):
fingerprint = request["fingerprint"]
for esk, seed in self.service.keychain.get_all_private_keys():
if esk.get_public_key().get_fingerprint() == fingerprint:
s = bytes_to_mnemonic(seed) if seed is not None else None
return {
"success": True,
"private_key": {
"fingerprint": fingerprint,
"esk": bytes(esk).hex(),
"seed": s,
},
}
return {"success": False, "private_key": {"fingerprint": fingerprint}}
async def log_in(self, request):
await self.stop_wallet()
fingerprint = request["fingerprint"]
await self.service._start(fingerprint)
return {"success": True}
async def add_key(self, request):
if "mnemonic" in request:
# Adding a key from 24 word mnemonic
mnemonic = request["mnemonic"]
seed = seed_from_mnemonic(mnemonic)
self.service.keychain.add_private_key_seed(seed)
esk = ExtendedPrivateKey.from_seed(seed)
elif "hexkey" in request:
# Adding a key from hex private key string. Two cases: extended private key (HD)
# which is 77 bytes, and int private key which is 32 bytes.
if len(request["hexkey"]) != 154 and len(request["hexkey"]) != 64:
return {"success": False}
if len(request["hexkey"]) == 64:
sk = PrivateKey.from_bytes(bytes.fromhex(request["hexkey"]))
self.service.keychain.add_private_key_not_extended(sk)
key_bytes = bytes(sk)
new_extended_bytes = bytearray(
bytes(ExtendedPrivateKey.from_seed(token_bytes(32)))
)
final_extended_bytes = bytes(
new_extended_bytes[: -len(key_bytes)] + key_bytes
)
esk = ExtendedPrivateKey.from_bytes(final_extended_bytes)
else:
esk = ExtendedPrivateKey.from_bytes(bytes.fromhex(request["hexkey"]))
self.service.keychain.add_private_key(esk)
else:
return {"success": False}
fingerprint = esk.get_public_key().get_fingerprint()
await self.stop_wallet()
# Makes sure the new key is added to config properly
check_keys(self.service.root_path)
# Starts the wallet with the new key selected
await self.service._start(fingerprint)
return {"success": True}
async def delete_key(self, request):
await self.stop_wallet()
fingerprint = request["fingerprint"]
self.service.keychain.delete_key_by_fingerprint(fingerprint)
return {"success": True}
async def clean_all_state(self):
self.service.keychain.delete_all_keys()
path = path_from_root(
self.service.root_path, self.service.config["database_path"]
)
if path.exists():
path.unlink()
async def stop_wallet(self):
if self.service is not None:
self.service._close()
await self.service._await_closed()
async def delete_all_keys(self, request: Dict):
await self.stop_wallet()
await self.clean_all_state()
response = {"success": True}
return response
async def generate_mnemonic(self, request: Dict):
mnemonic = generate_mnemonic()
response = {"success": True, "mnemonic": mnemonic}
return response

View File

@ -39,7 +39,10 @@ class Connection:
self.reader = sr
self.writer = sw
socket = self.writer.get_extra_info("socket")
self.local_host = socket.getsockname()[0]
if socket is not None:
self.local_host = socket.getsockname()[0]
else:
self.local_host = "localhost"
self.local_port = server_port
self.peer_host = self.writer.get_extra_info("peername")[0]
self.peer_port = self.writer.get_extra_info("peername")[1]

View File

@ -107,9 +107,20 @@ async def initialize_pipeline(
map_aiter(expand_outbound_messages, responses_aiter, 100)
)
async def send():
try:
await connection.send(message)
except Exception as e:
connection.log.warning(
f"Cannot write to {connection}, already closed. Error {e}."
)
global_connections.close(connection, True)
# This will run forever. Sends each message through the TCP connection, using the
# length encoding and CBOR serialization
async for connection, message in expanded_messages_aiter:
if connection is None:
continue
if message is None:
# Does not ban the peer, this is just a graceful close of connection.
global_connections.close(connection, True)
@ -122,13 +133,7 @@ async def initialize_pipeline(
connection.log.info(
f"-> {message.function} to peer {connection.get_peername()}"
)
try:
await connection.send(message)
except (RuntimeError, TimeoutError, OSError,) as e:
connection.log.warning(
f"Cannot write to {connection}, already closed. Error {e}."
)
global_connections.close(connection, True)
asyncio.create_task(send())
async def stream_reader_writer_to_connection(

View File

@ -1,7 +1,7 @@
import asyncio
def start_reconnect_task(server, peer_info, log):
def start_reconnect_task(server, peer_info, log, auth):
"""
Start a background task that checks connection and reconnects periodically to a peer.
"""
@ -16,8 +16,7 @@ def start_reconnect_task(server, peer_info, log):
if peer_retry:
log.info(f"Reconnecting to peer {peer_info}")
if not await server.start_client(peer_info, None, auth=True):
await asyncio.sleep(1)
await asyncio.sleep(1)
await server.start_client(peer_info, None, auth=auth)
await asyncio.sleep(3)
return asyncio.create_task(connection_check())

View File

@ -80,7 +80,11 @@ class ChiaServer:
self._outbound_aiter: push_aiter = push_aiter()
# Taks list to keep references to tasks, so they don'y get GCd
self._tasks: List[asyncio.Task] = [self._initialize_ping_task()]
self._tasks: List[asyncio.Task] = []
if local_type != NodeType.INTRODUCER:
# Introducers should not keep connections alive, they should close them
self._tasks.append(self._initialize_ping_task())
if name:
self.log = logging.getLogger(name)
else:
@ -89,8 +93,8 @@ class ChiaServer:
# Our unique random node id that we will send to other peers, regenerated on launch
node_id = create_node_id()
if hasattr(api, "set_global_connections"):
api.set_global_connections(self.global_connections)
if hasattr(api, "_set_global_connections"):
api._set_global_connections(self.global_connections)
# Tasks for entire server pipeline
self._pipeline_task: asyncio.Future = asyncio.ensure_future(

View File

@ -4,12 +4,12 @@ from src.types.peer_info import PeerInfo
from src.util.keychain import Keychain
from src.util.config import load_config_cli
from src.util.default_root import DEFAULT_ROOT_PATH
from src.rpc.farmer_rpc_server import start_farmer_rpc_server
from src.rpc.farmer_rpc_api import FarmerRpcApi
from src.server.start_service import run_service
# See: https://bugs.python.org/issue29288
u''.encode('idna')
u"".encode("idna")
def service_kwargs_for_farmer(root_path):
@ -33,8 +33,9 @@ def service_kwargs_for_farmer(root_path):
service_name=service_name,
server_listen_ports=[config["port"]],
connect_peers=connect_peers,
auth_connect_peers=False,
on_connect_callback=api._on_connect,
rpc_start_callback_port=(start_farmer_rpc_server, config["rpc_port"]),
rpc_info=(FarmerRpcApi, config["rpc_port"]),
)
return kwargs

View File

@ -1,8 +1,7 @@
import logging
from multiprocessing import freeze_support
from src.full_node.full_node import FullNode
from src.rpc.full_node_rpc_server import start_full_node_rpc_server
from src.rpc.full_node_rpc_api import FullNodeRpcApi
from src.server.outbound_message import NodeType
from src.server.start_service import run_service
from src.util.config import load_config_cli
@ -12,9 +11,7 @@ from src.server.upnp import upnp_remap_port
from src.types.peer_info import PeerInfo
# See: https://bugs.python.org/issue29288
u''.encode('idna')
log = logging.getLogger(__name__)
u"".encode("idna")
def service_kwargs_for_full_node(root_path):
@ -27,7 +24,7 @@ def service_kwargs_for_full_node(root_path):
peer_info = PeerInfo(introducer["host"], introducer["port"])
async def start_callback():
await api.start()
await api._start()
if config["enable_upnp"]:
upnp_remap_port(config["port"])
@ -48,7 +45,7 @@ def service_kwargs_for_full_node(root_path):
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
rpc_start_callback_port=(start_full_node_rpc_server, config["rpc_port"]),
rpc_info=(FullNodeRpcApi, config["rpc_port"]),
periodic_introducer_poll=(
peer_info,
config["introducer_connect_interval"],

View File

@ -3,12 +3,12 @@ from src.server.outbound_message import NodeType
from src.types.peer_info import PeerInfo
from src.util.config import load_config, load_config_cli
from src.util.default_root import DEFAULT_ROOT_PATH
from src.rpc.harvester_rpc_server import start_harvester_rpc_server
from src.rpc.harvester_rpc_api import HarvesterRpcApi
from src.server.start_service import run_service
# See: https://bugs.python.org/issue29288
u''.encode('idna')
u"".encode("idna")
def service_kwargs_for_harvester(root_path=DEFAULT_ROOT_PATH):
@ -26,6 +26,15 @@ def service_kwargs_for_harvester(root_path=DEFAULT_ROOT_PATH):
api = Harvester(config, plot_config, root_path)
async def start_callback():
await api._start()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
kwargs = dict(
root_path=root_path,
api=api,
@ -34,7 +43,11 @@ def service_kwargs_for_harvester(root_path=DEFAULT_ROOT_PATH):
service_name=service_name,
server_listen_ports=[config["port"]],
connect_peers=connect_peers,
rpc_start_callback_port=(start_harvester_rpc_server, config["rpc_port"]),
auth_connect_peers=True,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
rpc_info=(HarvesterRpcApi, config["rpc_port"]),
)
return kwargs

View File

@ -6,7 +6,7 @@ from src.util.default_root import DEFAULT_ROOT_PATH
from src.server.start_service import run_service
# See: https://bugs.python.org/issue29288
u''.encode('idna')
u"".encode("idna")
def service_kwargs_for_introducer(root_path=DEFAULT_ROOT_PATH):
@ -16,6 +16,15 @@ def service_kwargs_for_introducer(root_path=DEFAULT_ROOT_PATH):
config["max_peers_to_send"], config["recent_peer_threshold"]
)
async def start_callback():
await introducer._start()
def stop_callback():
introducer._close()
async def await_closed_callback():
await introducer._await_closed()
kwargs = dict(
root_path=root_path,
api=introducer,
@ -23,6 +32,9 @@ def service_kwargs_for_introducer(root_path=DEFAULT_ROOT_PATH):
advertised_port=config["port"],
service_name=service_name,
server_listen_ports=[config["port"]],
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
)
return kwargs

View File

@ -15,8 +15,10 @@ from src.server.outbound_message import Delivery, Message, NodeType, OutboundMes
from src.server.server import ChiaServer, start_server
from src.types.peer_info import PeerInfo
from src.util.logging import initialize_logging
from src.util.config import load_config_cli, load_config
from src.util.config import load_config
from src.util.setproctitle import setproctitle
from src.rpc.rpc_server import start_rpc_server
from src.server.connection import OnConnectFunc
from .reconnect_task import start_reconnect_task
@ -70,8 +72,9 @@ class Service:
service_name: str,
server_listen_ports: List[int] = [],
connect_peers: List[PeerInfo] = [],
on_connect_callback: Optional[OutboundMessage] = None,
rpc_start_callback_port: Optional[Tuple[Callable, int]] = None,
auth_connect_peers: bool = True,
on_connect_callback: Optional[OnConnectFunc] = None,
rpc_info: Optional[Tuple[type, int]] = None,
start_callback: Optional[Callable] = None,
stop_callback: Optional[Callable] = None,
await_closed_callback: Optional[Callable] = None,
@ -88,14 +91,13 @@ class Service:
proctitle_name = f"chia_{service_name}"
setproctitle(proctitle_name)
self._log = logging.getLogger(service_name)
config = load_config(root_path, "config.yaml", service_name)
initialize_logging(service_name, config["logging"], root_path)
config = load_config_cli(root_path, "config.yaml", service_name)
initialize_logging(f"{service_name:<30s}", config["logging"], root_path)
self._rpc_start_callback_port = rpc_start_callback_port
self._rpc_info = rpc_info
self._server = ChiaServer(
config["port"],
advertised_port,
api,
node_type,
ping_interval,
@ -109,6 +111,7 @@ class Service:
f(self._server)
self._connect_peers = connect_peers
self._auth_connect_peers = auth_connect_peers
self._server_listen_ports = server_listen_ports
self._api = api
@ -120,6 +123,7 @@ class Service:
self._start_callback = start_callback
self._stop_callback = stop_callback
self._await_closed_callback = await_closed_callback
self._advertised_port = advertised_port
def start(self):
if self._task is not None:
@ -136,7 +140,6 @@ class Service:
introducer_connect_interval,
target_peer_count,
) = self._periodic_introducer_poll
self._introducer_poll_task = create_periodic_introducer_poll_task(
self._server,
peer_info,
@ -146,14 +149,16 @@ class Service:
)
self._rpc_task = None
if self._rpc_start_callback_port:
rpc_f, rpc_port = self._rpc_start_callback_port
self._rpc_task = asyncio.ensure_future(
rpc_f(self._api, self.stop, rpc_port)
if self._rpc_info:
rpc_api, rpc_port = self._rpc_info
self._rpc_task = asyncio.create_task(
start_rpc_server(rpc_api(self._api), rpc_port, self.stop)
)
self._reconnect_tasks = [
start_reconnect_task(self._server, _, self._log)
start_reconnect_task(
self._server, _, self._log, self._auth_connect_peers
)
for _ in self._connect_peers
]
self._server_sockets = [
@ -171,10 +176,12 @@ class Service:
await _.wait_closed()
await self._server.await_closed()
if self._stop_callback:
self._stop_callback()
if self._await_closed_callback:
await self._await_closed_callback()
self._task = asyncio.ensure_future(_run())
self._task = asyncio.create_task(_run())
async def run(self):
self.start()
@ -193,15 +200,13 @@ class Service:
self._api._shut_down = True
if self._introducer_poll_task:
self._introducer_poll_task.cancel()
if self._stop_callback:
self._stop_callback()
async def wait_closed(self):
await self._task
if self._rpc_task:
await self._rpc_task
await (await self._rpc_task)()
self._log.info("Closed RPC server.")
self._log.info("%s fully closed", self._node_type)
self._log.info(f"Service at port {self._advertised_port} fully closed")
async def async_run_service(*args, **kwargs):
@ -212,11 +217,4 @@ async def async_run_service(*args, **kwargs):
def run_service(*args, **kwargs):
if uvloop is not None:
uvloop.install()
# TODO: use asyncio.run instead
# for now, we use `run_until_complete` as `asyncio.run` blocks on RPC server not exiting
if 1:
return asyncio.get_event_loop().run_until_complete(
async_run_service(*args, **kwargs)
)
else:
return asyncio.run(async_run_service(*args, **kwargs))
return asyncio.run(async_run_service(*args, **kwargs))

View File

@ -1,119 +1,53 @@
import asyncio
import signal
import logging
from src.timelord import Timelord
from src.server.outbound_message import NodeType
from src.types.peer_info import PeerInfo
from src.util.config import load_config_cli
from src.util.default_root import DEFAULT_ROOT_PATH
from src.consensus.constants import constants
from src.server.start_service import run_service
# See: https://bugs.python.org/issue29288
u''.encode('idna')
try:
import uvloop
except ImportError:
uvloop = None
from src.server.outbound_message import NodeType
from src.server.server import ChiaServer
from src.timelord import Timelord
from src.types.peer_info import PeerInfo
from src.util.config import load_config_cli, load_config
from src.util.default_root import DEFAULT_ROOT_PATH
from src.util.logging import initialize_logging
from src.util.setproctitle import setproctitle
u"".encode("idna")
def start_timelord_bg_task(server, peer_info, log):
"""
Start a background task that checks connection and reconnects periodically to the full_node.
"""
def service_kwargs_for_timelord(root_path):
service_name = "timelord"
config = load_config_cli(root_path, "config.yaml", service_name)
async def connection_check():
while True:
if server is not None:
full_node_retry = True
connect_peers = [
PeerInfo(config["full_node_peer"]["host"], config["full_node_peer"]["port"])
]
for connection in server.global_connections.get_connections():
if connection.get_peer_info() == peer_info:
full_node_retry = False
api = Timelord(config, config)
if full_node_retry:
log.info(f"Reconnecting to full_node {peer_info}")
if not await server.start_client(peer_info, None, auth=False):
await asyncio.sleep(1)
await asyncio.sleep(30)
async def start_callback():
await api._start()
return asyncio.create_task(connection_check())
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
async def async_main():
root_path = DEFAULT_ROOT_PATH
net_config = load_config(root_path, "config.yaml")
config = load_config_cli(root_path, "config.yaml", "timelord")
initialize_logging("Timelord %(name)-23s", config["logging"], root_path)
log = logging.getLogger(__name__)
setproctitle("chia_timelord")
timelord = Timelord(config, constants)
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
config["port"],
timelord,
NodeType.TIMELORD,
ping_interval,
network_id,
DEFAULT_ROOT_PATH,
config,
kwargs = dict(
root_path=root_path,
api=api,
node_type=NodeType.TIMELORD,
advertised_port=config["port"],
service_name=service_name,
server_listen_ports=[config["port"]],
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
connect_peers=connect_peers,
auth_connect_peers=False,
)
timelord.set_server(server)
coro = asyncio.start_server(
timelord._handle_client,
config["vdf_server"]["host"],
config["vdf_server"]["port"],
loop=asyncio.get_running_loop(),
)
def stop_all():
server.close_all()
timelord._shutdown()
try:
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, stop_all)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, stop_all)
except NotImplementedError:
log.info("signal handlers unsupported")
await asyncio.sleep(10) # Allows full node to startup
peer_info = PeerInfo(
config["full_node_peer"]["host"], config["full_node_peer"]["port"]
)
bg_task = start_timelord_bg_task(server, peer_info, log)
vdf_server = asyncio.ensure_future(coro)
sanitizer_mode = config["sanitizer_mode"]
if not sanitizer_mode:
await timelord._manage_discriminant_queue()
else:
await timelord._manage_discriminant_queue_sanitizer()
log.info("Closed discriminant queue.")
log.info("Shutdown timelord.")
await server.await_closed()
vdf_server.cancel()
bg_task.cancel()
log.info("Timelord fully closed.")
return kwargs
def main():
if uvloop is not None:
uvloop.install()
asyncio.run(async_main())
kwargs = service_kwargs_for_timelord(DEFAULT_ROOT_PATH)
return run_service(**kwargs)
if __name__ == "__main__":

View File

@ -1,47 +1,75 @@
import asyncio
import logging
import traceback
from multiprocessing import freeze_support
from src.wallet.wallet_node import WalletNode
from src.rpc.wallet_rpc_api import WalletRpcApi
from src.server.outbound_message import NodeType
from src.server.start_service import run_service
from src.util.config import load_config_cli
from src.util.default_root import DEFAULT_ROOT_PATH
from src.util.keychain import Keychain
from src.simulator.simulator_constants import test_constants
from src.types.peer_info import PeerInfo
# See: https://bugs.python.org/issue29288
u''.encode('idna')
try:
import uvloop
except ImportError:
uvloop = None
from src.util.default_root import DEFAULT_ROOT_PATH
from src.util.setproctitle import setproctitle
from src.wallet.websocket_server import WebSocketServer
u"".encode("idna")
log = logging.getLogger(__name__)
async def start_websocket_server():
"""
Starts WalletNode, WebSocketServer, and ChiaServer
"""
setproctitle("chia-wallet")
def service_kwargs_for_wallet(root_path):
service_name = "wallet"
config = load_config_cli(root_path, "config.yaml", service_name)
keychain = Keychain(testing=False)
websocket_server = WebSocketServer(keychain, DEFAULT_ROOT_PATH)
await websocket_server.start()
log.info("Wallet fully closed")
if config["testing"] is True:
config["database_path"] = "test_db_wallet.db"
api = WalletNode(
config, keychain, root_path, override_constants=test_constants,
)
else:
api = WalletNode(config, keychain, root_path)
introducer = config["introducer_peer"]
peer_info = PeerInfo(introducer["host"], introducer["port"])
connect_peers = [
PeerInfo(config["full_node_peer"]["host"], config["full_node_peer"]["port"])
]
async def start_callback():
await api._start()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
kwargs = dict(
root_path=root_path,
api=api,
node_type=NodeType.WALLET,
advertised_port=config["port"],
service_name=service_name,
server_listen_ports=[config["port"]],
on_connect_callback=api._on_connect,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
rpc_info=(WalletRpcApi, config["rpc_port"]),
connect_peers=connect_peers,
auth_connect_peers=False,
periodic_introducer_poll=(
peer_info,
config["introducer_connect_interval"],
config["target_peer_count"],
),
)
return kwargs
def main():
if uvloop is not None:
uvloop.install()
asyncio.run(start_websocket_server())
kwargs = service_kwargs_for_wallet(DEFAULT_ROOT_PATH)
return run_service(**kwargs)
if __name__ == "__main__":
try:
main()
except Exception:
tb = traceback.format_exc()
log.error(f"Error in wallet. {tb}")
raise
freeze_support()
main()

View File

@ -1,100 +1,70 @@
import asyncio
import logging
import logging.config
import signal
from multiprocessing import freeze_support
from src.rpc.full_node_rpc_api import FullNodeRpcApi
from src.server.outbound_message import NodeType
from src.server.start_service import run_service
from src.util.config import load_config_cli
from src.util.default_root import DEFAULT_ROOT_PATH
from src.util.path import mkdir, path_from_root
from src.simulator.full_node_simulator import FullNodeSimulator
from src.simulator.simulator_constants import test_constants
try:
import uvloop
except ImportError:
uvloop = None
from src.types.peer_info import PeerInfo
from src.rpc.full_node_rpc_server import start_full_node_rpc_server
from src.server.server import ChiaServer, start_server
from src.server.connection import NodeType
from src.util.logging import initialize_logging
from src.util.config import load_config_cli, load_config
from src.util.default_root import DEFAULT_ROOT_PATH
from src.util.setproctitle import setproctitle
from src.util.path import mkdir, path_from_root
# See: https://bugs.python.org/issue29288
u"".encode("idna")
async def main():
root_path = DEFAULT_ROOT_PATH
net_config = load_config(root_path, "config.yaml")
config = load_config_cli(root_path, "config.yaml", "full_node")
setproctitle("chia_full_node_simulator")
initialize_logging("FullNode %(name)-23s", config["logging"], root_path)
log = logging.getLogger(__name__)
server_closed = False
def service_kwargs_for_full_node(root_path):
service_name = "full_node_simulator"
config = load_config_cli(root_path, "config.yaml", service_name)
db_path = path_from_root(root_path, config["simulator_database_path"])
mkdir(db_path.parent)
config["database_path"] = config["simulator_database_path"]
full_node = await FullNodeSimulator.create(
config, root_path=root_path, override_constants=test_constants,
api = FullNodeSimulator(
config, root_path=root_path, override_constants=test_constants
)
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
introducer = config["introducer_peer"]
peer_info = PeerInfo(introducer["host"], introducer["port"])
# Starts the full node server (which full nodes can connect to)
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
config["port"],
full_node,
NodeType.FULL_NODE,
ping_interval,
network_id,
DEFAULT_ROOT_PATH,
config,
async def start_callback():
await api._start()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
kwargs = dict(
root_path=root_path,
api=api,
node_type=NodeType.FULL_NODE,
advertised_port=config["port"],
service_name=service_name,
server_listen_ports=[config["port"]],
on_connect_callback=api._on_connect,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
rpc_info=(FullNodeRpcApi, config["rpc_port"]),
periodic_introducer_poll=(
peer_info,
config["introducer_connect_interval"],
config["target_peer_count"],
),
)
full_node._set_server(server)
server_socket = await start_server(server, full_node._on_connect)
rpc_cleanup = None
def stop_all():
nonlocal server_closed
if not server_closed:
# Called by the UI, when node is closed, or when a signal is sent
log.info("Closing all connections, and server...")
server.close_all()
server_socket.close()
server_closed = True
# Starts the RPC server
rpc_cleanup = await start_full_node_rpc_server(
full_node, stop_all, config["rpc_port"]
)
try:
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, stop_all)
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, stop_all)
except NotImplementedError:
log.info("signal handlers unsupported")
# Awaits for server and all connections to close
await server_socket.wait_closed()
await server.await_closed()
log.info("Closed all node servers.")
# Stops the full node and closes DBs
await full_node._await_closed()
# Waits for the rpc server to close
if rpc_cleanup is not None:
await rpc_cleanup()
log.info("Closed RPC server.")
await asyncio.get_running_loop().shutdown_asyncgens()
log.info("Node fully closed.")
return kwargs
if uvloop is not None:
uvloop.install()
asyncio.run(main())
def main():
kwargs = service_kwargs_for_full_node(DEFAULT_ROOT_PATH)
return run_service(**kwargs)
if __name__ == "__main__":
freeze_support()
main()

View File

@ -2,11 +2,11 @@ import asyncio
import io
import logging
import time
from asyncio import Lock, StreamReader, StreamWriter
from typing import Dict, List, Optional, Tuple
from chiavdf import create_discriminant
from src.consensus.constants import constants as consensus_constants
from src.protocols import timelord_protocol
from src.server.outbound_message import Delivery, Message, NodeType, OutboundMessage
from src.server.server import ChiaServer
@ -20,8 +20,11 @@ log = logging.getLogger(__name__)
class Timelord:
def __init__(self, config: Dict, constants: Dict):
self.constants = constants
def __init__(self, config: Dict, override_constants: Dict = {}):
self.constants = consensus_constants.copy()
for key, value in override_constants.items():
self.constants[key] = value
self.config: Dict = config
self.ips_estimate = {
k: v
@ -32,8 +35,10 @@ class Timelord:
)
)
}
self.lock: Lock = Lock()
self.active_discriminants: Dict[bytes32, Tuple[StreamWriter, uint64, str]] = {}
self.lock: asyncio.Lock = asyncio.Lock()
self.active_discriminants: Dict[
bytes32, Tuple[asyncio.StreamWriter, uint64, str]
] = {}
self.best_weight_three_proofs: int = -1
self.active_discriminants_start_time: Dict = {}
self.pending_iters: Dict = {}
@ -46,18 +51,22 @@ class Timelord:
self.discriminant_queue: List[Tuple[bytes32, uint128]] = []
self.max_connection_time = self.config["max_connection_time"]
self.potential_free_clients: List = []
self.free_clients: List[Tuple[str, StreamReader, StreamWriter]] = []
self.free_clients: List[
Tuple[str, asyncio.StreamReader, asyncio.StreamWriter]
] = []
self.server: Optional[ChiaServer] = None
self.vdf_server = None
self._is_shutdown = False
self.sanitizer_mode = self.config["sanitizer_mode"]
log.info(f"Am I sanitizing? {self.sanitizer_mode}")
self.last_time_seen_discriminant: Dict = {}
self.max_known_weights: List[uint128] = []
def set_server(self, server: ChiaServer):
def _set_server(self, server: ChiaServer):
self.server = server
async def _handle_client(self, reader: StreamReader, writer: StreamWriter):
async def _handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter
):
async with self.lock:
client_ip = writer.get_extra_info("peername")[0]
log.info(f"New timelord connection from client: {client_ip}.")
@ -69,13 +78,35 @@ class Timelord:
self.potential_free_clients.remove((ip, end_time))
break
def _shutdown(self):
async def _start(self):
if self.sanitizer_mode:
log.info("Starting timelord in sanitizer mode")
self.disc_queue = asyncio.create_task(
self._manage_discriminant_queue_sanitizer()
)
else:
log.info("Starting timelord in normal mode")
self.disc_queue = asyncio.create_task(self._manage_discriminant_queue())
self.vdf_server = await asyncio.start_server(
self._handle_client,
self.config["vdf_server"]["host"],
self.config["vdf_server"]["port"],
)
def _close(self):
self._is_shutdown = True
assert self.vdf_server is not None
self.vdf_server.close()
async def _await_closed(self):
assert self.disc_queue is not None
await self.disc_queue
async def _stop_worst_process(self, worst_weight_active):
# This is already inside a lock, no need to lock again.
log.info(f"Stopping one process at weight {worst_weight_active}")
stop_writer: Optional[StreamWriter] = None
stop_writer: Optional[asyncio.StreamWriter] = None
stop_discriminant: Optional[bytes32] = None
low_weights = {
@ -289,8 +320,8 @@ class Timelord:
msg = ""
try:
msg = data.decode()
except Exception:
pass
except Exception as e:
log.error(f"Exception while decoding data {e}")
if msg == "STOP":
log.info(f"Stopped client running on ip {ip}.")
@ -489,22 +520,16 @@ class Timelord:
with_iters = [
(d, w)
for d, w in self.discriminant_queue
if d in self.pending_iters
and len(self.pending_iters[d]) != 0
if d in self.pending_iters and len(self.pending_iters[d]) != 0
]
if (
len(with_iters) > 0
and len(self.free_clients) > 0
):
if len(with_iters) > 0 and len(self.free_clients) > 0:
disc, weight = with_iters[0]
log.info(f"Creating compact weso proof: weight {weight}.")
ip, sr, sw = self.free_clients[0]
self.free_clients = self.free_clients[1:]
self.discriminant_queue.remove((disc, weight))
asyncio.create_task(
self._do_process_communication(
disc, weight, ip, sr, sw
)
self._do_process_communication(disc, weight, ip, sr, sw)
)
if len(self.proofs_to_write) > 0:
for msg in self.proofs_to_write:
@ -526,7 +551,9 @@ class Timelord:
)
return
if challenge_start.weight <= self.best_weight_three_proofs:
log.info("Not starting challenge, already three proofs at that weight")
log.info(
"Not starting challenge, already three proofs at that weight"
)
return
self.seen_discriminants.append(challenge_start.challenge_hash)
self.discriminant_queue.append(
@ -575,7 +602,9 @@ class Timelord:
if proof_of_space_info.challenge_hash in disc_dict:
challenge_weight = disc_dict[proof_of_space_info.challenge_hash]
if challenge_weight >= min(self.max_known_weights):
log.info("Not storing iter, waiting for more block confirmations.")
log.info(
"Not storing iter, waiting for more block confirmations."
)
return
else:
log.info("Not storing iter, challenge inactive.")

View File

@ -5,14 +5,13 @@ import pathlib
import pkg_resources
from src.util.logging import initialize_logging
from src.util.config import load_config
from asyncio import Lock
from typing import List
from src.util.default_root import DEFAULT_ROOT_PATH
from src.util.setproctitle import setproctitle
active_processes: List = []
stopped = False
lock = Lock()
lock = asyncio.Lock()
log = logging.getLogger(__name__)
@ -23,7 +22,10 @@ async def kill_processes():
async with lock:
stopped = True
for process in active_processes:
process.kill()
try:
process.kill()
except ProcessLookupError:
pass
def find_vdf_client():
@ -76,7 +78,7 @@ def main():
root_path = DEFAULT_ROOT_PATH
setproctitle("chia_timelord_launcher")
config = load_config(root_path, "config.yaml", "timelord_launcher")
initialize_logging("Launcher %(name)-23s", config["logging"], root_path)
initialize_logging("TLauncher", config["logging"], root_path)
def signal_received():
asyncio.create_task(kill_processes())

View File

@ -36,6 +36,14 @@ class ClassGroup(tuple):
super(ClassGroup, self).__init__()
self._discriminant = None
def __eq__(self, obj):
return (
isinstance(obj, ClassGroup)
and obj[0] == self[0]
and obj[1] == self[1]
and obj[2] == self[2]
)
def identity(self):
return self.identity_for_discriminant(self.discriminant())

View File

@ -35,9 +35,9 @@ def config_path_for_filename(root_path: Path, filename: Union[str, Path]) -> Pat
def save_config(root_path: Path, filename: Union[str, Path], config_data: Any):
path = config_path_for_filename(root_path, filename)
with open(path.with_suffix('.' + str(os.getpid())), "w") as f:
with open(path.with_suffix("." + str(os.getpid())), "w") as f:
yaml.safe_dump(config_data, f)
shutil.move(path.with_suffix('.' + str(os.getpid())), path)
shutil.move(path.with_suffix("." + str(os.getpid())), path)
def load_config(

View File

@ -193,6 +193,7 @@ wallet:
# If we are restoring from private key and don't know the height.
starting_height: 0
num_sync_batches: 50
initial_num_public_keys: 100
full_node_peer:
host: 127.0.0.1

View File

@ -264,7 +264,7 @@ class Keychain:
keyring.delete_password(
self._get_service(), self._get_private_key_seed_user(index)
)
except BaseException:
except Exception:
delete_exception = True
# Stop when there are no more keys to delete
@ -283,7 +283,7 @@ class Keychain:
keyring.delete_password(
self._get_service(), self._get_private_key_user(index)
)
except BaseException:
except Exception:
delete_exception = True
# Stop when there are no more keys to delete

View File

@ -5,19 +5,21 @@ from pathlib import Path
from typing import Dict
from src.util.path import mkdir, path_from_root
from logging.handlers import RotatingFileHandler
from concurrent_log_handler import ConcurrentRotatingFileHandler
def initialize_logging(prefix: str, logging_config: Dict, root_path: Path):
def initialize_logging(service_name: str, logging_config: Dict, root_path: Path):
log_path = path_from_root(
root_path, logging_config.get("log_filename", "log/debug.log")
)
mkdir(str(log_path.parent))
file_name_length = 33 - len(service_name)
if logging_config["log_stdout"]:
handler = colorlog.StreamHandler()
handler.setFormatter(
colorlog.ColoredFormatter(
f"{prefix}: %(log_color)s%(levelname)-8s%(reset)s %(asctime)s.%(msecs)03d %(message)s",
f"%(asctime)s.%(msecs)03d {service_name} %(name)-{file_name_length}s: "
f"%(log_color)s%(levelname)-8s%(reset)s %(message)s",
datefmt="%H:%M:%S",
reset=True,
)
@ -26,15 +28,16 @@ def initialize_logging(prefix: str, logging_config: Dict, root_path: Path):
logger = colorlog.getLogger()
logger.addHandler(handler)
else:
logging.basicConfig(
filename=log_path,
filemode="a",
format=f"{prefix}: %(levelname)-8s %(asctime)s.%(msecs)03d %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger()
handler = RotatingFileHandler(log_path, maxBytes=20000000, backupCount=7)
handler = ConcurrentRotatingFileHandler(
log_path, "a", maxBytes=20 * 1024 * 1024, backupCount=7
)
handler.setFormatter(
logging.Formatter(
fmt=f"%(asctime)s.%(msecs)03d {service_name} %(name)-{file_name_length}s: %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
)
logger.addHandler(handler)
if "log_level" in logging_config:

View File

@ -1,7 +1,6 @@
from typing import Optional, List, Dict, Tuple
import clvm
from clvm import EvalError
from clvm.EvalError import EvalError
from clvm.casts import int_from_bytes
from src.types.condition_var_pair import ConditionVarPair
@ -125,7 +124,7 @@ def get_name_puzzle_conditions(
cost_sum += cost_run
if error:
return error, [], uint64(cost_sum)
except clvm.EvalError:
except EvalError:
return Err.INVALID_COIN_SOLUTION, [], uint64(cost_sum)
if conditions_dict is None:
conditions_dict = {}

View File

@ -300,7 +300,7 @@ class TruncatedNode:
p.append(TRUNCATED + self.hash)
class SetError(BaseException):
class SetError(Exception):
pass

View File

@ -4,7 +4,7 @@ SERVICES_FOR_GROUP = {
"harvester": "chia_harvester".split(),
"farmer": "chia_harvester chia_farmer chia_full_node chia-wallet".split(),
"timelord": "chia_timelord chia_timelord_launcher chia_full_node".split(),
"wallet-server": "chia-wallet".split(),
"wallet-server": "chia-wallet chia_full_node".split(),
"introducer": "chia_introducer".split(),
"simulator": "chia_full_node_simulator".split(),
"plotter": "chia-create-plots".split(),

View File

@ -4,9 +4,8 @@ from __future__ import annotations
import dataclasses
import io
import pprint
import json
from enum import Enum
from typing import Any, BinaryIO, List, Type, get_type_hints, Union, Dict
from typing import Any, BinaryIO, List, Type, get_type_hints, Dict
from src.util.byte_types import hexstr_to_bytes
from src.types.program import Program
from src.util.hash import std_hash
@ -23,14 +22,13 @@ from blspy import (
)
from src.types.sized_bytes import bytes32
from src.util.ints import uint32, uint8, uint64, int64, uint128, int512
from src.util.ints import uint32, uint64, int64, uint128, int512
from src.util.type_checking import (
is_type_List,
is_type_Tuple,
is_type_SpecificOptional,
strictdataclass,
)
from src.wallet.util.wallet_types import WalletType
pp = pprint.PrettyPrinter(indent=1, width=120, compact=True)

View File

@ -3,8 +3,6 @@ import time
import clvm
from typing import Dict, Optional, List, Any, Set
from clvm_tools import binutils
from clvm.EvalError import EvalError
from src.types.BLSSignature import BLSSignature
from src.types.coin import Coin
from src.types.coin_solution import CoinSolution
@ -36,7 +34,7 @@ from src.wallet.wallet_coin_record import WalletCoinRecord
from src.wallet.wallet_info import WalletInfo
from src.wallet.derivation_record import DerivationRecord
from src.wallet.cc_wallet import cc_wallet_puzzles
from clvm import run_program
from clvm_tools import binutils
# TODO: write tests based on wallet tests
# TODO: {Matt} compatibility based on deriving innerpuzzle from derivation record
@ -285,9 +283,9 @@ class CCWallet:
"""
cost_sum = 0
try:
cost_run, sexp = run_program(block_program, [])
cost_run, sexp = clvm.run_program(block_program, [])
cost_sum += cost_run
except EvalError:
except clvm.EvalError.EvalError:
return False
for name_solution in sexp.as_iter():
@ -308,7 +306,7 @@ class CCWallet:
cost_sum += cost_run
if error:
return False
except clvm.EvalError:
except clvm.EvalError.EvalError:
return False
if conditions_dict is None:

View File

@ -160,7 +160,7 @@ class Wallet:
self.wallet_info.id
)
)
sum = 0
sum_value = 0
used_coins: Set = set()
# Use older coins first
@ -174,13 +174,13 @@ class Wallet:
self.wallet_info.id
)
for coinrecord in unspent:
if sum >= amount and len(used_coins) > 0:
if sum_value >= amount and len(used_coins) > 0:
break
if coinrecord.coin.name() in unconfirmed_removals:
continue
if coinrecord.coin in exclude:
continue
sum += coinrecord.coin.amount
sum_value += coinrecord.coin.amount
used_coins.add(coinrecord.coin)
self.log.info(
f"Selected coin: {coinrecord.coin.name()} at height {coinrecord.confirmed_block_index}!"
@ -188,36 +188,26 @@ class Wallet:
# This happens when we couldn't use one of the coins because it's already used
# but unconfirmed, and we are waiting for the change. (unconfirmed_additions)
unconfirmed_additions = None
if sum < amount:
if sum_value < amount:
raise ValueError(
"Can't make this transaction at the moment. Waiting for the change from the previous transaction."
)
unconfirmed_additions = await self.wallet_state_manager.unconfirmed_additions_for_wallet(
self.wallet_info.id
)
for coin in unconfirmed_additions.values():
if sum > amount:
break
if coin.name() in unconfirmed_removals:
continue
# TODO(straya): remove this
# unconfirmed_additions = await self.wallet_state_manager.unconfirmed_additions_for_wallet(
# self.wallet_info.id
# )
# for coin in unconfirmed_additions.values():
# if sum_value > amount:
# break
# if coin.name() in unconfirmed_removals:
# continue
sum += coin.amount
used_coins.add(coin)
self.log.info(f"Selected used coin: {coin.name()}")
# sum_value += coin.amount
# used_coins.add(coin)
# self.log.info(f"Selected used coin: {coin.name()}")
if sum >= amount:
self.log.info(f"Successfully selected coins: {used_coins}")
return used_coins
else:
# This shouldn't happen because of: if amount > self.get_unconfirmed_balance_spendable():
self.log.error(
f"Wasn't able to select coins for amount: {amount}"
f"unspent: {unspent}"
f"unconfirmed_removals: {unconfirmed_removals}"
f"unconfirmed_additions: {unconfirmed_additions}"
)
return None
self.log.info(f"Successfully selected coins: {used_coins}")
return used_coins
async def generate_unsigned_transaction(
self,

View File

@ -1,7 +1,7 @@
import asyncio
import json
import time
from typing import Dict, Optional, Tuple, List, AsyncGenerator
from typing import Dict, Optional, Tuple, List, AsyncGenerator, Callable
import concurrent
from pathlib import Path
import random
@ -38,8 +38,8 @@ from src.full_node.blockchain import ReceiveBlockResult
from src.types.mempool_inclusion_status import MempoolInclusionStatus
from src.util.errors import Err
from src.util.path import path_from_root, mkdir
from src.server.reconnect_task import start_reconnect_task
from src.util.keychain import Keychain
from src.wallet.trade_manager import TradeManager
class WalletNode:
@ -76,24 +76,19 @@ class WalletNode:
short_sync_threshold: int
_shut_down: bool
root_path: Path
local_test: bool
state_changed_callback: Optional[Callable]
tasks: List[asyncio.Future]
@staticmethod
async def create(
def __init__(
self,
config: Dict,
private_key: ExtendedPrivateKey,
keychain: Keychain,
root_path: Path,
name: str = None,
override_constants: Dict = {},
local_test: bool = False,
):
self = WalletNode()
self.config = config
self.constants = consensus_constants.copy()
self.root_path = root_path
self.local_test = local_test
for key, value in override_constants.items():
self.constants[key] = value
if name:
@ -101,20 +96,10 @@ class WalletNode:
else:
self.log = logging.getLogger(__name__)
db_path_key_suffix = str(private_key.get_public_key().get_fingerprint())
path = path_from_root(
self.root_path, f"{config['database_path']}-{db_path_key_suffix}"
)
mkdir(path.parent)
self.wallet_state_manager = await WalletStateManager.create(
private_key, config, path, self.constants
)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
# Normal operation data
self.cached_blocks = {}
self.future_block_hashes = {}
self.keychain = keychain
# Sync data
self._shut_down = False
@ -124,12 +109,48 @@ class WalletNode:
self.short_sync_threshold = 15
self.potential_blocks_received = {}
self.potential_header_hashes = {}
self.state_changed_callback = None
self.server = None
self.tasks = []
async def _start(self, public_key_fingerprint: Optional[int] = None):
self._shut_down = False
private_keys = self.keychain.get_all_private_keys()
if len(private_keys) == 0:
raise RuntimeError("No keys")
return self
private_key: Optional[ExtendedPrivateKey] = None
if public_key_fingerprint is not None:
for sk, _ in private_keys:
if sk.get_public_key().get_fingerprint() == public_key_fingerprint:
private_key = sk
break
else:
private_key = private_keys[0][0]
if private_key is None:
raise RuntimeError("Invalid fingerprint {public_key_fingerprint}")
db_path_key_suffix = str(private_key.get_public_key().get_fingerprint())
path = path_from_root(
self.root_path, f"{self.config['database_path']}-{db_path_key_suffix}"
)
mkdir(path.parent)
self.wallet_state_manager = await WalletStateManager.create(
private_key, self.config, path, self.constants
)
self.trade_manager = await TradeManager.create(self.wallet_state_manager)
if self.state_changed_callback is not None:
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
def _set_state_changed_callback(self, callback: Callable):
self.state_changed_callback = callback
if self.global_connections is not None:
self.global_connections.set_state_changed_callback(callback)
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
def _pending_tx_handler(self):
asyncio.ensure_future(self._resend_queue())
@ -188,10 +209,10 @@ class WalletNode:
return messages
def set_global_connections(self, global_connections: PeerConnections):
def _set_global_connections(self, global_connections: PeerConnections):
self.global_connections = global_connections
def set_server(self, server: ChiaServer):
def _set_server(self, server: ChiaServer):
self.server = server
async def _on_connect(self) -> AsyncGenerator[OutboundMessage, None]:
@ -200,52 +221,16 @@ class WalletNode:
for msg in messages:
yield msg
def _shutdown(self):
print("Shutting down")
def _close(self):
self._shut_down = True
for task in self.tasks:
task.cancel()
self.wsm_close_task = asyncio.create_task(
self.wallet_state_manager.close_all_stores()
)
for connection in self.global_connections.get_connections():
connection.close()
def _start_bg_tasks(self):
"""
Start a background task connecting periodically to the introducer and
requesting the peer list.
"""
introducer = self.config["introducer_peer"]
introducer_peerinfo = PeerInfo(introducer["host"], introducer["port"])
async def introducer_client():
async def on_connect() -> OutboundMessageGenerator:
msg = Message("request_peers", introducer_protocol.RequestPeers())
yield OutboundMessage(NodeType.INTRODUCER, msg, Delivery.RESPOND)
while not self._shut_down:
for connection in self.global_connections.get_connections():
# If we are still connected to introducer, disconnect
if connection.connection_type == NodeType.INTRODUCER:
self.global_connections.close(connection)
if self._num_needed_peers():
if not await self.server.start_client(
introducer_peerinfo, on_connect
):
await asyncio.sleep(5)
continue
await asyncio.sleep(5)
if self._num_needed_peers() == self.config["target_peer_count"]:
# Try again if we have 0 peers
continue
await asyncio.sleep(self.config["introducer_connect_interval"])
if "full_node_peer" in self.config:
peer_info = PeerInfo(
self.config["full_node_peer"]["host"],
self.config["full_node_peer"]["port"],
)
task = start_reconnect_task(self.server, peer_info, self.log)
self.tasks.append(task)
if self.local_test is False:
self.tasks.append(asyncio.create_task(introducer_client()))
async def _await_closed(self):
await self.wsm_close_task
def _num_needed_peers(self) -> int:
assert self.server is not None
@ -328,8 +313,8 @@ class WalletNode:
Message("request_all_header_hashes_after", request_header_hashes),
Delivery.RESPOND,
)
timeout = 100
sleep_interval = 10
timeout = 50
sleep_interval = 3
sleep_interval_short = 1
start_wait = time.time()
while time.time() - start_wait < timeout:
@ -602,6 +587,8 @@ class WalletNode:
else:
# Not added to chain yet. Try again soon.
await asyncio.sleep(sleep_interval_short)
if self._shut_down:
return
total_time_slept += sleep_interval_short
if hh in self.wallet_state_manager.block_records:
break
@ -648,7 +635,6 @@ class WalletNode:
self.log.info(
f"Added orphan {block_record.header_hash} at height {block_record.height}"
)
pass
elif res == ReceiveBlockResult.ADDED_TO_HEAD:
self.log.info(
f"Updated LCA to {block_record.header_hash} at height {block_record.height}"
@ -691,7 +677,7 @@ class WalletNode:
f"SpendBundle has been received (and is pending) by the FullNode. {ack}"
)
else:
self.log.info(f"SpendBundle has been rejected by the FullNode. {ack}")
self.log.warning(f"SpendBundle has been rejected by the FullNode. {ack}")
if ack.error is not None:
await self.wallet_state_manager.remove_from_queue(
ack.txid, name, ack.status, Err[ack.error]
@ -761,7 +747,7 @@ class WalletNode:
self.wallet_state_manager.set_sync_mode(True)
async for ret_msg in self._sync():
yield ret_msg
except (BaseException, asyncio.CancelledError) as e:
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Error with syncing. {type(e)} {tb}")
self.wallet_state_manager.set_sync_mode(False)

View File

@ -158,7 +158,6 @@ class WalletPuzzleStore:
"""
Sets a derivation path to used so we don't use it again.
"""
pass
cursor = await self.db_connection.execute(
"UPDATE derivation_paths SET used=1 WHERE derivation_index<=?", (index,),
)

View File

@ -41,6 +41,7 @@ from src.wallet.wallet import Wallet
from src.types.program import Program
from src.wallet.derivation_record import DerivationRecord
from src.wallet.util.wallet_types import WalletType
from src.consensus.find_fork_point import find_fork_point_in_chain
class WalletStateManager:
@ -137,7 +138,7 @@ class WalletStateManager:
async with self.puzzle_store.lock:
index = await self.puzzle_store.get_last_derivation_path()
if index is None or index < 100:
if index is None or index < self.config["initial_num_public_keys"]:
await self.create_more_puzzle_hashes(from_zero=True)
if len(self.block_records) > 0:
@ -213,7 +214,7 @@ class WalletStateManager:
# This handles the case where the database is empty
unused = uint32(0)
to_generate = 100
to_generate = self.config["initial_num_public_keys"]
for wallet_id in targets:
target_wallet = self.wallets[wallet_id]
@ -560,7 +561,6 @@ class WalletStateManager:
assert block.removals is not None
await wallet.coin_added(coin, index, header_hash, block.removals)
self.log.info(f"Doing state changed for wallet id {wallet_id}")
self.state_changed("coin_added", wallet_id)
async def add_pending_transaction(self, tx_record: TransactionRecord):
@ -728,8 +728,8 @@ class WalletStateManager:
# Not genesis, updated LCA
if block.weight > self.block_records[self.lca].weight:
fork_h = self._find_fork_point_in_chain(
self.block_records[self.lca], block
fork_h = find_fork_point_in_chain(
self.block_records, self.block_records[self.lca], block
)
await self.reorg_rollback(fork_h)
@ -997,24 +997,6 @@ class WalletStateManager:
return False
return True
def _find_fork_point_in_chain(
self, block_1: BlockRecord, block_2: BlockRecord
) -> uint32:
""" Tries to find height where new chain (block_2) diverged from block_1 (assuming prev blocks
are all included in chain)"""
while block_2.height > 0 or block_1.height > 0:
if block_2.height > block_1.height:
block_2 = self.block_records[block_2.prev_header_hash]
elif block_1.height > block_2.height:
block_1 = self.block_records[block_1.prev_header_hash]
else:
if block_2.header_hash == block_1.header_hash:
return block_2.height
block_2 = self.block_records[block_2.prev_header_hash]
block_1 = self.block_records[block_1.prev_header_hash]
assert block_2 == block_1 # Genesis block is the same, genesis fork
return uint32(0)
def validate_select_proofs(
self,
all_proof_hashes: List[Tuple[bytes32, Optional[Tuple[uint64, uint64]]]],
@ -1183,8 +1165,8 @@ class WalletStateManager:
tx_filter = PyBIP158([b for b in transactions_filter])
# Find fork point
fork_h: uint32 = self._find_fork_point_in_chain(
self.block_records[self.lca], new_block
fork_h: uint32 = find_fork_point_in_chain(
self.block_records, self.block_records[self.lca], new_block
)
# Get all unspent coins

View File

@ -1,781 +0,0 @@
import asyncio
import json
import logging
import signal
import time
import traceback
from pathlib import Path
from blspy import ExtendedPrivateKey, PrivateKey
from secrets import token_bytes
from typing import List, Optional, Tuple
import aiohttp
from src.util.byte_types import hexstr_to_bytes
from src.util.keychain import (
Keychain,
seed_from_mnemonic,
bytes_to_mnemonic,
generate_mnemonic,
)
from src.util.path import path_from_root
from src.util.ws_message import create_payload, format_response, pong
from src.wallet.trade_manager import TradeManager
try:
import uvloop
except ImportError:
uvloop = None
from src.cmds.init import check_keys
from src.server.outbound_message import NodeType, OutboundMessage, Message, Delivery
from src.server.server import ChiaServer
from src.simulator.simulator_constants import test_constants
from src.simulator.simulator_protocol import FarmNewBlockProtocol
from src.util.config import load_config_cli, load_config
from src.util.ints import uint64
from src.util.logging import initialize_logging
from src.wallet.util.wallet_types import WalletType
from src.wallet.rl_wallet.rl_wallet import RLWallet
from src.wallet.cc_wallet.cc_wallet import CCWallet
from src.wallet.wallet_info import WalletInfo
from src.wallet.wallet_node import WalletNode
from src.types.mempool_inclusion_status import MempoolInclusionStatus
# Timeout for response from wallet/full node for sending a transaction
TIMEOUT = 30
log = logging.getLogger(__name__)
class WebSocketServer:
def __init__(self, keychain: Keychain, root_path: Path):
self.config = load_config_cli(root_path, "config.yaml", "wallet")
initialize_logging("Wallet %(name)-25s", self.config["logging"], root_path)
self.log = log
self.keychain = keychain
self.websocket = None
self.root_path = root_path
self.wallet_node: Optional[WalletNode] = None
self.trade_manager: Optional[TradeManager] = None
self.shut_down = False
if self.config["testing"] is True:
self.config["database_path"] = "test_db_wallet.db"
async def start(self):
self.log.info("Starting Websocket Server")
def master_close_cb():
asyncio.ensure_future(self.stop())
try:
asyncio.get_running_loop().add_signal_handler(
signal.SIGINT, master_close_cb
)
asyncio.get_running_loop().add_signal_handler(
signal.SIGTERM, master_close_cb
)
except NotImplementedError:
self.log.info("Not implemented")
await self.start_wallet()
await self.connect_to_daemon()
self.log.info("webSocketServer closed")
async def start_wallet(self, public_key_fingerprint: Optional[int] = None) -> bool:
private_keys = self.keychain.get_all_private_keys()
if len(private_keys) == 0:
self.log.info("No keys")
return False
if public_key_fingerprint is not None:
for sk, _ in private_keys:
if sk.get_public_key().get_fingerprint() == public_key_fingerprint:
private_key = sk
break
else:
private_key = private_keys[0][0]
if private_key is None:
self.log.info("No keys")
return False
if self.config["testing"] is True:
log.info("Websocket server in testing mode")
self.wallet_node = await WalletNode.create(
self.config,
private_key,
self.root_path,
override_constants=test_constants,
local_test=True,
)
else:
self.wallet_node = await WalletNode.create(
self.config, private_key, self.root_path
)
if self.wallet_node is None:
return False
self.trade_manager = await TradeManager.create(
self.wallet_node.wallet_state_manager
)
self.wallet_node.wallet_state_manager.set_callback(self.state_changed_callback)
net_config = load_config(self.root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
self.config["port"],
self.wallet_node,
NodeType.WALLET,
ping_interval,
network_id,
self.root_path,
self.config,
)
self.wallet_node.set_server(server)
self.wallet_node._start_bg_tasks()
return True
async def connection(self, ws):
data = {"service": "chia-wallet"}
payload = create_payload("register_service", data, "chia-wallet", "daemon")
await ws.send_str(payload)
while True:
msg = await ws.receive()
if msg.type == aiohttp.WSMsgType.TEXT:
message = msg.data.strip()
# self.log.info(f"received message: {message}")
await self.safe_handle(ws, message)
elif msg.type == aiohttp.WSMsgType.BINARY:
pass
# self.log.warning("Received binary data")
elif msg.type == aiohttp.WSMsgType.PING:
await ws.pong()
elif msg.type == aiohttp.WSMsgType.PONG:
self.log.info("Pong received")
else:
if msg.type == aiohttp.WSMsgType.CLOSE:
print("Closing")
await ws.close()
elif msg.type == aiohttp.WSMsgType.ERROR:
print("Error during receive %s" % ws.exception())
elif msg.type == aiohttp.WSMsgType.CLOSED:
pass
break
await ws.close()
async def connect_to_daemon(self):
while True:
session = None
try:
if self.shut_down:
break
session = aiohttp.ClientSession()
async with session.ws_connect(
"ws://127.0.0.1:55400", autoclose=False, autoping=True
) as ws:
self.websocket = ws
await self.connection(ws)
self.log.info("Connection closed")
self.websocket = None
await session.close()
except BaseException as e:
self.log.error(f"Exception: {e}")
if session is not None:
await session.close()
await asyncio.sleep(1)
async def stop(self):
self.shut_down = True
if self.wallet_node is not None:
self.wallet_node.server.close_all()
self.wallet_node._shutdown()
await self.wallet_node.wallet_state_manager.close_all_stores()
self.log.info("closing websocket")
if self.websocket is not None:
self.log.info("closing websocket 2")
await self.websocket.close()
self.log.info("closied websocket")
async def get_next_puzzle_hash(self, request):
"""
Returns a new puzzlehash
"""
wallet_id = int(request["wallet_id"])
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
if wallet.wallet_info.type == WalletType.STANDARD_WALLET:
puzzle_hash = (await wallet.get_new_puzzlehash()).hex()
elif wallet.wallet_info.type == WalletType.COLOURED_COIN:
puzzle_hash = await wallet.get_new_inner_hash()
response = {
"wallet_id": wallet_id,
"puzzle_hash": puzzle_hash,
}
return response
async def send_transaction(self, request):
wallet_id = int(request["wallet_id"])
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
try:
tx = await wallet.generate_signed_transaction_dict(request)
except BaseException as e:
data = {
"status": "FAILED",
"reason": f"Failed to generate signed transaction {e}",
}
return data
if tx is None:
data = {
"status": "FAILED",
"reason": "Failed to generate signed transaction",
}
return data
try:
await wallet.push_transaction(tx)
except BaseException as e:
data = {
"status": "FAILED",
"reason": f"Failed to push transaction {e}",
}
return data
self.log.error(tx)
sent = False
start = time.time()
while time.time() - start < TIMEOUT:
sent_to: List[
Tuple[str, MempoolInclusionStatus, Optional[str]]
] = await self.wallet_node.wallet_state_manager.get_transaction_status(
tx.name()
)
if len(sent_to) == 0:
await asyncio.sleep(0.1)
continue
status, err = sent_to[0][1], sent_to[0][2]
if status == MempoolInclusionStatus.SUCCESS:
data = {"status": "SUCCESS"}
sent = True
break
elif status == MempoolInclusionStatus.PENDING:
assert err is not None
data = {"status": "PENDING", "reason": err}
sent = True
break
elif status == MempoolInclusionStatus.FAILED:
assert err is not None
data = {"status": "FAILED", "reason": err}
sent = True
break
if not sent:
data = {
"status": "FAILED",
"reason": "Timed out. Transaction may or may not have been sent.",
}
return data
async def get_transactions(self, request):
wallet_id = int(request["wallet_id"])
transactions = await self.wallet_node.wallet_state_manager.get_all_transactions(
wallet_id
)
response = {"success": True, "txs": transactions, "wallet_id": wallet_id}
return response
async def farm_block(self, request):
puzzle_hash = bytes.fromhex(request["puzzle_hash"])
request = FarmNewBlockProtocol(puzzle_hash)
msg = OutboundMessage(
NodeType.FULL_NODE, Message("farm_new_block", request), Delivery.BROADCAST,
)
self.wallet_node.server.push_message(msg)
return {"success": True}
async def get_wallet_balance(self, request):
wallet_id = int(request["wallet_id"])
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
balance = await wallet.get_confirmed_balance()
pending_balance = await wallet.get_unconfirmed_balance()
spendable_balance = await wallet.get_spendable_balance()
pending_change = await wallet.get_pending_change_balance()
if wallet.wallet_info.type == WalletType.COLOURED_COIN:
frozen_balance = 0
else:
frozen_balance = await wallet.get_frozen_amount()
response = {
"wallet_id": wallet_id,
"success": True,
"confirmed_wallet_balance": balance,
"unconfirmed_wallet_balance": pending_balance,
"spendable_balance": spendable_balance,
"frozen_balance": frozen_balance,
"pending_change": pending_change,
}
return response
async def get_sync_status(self):
syncing = self.wallet_node.wallet_state_manager.sync_mode
response = {"syncing": syncing}
return response
async def get_height_info(self):
lca = self.wallet_node.wallet_state_manager.lca
height = self.wallet_node.wallet_state_manager.block_records[lca].height
response = {"height": height}
return response
async def get_connection_info(self):
connections = (
self.wallet_node.server.global_connections.get_full_node_peerinfos()
)
response = {"connections": connections}
return response
async def create_new_wallet(self, request):
config, wallet_state_manager, main_wallet = self.get_wallet_config()
if request["wallet_type"] == "cc_wallet":
if request["mode"] == "new":
cc_wallet: CCWallet = await CCWallet.create_new_cc(
wallet_state_manager, main_wallet, request["amount"]
)
response = {"success": True, "type": cc_wallet.wallet_info.type.name}
return response
elif request["mode"] == "existing":
cc_wallet = await CCWallet.create_wallet_for_cc(
wallet_state_manager, main_wallet, request["colour"]
)
response = {"success": True, "type": cc_wallet.wallet_info.type.name}
return response
response = {"success": False}
return response
def get_wallet_config(self):
return (
self.wallet_node.config,
self.wallet_node.wallet_state_manager,
self.wallet_node.wallet_state_manager.main_wallet,
)
async def get_wallets(self):
wallets: List[
WalletInfo
] = await self.wallet_node.wallet_state_manager.get_all_wallets()
response = {"wallets": wallets, "success": True}
return response
async def rl_set_admin_info(self, request):
wallet_id = int(request["wallet_id"])
wallet: RLWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
user_pubkey = request["user_pubkey"]
limit = uint64(int(request["limit"]))
interval = uint64(int(request["interval"]))
amount = uint64(int(request["amount"]))
success = await wallet.admin_create_coin(interval, limit, user_pubkey, amount)
response = {"success": success}
return response
async def rl_set_user_info(self, request):
wallet_id = int(request["wallet_id"])
wallet: RLWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
admin_pubkey = request["admin_pubkey"]
limit = uint64(int(request["limit"]))
interval = uint64(int(request["interval"]))
origin_id = request["origin_id"]
success = await wallet.set_user_info(interval, limit, origin_id, admin_pubkey)
response = {"success": success}
return response
async def cc_set_name(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
await wallet.set_name(str(request["name"]))
response = {"wallet_id": wallet_id, "success": True}
return response
async def cc_get_name(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
name: str = await wallet.get_name()
response = {"wallet_id": wallet_id, "name": name}
return response
async def cc_spend(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
puzzle_hash = hexstr_to_bytes(request["innerpuzhash"])
try:
tx = await wallet.cc_spend(request["amount"], puzzle_hash)
except BaseException as e:
data = {
"status": "FAILED",
"reason": f"{e}",
}
return data
if tx is None:
data = {
"status": "FAILED",
"reason": "Failed to generate signed transaction",
}
return data
self.log.error(tx)
sent = False
start = time.time()
while time.time() - start < TIMEOUT:
sent_to: List[
Tuple[str, MempoolInclusionStatus, Optional[str]]
] = await self.wallet_node.wallet_state_manager.get_transaction_status(
tx.name()
)
if len(sent_to) == 0:
await asyncio.sleep(0.1)
continue
status, err = sent_to[0][1], sent_to[0][2]
if status == MempoolInclusionStatus.SUCCESS:
data = {"status": "SUCCESS"}
sent = True
break
elif status == MempoolInclusionStatus.PENDING:
assert err is not None
data = {"status": "PENDING", "reason": err}
sent = True
break
elif status == MempoolInclusionStatus.FAILED:
assert err is not None
data = {"status": "FAILED", "reason": err}
sent = True
break
if not sent:
data = {
"status": "FAILED",
"reason": "Timed out. Transaction may or may not have been sent.",
}
return data
async def cc_get_colour(self, request):
wallet_id = int(request["wallet_id"])
wallet: CCWallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
colour: str = await wallet.get_colour()
response = {"colour": colour, "wallet_id": wallet_id}
return response
async def get_wallet_summaries(self):
response = {}
for wallet_id in self.wallet_node.wallet_state_manager.wallets:
wallet = self.wallet_node.wallet_state_manager.wallets[wallet_id]
balance = await wallet.get_confirmed_balance()
type = wallet.wallet_info.type
if type == WalletType.COLOURED_COIN:
name = wallet.cc_info.my_colour_name
colour = await wallet.get_colour()
response[wallet_id] = {
"type": type,
"balance": balance,
"name": name,
"colour": colour,
}
else:
response[wallet_id] = {"type": type, "balance": balance}
return response
async def get_discrepancies_for_offer(self, request):
file_name = request["filename"]
file_path = Path(file_name)
(
success,
discrepancies,
error,
) = await self.trade_manager.get_discrepancies_for_offer(file_path)
if success:
response = {"success": True, "discrepancies": discrepancies}
else:
response = {"success": False, "error": error}
return response
async def create_offer_for_ids(self, request):
offer = request["ids"]
file_name = request["filename"]
success, spend_bundle, error = await self.trade_manager.create_offer_for_ids(
offer
)
if success:
self.trade_manager.write_offer_to_disk(Path(file_name), spend_bundle)
response = {"success": success}
else:
response = {"success": success, "reason": error}
return response
async def respond_to_offer(self, request):
file_path = Path(request["filename"])
success, reason = await self.trade_manager.respond_to_offer(file_path)
if success:
response = {"success": success}
else:
response = {"success": success, "reason": reason}
return response
async def get_public_keys(self):
fingerprints = [
(esk.get_public_key().get_fingerprint(), seed is not None)
for (esk, seed) in self.keychain.get_all_private_keys()
]
response = {"success": True, "public_key_fingerprints": fingerprints}
return response
async def get_private_key(self, request):
fingerprint = request["fingerprint"]
for esk, seed in self.keychain.get_all_private_keys():
if esk.get_public_key().get_fingerprint() == fingerprint:
s = bytes_to_mnemonic(seed) if seed is not None else None
self.log.warning(f"{s}, {esk}")
return {
"success": True,
"private_key": {
"fingerprint": fingerprint,
"esk": bytes(esk).hex(),
"seed": s,
},
}
return {"success": False, "private_key": {"fingerprint": fingerprint}}
async def log_in(self, request):
await self.stop_wallet()
fingerprint = request["fingerprint"]
started = await self.start_wallet(fingerprint)
response = {"success": started}
return response
async def add_key(self, request):
if "mnemonic" in request:
# Adding a key from 24 word mnemonic
mnemonic = request["mnemonic"]
seed = seed_from_mnemonic(mnemonic)
self.keychain.add_private_key_seed(seed)
esk = ExtendedPrivateKey.from_seed(seed)
elif "hexkey" in request:
# Adding a key from hex private key string. Two cases: extended private key (HD)
# which is 77 bytes, and int private key which is 32 bytes.
if len(request["hexkey"]) != 154 and len(request["hexkey"]) != 64:
return {"success": False}
if len(request["hexkey"]) == 64:
sk = PrivateKey.from_bytes(bytes.fromhex(request["hexkey"]))
self.keychain.add_private_key_not_extended(sk)
key_bytes = bytes(sk)
new_extended_bytes = bytearray(
bytes(ExtendedPrivateKey.from_seed(token_bytes(32)))
)
final_extended_bytes = bytes(
new_extended_bytes[: -len(key_bytes)] + key_bytes
)
esk = ExtendedPrivateKey.from_bytes(final_extended_bytes)
else:
esk = ExtendedPrivateKey.from_bytes(bytes.fromhex(request["hexkey"]))
self.keychain.add_private_key(esk)
else:
return {"success": False}
fingerprint = esk.get_public_key().get_fingerprint()
await self.stop_wallet()
# Makes sure the new key is added to config properly
check_keys(self.root_path)
# Starts the wallet with the new key selected
started = await self.start_wallet(fingerprint)
response = {"success": started}
return response
async def delete_key(self, request):
await self.stop_wallet()
fingerprint = request["fingerprint"]
self.log.warning(f"Removing one key {fingerprint}")
self.log.warning(f"{self.keychain.get_all_public_keys()}")
self.keychain.delete_key_by_fingerprint(fingerprint)
self.log.warning(f"{self.keychain.get_all_public_keys()}")
response = {"success": True}
return response
async def clean_all_state(self):
self.keychain.delete_all_keys()
path = path_from_root(self.root_path, self.config["database_path"])
if path.exists():
path.unlink()
async def stop_wallet(self):
if self.wallet_node is not None:
if self.wallet_node.server is not None:
self.wallet_node.server.close_all()
self.wallet_node._shutdown()
await self.wallet_node.wallet_state_manager.close_all_stores()
self.wallet_node = None
async def delete_all_keys(self):
await self.stop_wallet()
await self.clean_all_state()
response = {"success": True}
return response
async def generate_mnemonic(self):
mnemonic = generate_mnemonic()
response = {"success": True, "mnemonic": mnemonic}
return response
async def safe_handle(self, websocket, payload):
message = None
try:
message = json.loads(payload)
response = await self.handle_message(message)
if response is not None:
# self.log.info(f"message: {message}")
# self.log.info(f"response: {response}")
# self.log.info(f"payload: {format_response(message, response)}")
await websocket.send_str(format_response(message, response))
except BaseException as e:
tb = traceback.format_exc()
self.log.error(f"Error while handling message: {tb}")
error = {"success": False, "error": f"{e}"}
if message is None:
return
await websocket.send_str(format_response(message, error))
async def handle_message(self, message):
"""
This function gets called when new message is received via websocket.
"""
command = message["command"]
if message["ack"]:
return None
data = None
if "data" in message:
data = message["data"]
if command == "ping":
return pong()
elif command == "get_wallet_balance":
return await self.get_wallet_balance(data)
elif command == "send_transaction":
return await self.send_transaction(data)
elif command == "get_next_puzzle_hash":
return await self.get_next_puzzle_hash(data)
elif command == "get_transactions":
return await self.get_transactions(data)
elif command == "farm_block":
return await self.farm_block(data)
elif command == "get_sync_status":
return await self.get_sync_status()
elif command == "get_height_info":
return await self.get_height_info()
elif command == "get_connection_info":
return await self.get_connection_info()
elif command == "create_new_wallet":
return await self.create_new_wallet(data)
elif command == "get_wallets":
return await self.get_wallets()
elif command == "rl_set_admin_info":
return await self.rl_set_admin_info(data)
elif command == "rl_set_user_info":
return await self.rl_set_user_info(data)
elif command == "cc_set_name":
return await self.cc_set_name(data)
elif command == "cc_get_name":
return await self.cc_get_name(data)
elif command == "cc_spend":
return await self.cc_spend(data)
elif command == "cc_get_colour":
return await self.cc_get_colour(data)
elif command == "create_offer_for_ids":
return await self.create_offer_for_ids(data)
elif command == "get_discrepancies_for_offer":
return await self.get_discrepancies_for_offer(data)
elif command == "respond_to_offer":
return await self.respond_to_offer(data)
elif command == "get_wallet_summaries":
return await self.get_wallet_summaries()
elif command == "get_public_keys":
return await self.get_public_keys()
elif command == "get_private_key":
return await self.get_private_key(data)
elif command == "generate_mnemonic":
return await self.generate_mnemonic()
elif command == "log_in":
return await self.log_in(data)
elif command == "add_key":
return await self.add_key(data)
elif command == "delete_key":
return await self.delete_key(data)
elif command == "delete_all_keys":
return await self.delete_all_keys()
else:
response = {"error": f"unknown_command {command}"}
return response
async def notify_ui_that_state_changed(self, state: str, wallet_id):
data = {
"state": state,
}
# self.log.info(f"Wallet notify id is: {wallet_id}")
if wallet_id is not None:
data["wallet_id"] = wallet_id
if self.websocket is not None:
try:
await self.websocket.send_str(
create_payload("state_changed", data, "chia-wallet", "wallet_ui")
)
except (BaseException) as e:
try:
self.log.warning(f"Sending data failed. Exception {type(e)}.")
except BrokenPipeError:
pass
def state_changed_callback(self, state: str, wallet_id: int = None):
if self.websocket is None:
return
asyncio.create_task(self.notify_ui_that_state_changed(state, wallet_id))

View File

@ -71,10 +71,10 @@ class BlockTools:
# No real plots supplied, so we will use the small test plots
self.use_any_pos = True
self.plot_config: Dict = {"plots": {}}
# Can't go much lower than 19, since plots start having no solutions
k: uint8 = uint8(19)
# Can't go much lower than 18, since plots start having no solutions
k: uint8 = uint8(18)
# Uses many plots for testing, in order to guarantee proofs of space at every height
num_plots = 40
num_plots = 30
# Use the empty string as the seed for the private key
self.keychain = Keychain("testing", True)
@ -115,7 +115,7 @@ class BlockTools:
k,
b"genesis",
plot_seeds[pn],
2 * 1024,
128,
)
done_filenames.add(filename)
self.plot_config["plots"][str(plot_dir / filename)] = {

View File

@ -424,7 +424,7 @@ class TestWalletSimulator:
@pytest.mark.asyncio
async def test_cc_trade_with_multiple_colours(self, two_wallet_nodes):
num_blocks = 10
num_blocks = 5
full_nodes, wallets = two_wallet_nodes
full_node_1, server_1 = full_nodes[0]
wallet_node, server_2 = wallets[0]

View File

@ -19,19 +19,20 @@ from src.consensus.coinbase import create_coinbase_coin_and_signature
from src.types.sized_bytes import bytes32
from src.full_node.block_store import BlockStore
from src.full_node.coin_store import CoinStore
from src.consensus.find_fork_point import find_fork_point_in_chain
bt = BlockTools()
test_constants: Dict[str, Any] = consensus_constants.copy()
test_constants.update(
{
"DIFFICULTY_STARTING": 5,
"DISCRIMINANT_SIZE_BITS": 16,
"DIFFICULTY_STARTING": 1,
"DISCRIMINANT_SIZE_BITS": 8,
"BLOCK_TIME_TARGET": 10,
"MIN_BLOCK_TIME": 2,
"DIFFICULTY_EPOCH": 12, # The number of blocks per epoch
"DIFFICULTY_DELAY": 3, # EPOCH / WARP_FACTOR
"MIN_ITERS_STARTING": 50 * 2,
"DIFFICULTY_EPOCH": 6, # The number of blocks per epoch
"DIFFICULTY_DELAY": 2, # EPOCH / WARP_FACTOR
"MIN_ITERS_STARTING": 50 * 1,
}
)
test_constants["GENESIS_BLOCK"] = bytes(
@ -493,7 +494,7 @@ class TestBlockValidation:
@pytest.mark.asyncio
async def test_difficulty_change(self):
num_blocks = 30
num_blocks = 14
# Make it 5x faster than target time
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 2)
db_path = Path("blockchain_test.db")
@ -508,19 +509,18 @@ class TestBlockValidation:
assert result == ReceiveBlockResult.ADDED_TO_HEAD
assert error_code is None
diff_25 = b.get_next_difficulty(blocks[24].header)
diff_26 = b.get_next_difficulty(blocks[25].header)
diff_27 = b.get_next_difficulty(blocks[26].header)
diff_12 = b.get_next_difficulty(blocks[11].header)
diff_13 = b.get_next_difficulty(blocks[12].header)
diff_14 = b.get_next_difficulty(blocks[13].header)
assert diff_26 == diff_25
assert diff_27 > diff_26
assert (diff_27 / diff_26) <= test_constants["DIFFICULTY_FACTOR"]
assert diff_13 == diff_12
assert diff_14 > diff_13
assert (diff_14 / diff_13) <= test_constants["DIFFICULTY_FACTOR"]
assert (b.get_next_min_iters(blocks[1])) == test_constants["MIN_ITERS_STARTING"]
assert (b.get_next_min_iters(blocks[24])) == (b.get_next_min_iters(blocks[23]))
assert (b.get_next_min_iters(blocks[25])) == (b.get_next_min_iters(blocks[24]))
assert (b.get_next_min_iters(blocks[26])) > (b.get_next_min_iters(blocks[25]))
assert (b.get_next_min_iters(blocks[27])) == (b.get_next_min_iters(blocks[26]))
assert (b.get_next_min_iters(blocks[12])) == (b.get_next_min_iters(blocks[11]))
assert (b.get_next_min_iters(blocks[13])) > (b.get_next_min_iters(blocks[12]))
assert (b.get_next_min_iters(blocks[14])) == (b.get_next_min_iters(blocks[13]))
await connection.close()
b.shut_down()
@ -529,7 +529,7 @@ class TestBlockValidation:
class TestReorgs:
@pytest.mark.asyncio
async def test_basic_reorg(self):
blocks = bt.get_consecutive_blocks(test_constants, 100, [], 9)
blocks = bt.get_consecutive_blocks(test_constants, 15, [], 9)
db_path = Path("blockchain_test.db")
if db_path.exists():
db_path.unlink()
@ -540,22 +540,22 @@ class TestReorgs:
for i in range(1, len(blocks)):
await b.receive_block(blocks[i])
assert b.get_current_tips()[0].height == 100
assert b.get_current_tips()[0].height == 15
blocks_reorg_chain = bt.get_consecutive_blocks(
test_constants, 30, blocks[:90], 9, b"2"
test_constants, 7, blocks[:10], 9, b"2"
)
for i in range(1, len(blocks_reorg_chain)):
reorg_block = blocks_reorg_chain[i]
result, removed, error_code = await b.receive_block(reorg_block)
if reorg_block.height < 90:
if reorg_block.height < 10:
assert result == ReceiveBlockResult.ALREADY_HAVE_BLOCK
elif reorg_block.height < 99:
elif reorg_block.height < 14:
assert result == ReceiveBlockResult.ADDED_AS_ORPHAN
elif reorg_block.height >= 100:
elif reorg_block.height >= 15:
assert result == ReceiveBlockResult.ADDED_TO_HEAD
assert error_code is None
assert b.get_current_tips()[0].height == 119
assert b.get_current_tips()[0].height == 16
await connection.close()
b.shut_down()
@ -656,12 +656,18 @@ class TestReorgs:
for i in range(1, len(blocks_2)):
await b.receive_block(blocks_2[i])
assert b._find_fork_point_in_chain(blocks[10].header, blocks_2[10].header) == 4
assert (
find_fork_point_in_chain(b.headers, blocks[10].header, blocks_2[10].header)
== 4
)
for i in range(1, len(blocks_3)):
await b.receive_block(blocks_3[i])
assert b._find_fork_point_in_chain(blocks[10].header, blocks_3[10].header) == 2
assert (
find_fork_point_in_chain(b.headers, blocks[10].header, blocks_3[10].header)
== 2
)
assert b.lca_block.data == blocks[2].header.data
@ -669,10 +675,15 @@ class TestReorgs:
await b.receive_block(blocks_reorg[i])
assert (
b._find_fork_point_in_chain(blocks[10].header, blocks_reorg[10].header) == 8
find_fork_point_in_chain(
b.headers, blocks[10].header, blocks_reorg[10].header
)
== 8
)
assert (
b._find_fork_point_in_chain(blocks_2[10].header, blocks_reorg[10].header)
find_fork_point_in_chain(
b.headers, blocks_2[10].header, blocks_reorg[10].header
)
== 4
)
assert b.lca_block.data == blocks[4].header.data

View File

@ -129,7 +129,9 @@ class TestCoinStore:
@pytest.mark.asyncio
async def test_basic_reorg(self):
blocks = bt.get_consecutive_blocks(test_constants, 100, [], 9)
initial_block_count = 20
reorg_length = 15
blocks = bt.get_consecutive_blocks(test_constants, initial_block_count, [], 9)
db_path = Path("blockchain_test.db")
if db_path.exists():
db_path.unlink()
@ -141,7 +143,7 @@ class TestCoinStore:
for i in range(1, len(blocks)):
await b.receive_block(blocks[i])
assert b.get_current_tips()[0].height == 100
assert b.get_current_tips()[0].height == initial_block_count
for c, block in enumerate(blocks):
unspent = await coin_store.get_coin_record(
@ -158,17 +160,21 @@ class TestCoinStore:
assert unspent_fee.name == block.header.data.fees_coin.name()
blocks_reorg_chain = bt.get_consecutive_blocks(
test_constants, 30, blocks[:90], 9, b"1"
test_constants,
reorg_length,
blocks[: initial_block_count - 10],
9,
b"1",
)
for i in range(1, len(blocks_reorg_chain)):
reorg_block = blocks_reorg_chain[i]
result, removed, error_code = await b.receive_block(reorg_block)
if reorg_block.height < 90:
if reorg_block.height < initial_block_count - 10:
assert result == ReceiveBlockResult.ALREADY_HAVE_BLOCK
elif reorg_block.height < 99:
elif reorg_block.height < initial_block_count - 1:
assert result == ReceiveBlockResult.ADDED_AS_ORPHAN
elif reorg_block.height >= 100:
elif reorg_block.height >= initial_block_count:
assert result == ReceiveBlockResult.ADDED_TO_HEAD
unspent = await coin_store.get_coin_record(
reorg_block.header.data.coinbase.name(), reorg_block.header
@ -178,7 +184,10 @@ class TestCoinStore:
assert unspent.spent == 0
assert unspent.spent_block_index == 0
assert error_code is None
assert b.get_current_tips()[0].height == 119
assert (
b.get_current_tips()[0].height
== initial_block_count - 10 + reorg_length - 1
)
except Exception as e:
await connection.close()
Path("blockchain_test.db").unlink()

View File

@ -483,7 +483,7 @@ class TestFullNodeProtocol:
blocks_new = bt.get_consecutive_blocks(
test_constants,
40,
10,
blocks_list[:],
4,
reward_puzzlehash=coinbase_puzzlehash,
@ -505,11 +505,11 @@ class TestFullNodeProtocol:
candidates.append(blocks_new_2[-1])
unf_block_not_child = FullBlock(
blocks_new[30].proof_of_space,
blocks_new[-7].proof_of_space,
None,
blocks_new[30].header,
blocks_new[30].transactions_generator,
blocks_new[30].transactions_filter,
blocks_new[-7].header,
blocks_new[-7].transactions_generator,
blocks_new[-7].transactions_filter,
)
unf_block_req_bad = fnp.RespondUnfinishedBlock(unf_block_not_child)
@ -541,18 +541,19 @@ class TestFullNodeProtocol:
# Slow block should delay prop
start = time.time()
propagation_messages = [
x async for x in full_node_1.respond_unfinished_block(get_cand(40))
x async for x in full_node_1.respond_unfinished_block(get_cand(20))
]
assert len(propagation_messages) == 2
assert isinstance(
propagation_messages[0].message.data, timelord_protocol.ProofOfSpaceInfo
)
assert isinstance(propagation_messages[1].message.data, fnp.NewUnfinishedBlock)
assert time.time() - start > 3
# TODO: fix
# assert time.time() - start > 3
# Already seen
assert (
len([x async for x in full_node_1.respond_unfinished_block(get_cand(40))])
len([x async for x in full_node_1.respond_unfinished_block(get_cand(20))])
== 0
)
@ -561,6 +562,7 @@ class TestFullNodeProtocol:
len([x async for x in full_node_1.respond_unfinished_block(get_cand(49))])
== 0
)
# Fastest equal height should propagate
start = time.time()
assert (
@ -870,12 +872,6 @@ class TestWalletProtocol:
@pytest.mark.asyncio
async def test_request_all_proof_hashes(self, two_nodes):
full_node_1, full_node_2, server_1, server_2 = two_nodes
num_blocks = test_constants["DIFFICULTY_EPOCH"] * 2
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
for block in blocks:
async for _ in full_node_1.respond_block(fnp.RespondBlock(block)):
pass
blocks_list = await get_block_path(full_node_1)
msgs = [
@ -885,7 +881,7 @@ class TestWalletProtocol:
)
]
hashes = msgs[0].message.data.hashes
assert len(hashes) >= num_blocks - 1
assert len(hashes) >= len(blocks_list) - 2
for i in range(len(hashes)):
if (
i % test_constants["DIFFICULTY_EPOCH"]
@ -909,11 +905,6 @@ class TestWalletProtocol:
@pytest.mark.asyncio
async def test_request_all_header_hashes_after(self, two_nodes):
full_node_1, full_node_2, server_1, server_2 = two_nodes
num_blocks = 18
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
for block in blocks[:10]:
async for _ in full_node_1.respond_block(fnp.RespondBlock(block)):
pass
blocks_list = await get_block_path(full_node_1)
msgs = [

View File

@ -23,7 +23,7 @@ class TestFullSync:
@pytest.mark.asyncio
async def test_basic_sync(self, two_nodes):
num_blocks = 100
num_blocks = 40
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
full_node_1, full_node_2, server_1, server_2 = two_nodes

View File

@ -38,8 +38,8 @@ class TestMempool:
yield _
@pytest.fixture(scope="function")
async def two_nodes_standard_freeze(self):
async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 200}):
async def two_nodes_small_freeze(self):
async for _ in setup_two_nodes({"COINBASE_FREEZE_PERIOD": 30}):
yield _
@pytest.mark.asyncio
@ -77,7 +77,7 @@ class TestMempool:
assert sb is spend_bundle
@pytest.mark.asyncio
async def test_coinbase_freeze(self, two_nodes_standard_freeze):
async def test_coinbase_freeze(self, two_nodes_small_freeze):
num_blocks = 2
wallet_a = WalletTool()
coinbase_puzzlehash = wallet_a.get_new_puzzlehash()
@ -87,7 +87,7 @@ class TestMempool:
blocks = bt.get_consecutive_blocks(
test_constants, num_blocks, [], 10, b"", coinbase_puzzlehash
)
full_node_1, full_node_2, server_1, server_2 = two_nodes_standard_freeze
full_node_1, full_node_2, server_1, server_2 = two_nodes_small_freeze
block = blocks[1]
async for _ in full_node_1.respond_block(
@ -112,10 +112,10 @@ class TestMempool:
assert sb is None
blocks = bt.get_consecutive_blocks(
test_constants, 200, [], 10, b"", coinbase_puzzlehash
test_constants, 30, [], 10, b"", coinbase_puzzlehash
)
for i in range(1, 201):
for i in range(1, 31):
async for _ in full_node_1.respond_block(
full_node_protocol.RespondBlock(blocks[i])
):

View File

@ -39,7 +39,7 @@ class TestNodeLoad:
await asyncio.sleep(2) # Allow connections to get made
num_unfinished_blocks = 1000
num_unfinished_blocks = 500
start_unf = time.time()
for i in range(num_unfinished_blocks):
msg = Message(
@ -56,7 +56,7 @@ class TestNodeLoad:
OutboundMessage(NodeType.FULL_NODE, block_msg, Delivery.BROADCAST)
)
while time.time() - start_unf < 100:
while time.time() - start_unf < 50:
if (
max([h.height for h in full_node_2.blockchain.get_current_tips()])
== num_blocks - 1
@ -71,7 +71,7 @@ class TestNodeLoad:
@pytest.mark.asyncio
async def test_blocks_load(self, two_nodes):
num_blocks = 100
num_blocks = 50
full_node_1, full_node_2, server_1, server_2 = two_nodes
blocks = bt.get_consecutive_blocks(test_constants, num_blocks, [], 10)
@ -92,4 +92,4 @@ class TestNodeLoad:
OutboundMessage(NodeType.FULL_NODE, msg, Delivery.BROADCAST)
)
print(f"Time taken to process {num_blocks} is {time.time() - start_unf}")
assert time.time() - start_unf < 200
assert time.time() - start_unf < 100

View File

@ -1,13 +1,15 @@
import asyncio
import pytest
from src.rpc.farmer_rpc_api import FarmerRpcApi
from src.rpc.harvester_rpc_api import HarvesterRpcApi
from blspy import PrivateKey
from chiapos import DiskPlotter
from src.types.proof_of_space import ProofOfSpace
from src.rpc.farmer_rpc_server import start_farmer_rpc_server
from src.rpc.harvester_rpc_server import start_harvester_rpc_server
from src.rpc.farmer_rpc_client import FarmerRpcClient
from src.rpc.harvester_rpc_client import HarvesterRpcClient
from src.rpc.rpc_server import start_rpc_server
from src.util.ints import uint16
from tests.setup_nodes import setup_full_system, test_constants
from tests.block_tools import get_plot_dir
@ -37,9 +39,14 @@ class TestRpc:
def stop_node_cb_2():
pass
rpc_cleanup = await start_farmer_rpc_server(farmer, stop_node_cb, test_rpc_port)
rpc_cleanup_2 = await start_harvester_rpc_server(
harvester, stop_node_cb_2, test_rpc_port_2
farmer_rpc_api = FarmerRpcApi(farmer)
harvester_rpc_api = HarvesterRpcApi(harvester)
rpc_cleanup = await start_rpc_server(
farmer_rpc_api, test_rpc_port, stop_node_cb
)
rpc_cleanup_2 = await start_rpc_server(
harvester_rpc_api, test_rpc_port_2, stop_node_cb_2
)
try:
@ -71,7 +78,7 @@ class TestRpc:
18,
b"genesis",
plot_seed,
2 * 1024,
128,
)
await client_2.add_plot(str(plot_dir / filename), plot_sk)
@ -91,7 +98,7 @@ class TestRpc:
18,
b"genesis",
plot_seed,
2 * 1024,
128,
)
await client_2.add_plot(str(plot_dir / filename), plot_sk, pool_pk)
assert len((await client_2.get_plots())["plots"]) == num_plots + 1

View File

@ -2,7 +2,8 @@ import asyncio
import pytest
from src.rpc.full_node_rpc_server import start_full_node_rpc_server
from src.rpc.full_node_rpc_api import FullNodeRpcApi
from src.rpc.rpc_server import start_rpc_server
from src.protocols import full_node_protocol
from src.rpc.full_node_rpc_client import FullNodeRpcClient
from src.util.ints import uint16
@ -42,8 +43,10 @@ class TestRpc:
full_node_1._close()
server_1.close_all()
rpc_cleanup = await start_full_node_rpc_server(
full_node_1, stop_node_cb, test_rpc_port
full_node_rpc_api = FullNodeRpcApi(full_node_1)
rpc_cleanup = await start_rpc_server(
full_node_rpc_api, test_rpc_port, stop_node_cb
)
try:

View File

@ -1,4 +1,5 @@
import asyncio
import signal
from typing import Any, Dict, Tuple, List
from src.full_node.full_node import FullNode
@ -17,6 +18,8 @@ from src.timelord import Timelord
from src.server.connection import PeerInfo
from src.server.start_service import create_periodic_introducer_poll_task
from src.util.ints import uint16, uint32
from src.server.start_service import Service
from src.rpc.harvester_rpc_api import HarvesterRpcApi
bt = BlockTools()
@ -25,7 +28,7 @@ root_path = bt.root_path
test_constants: Dict[str, Any] = {
"DIFFICULTY_STARTING": 1,
"DISCRIMINANT_SIZE_BITS": 16,
"DISCRIMINANT_SIZE_BITS": 8,
"BLOCK_TIME_TARGET": 10,
"MIN_BLOCK_TIME": 2,
"DIFFICULTY_EPOCH": 12, # The number of blocks per epoch
@ -34,7 +37,7 @@ test_constants: Dict[str, Any] = {
"PROPAGATION_DELAY_THRESHOLD": 20,
"TX_PER_SEC": 1,
"MEMPOOL_BLOCK_BUFFER": 10,
"MIN_ITERS_STARTING": 50 * 2,
"MIN_ITERS_STARTING": 50 * 1,
}
test_constants["GENESIS_BLOCK"] = bytes(
bt.create_genesis_block(test_constants, bytes([0] * 32), b"0")
@ -50,8 +53,7 @@ async def _teardown_nodes(node_aiters: List) -> None:
pass
async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={}):
# SETUP
async def setup_full_node(db_name, port, introducer_port=None, simulator=False, dic={}):
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
@ -60,328 +62,330 @@ async def setup_full_node_simulator(db_name, port, introducer_port=None, dic={})
if db_path.exists():
db_path.unlink()
net_config = load_config(root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
config = load_config(root_path, "config.yaml", "full_node")
config["database_path"] = str(db_path)
if introducer_port is not None:
config["introducer_peer"]["host"] = "127.0.0.1"
config["introducer_peer"]["port"] = introducer_port
full_node_1 = await FullNodeSimulator.create(
config=config,
name=f"full_node_{port}",
root_path=root_path,
override_constants=test_constants_copy,
)
assert ping_interval is not None
assert network_id is not None
server_1 = ChiaServer(
port,
full_node_1,
NodeType.FULL_NODE,
ping_interval,
network_id,
bt.root_path,
config,
"full-node-simulator-server",
)
_ = await start_server(server_1, full_node_1._on_connect)
full_node_1._set_server(server_1)
yield (full_node_1, server_1)
# TEARDOWN
_.close()
server_1.close_all()
full_node_1._close()
await server_1.await_closed()
await full_node_1._await_closed()
db_path.unlink()
async def setup_full_node(db_name, port, introducer_port=None, dic={}):
# SETUP
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
db_path = root_path / f"{db_name}"
if db_path.exists():
db_path.unlink()
net_config = load_config(root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
config = load_config(root_path, "config.yaml", "full_node")
config = load_config(bt.root_path, "config.yaml", "full_node")
config["database_path"] = db_name
config["send_uncompact_interval"] = 30
periodic_introducer_poll = None
if introducer_port is not None:
config["introducer_peer"]["host"] = "127.0.0.1"
config["introducer_peer"]["port"] = introducer_port
full_node_1 = await FullNode.create(
periodic_introducer_poll = (
PeerInfo("127.0.0.1", introducer_port),
30,
config["target_peer_count"],
)
FullNodeApi = FullNodeSimulator if simulator else FullNode
api = FullNodeApi(
config=config,
root_path=root_path,
name=f"full_node_{port}",
override_constants=test_constants_copy,
)
assert ping_interval is not None
assert network_id is not None
server_1 = ChiaServer(
port,
full_node_1,
NodeType.FULL_NODE,
ping_interval,
network_id,
root_path,
config,
f"full_node_server_{port}",
)
_ = await start_server(server_1, full_node_1._on_connect)
full_node_1._set_server(server_1)
if introducer_port is not None:
peer_info = PeerInfo("127.0.0.1", introducer_port)
create_periodic_introducer_poll_task(
server_1,
peer_info,
full_node_1.global_connections,
config["introducer_connect_interval"],
config["target_peer_count"],
)
yield (full_node_1, server_1)
# TEARDOWN
_.close()
server_1.close_all()
full_node_1._close()
await server_1.await_closed()
await full_node_1._await_closed()
db_path = root_path / f"{db_name}"
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=root_path,
api=api,
node_type=NodeType.FULL_NODE,
advertised_port=port,
service_name="full_node",
server_listen_ports=[port],
auth_connect_peers=False,
on_connect_callback=api._on_connect,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
periodic_introducer_poll=periodic_introducer_poll,
)
run_task = asyncio.create_task(service.run())
await started.wait()
yield api, api.server
service.stop()
await run_task
if db_path.exists():
db_path.unlink()
async def setup_wallet_node(
port, introducer_port=None, key_seed=b"setup_wallet_node", dic={}
port,
full_node_port=None,
introducer_port=None,
key_seed=b"setup_wallet_node",
dic={},
):
config = load_config(root_path, "config.yaml", "wallet")
if "starting_height" in dic:
config["starting_height"] = dic["starting_height"]
config["initial_num_public_keys"] = 5
keychain = Keychain(key_seed.hex(), True)
keychain.add_private_key_seed(key_seed)
private_key = keychain.get_all_private_keys()[0][0]
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
db_path = root_path / f"test-wallet-db-{port}.db"
db_path_key_suffix = str(
keychain.get_all_public_keys()[0].get_public_key().get_fingerprint()
)
db_name = f"test-wallet-db-{port}"
db_path = root_path / f"test-wallet-db-{port}-{db_path_key_suffix}"
if db_path.exists():
db_path.unlink()
config["database_path"] = str(db_path)
config["database_path"] = str(db_name)
net_config = load_config(root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
wallet = await WalletNode.create(
api = WalletNode(
config,
private_key,
keychain,
root_path,
override_constants=test_constants_copy,
name="wallet1",
)
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
port,
wallet,
NodeType.WALLET,
ping_interval,
network_id,
root_path,
config,
"wallet-server",
periodic_introducer_poll = None
if introducer_port is not None:
periodic_introducer_poll = (
PeerInfo("127.0.0.1", introducer_port),
30,
config["target_peer_count"],
)
connect_peers: List[PeerInfo] = []
if full_node_port is not None:
connect_peers = [PeerInfo("127.0.0.1", full_node_port)]
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=root_path,
api=api,
node_type=NodeType.WALLET,
advertised_port=port,
service_name="wallet",
server_listen_ports=[port],
connect_peers=connect_peers,
auth_connect_peers=False,
on_connect_callback=api._on_connect,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
periodic_introducer_poll=periodic_introducer_poll,
)
wallet.set_server(server)
yield (wallet, server)
run_task = asyncio.create_task(service.run())
await started.wait()
server.close_all()
await wallet.wallet_state_manager.clear_all_stores()
await wallet.wallet_state_manager.close_all_stores()
wallet.wallet_state_manager.unlink_db()
await server.await_closed()
yield api, api.server
# await asyncio.sleep(1) # Sleep to ÷
service.stop()
await run_task
if db_path.exists():
db_path.unlink()
keychain.delete_all_keys()
async def setup_harvester(port, dic={}):
async def setup_harvester(port, farmer_port, dic={}):
config = load_config(bt.root_path, "config.yaml", "harvester")
harvester = Harvester(config, bt.plot_config, bt.root_path)
api = Harvester(config, bt.plot_config, bt.root_path)
net_config = load_config(bt.root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
port,
harvester,
NodeType.HARVESTER,
ping_interval,
network_id,
bt.root_path,
config,
f"harvester_server_{port}",
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=root_path,
api=api,
node_type=NodeType.HARVESTER,
advertised_port=port,
service_name="harvester",
server_listen_ports=[port],
connect_peers=[PeerInfo("127.0.0.1", farmer_port)],
auth_connect_peers=True,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
)
harvester.set_server(server)
yield (harvester, server)
run_task = asyncio.create_task(service.run())
await started.wait()
server.close_all()
harvester._shutdown()
await server.await_closed()
await harvester._await_shutdown()
yield api, api.server
service.stop()
await run_task
async def setup_farmer(port, dic={}):
print("root path", root_path)
config = load_config(root_path, "config.yaml", "farmer")
async def setup_farmer(port, full_node_port, dic={}):
config = load_config(bt.root_path, "config.yaml", "farmer")
config_pool = load_config(root_path, "config.yaml", "pool")
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
net_config = load_config(root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
config["xch_target_puzzle_hash"] = bt.fee_target.hex()
config["pool_public_keys"] = [
bytes(epk.get_public_key()).hex() for epk in bt.keychain.get_all_public_keys()
]
config_pool["xch_target_puzzle_hash"] = bt.fee_target.hex()
farmer = Farmer(config, config_pool, bt.keychain, test_constants_copy)
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
port,
farmer,
NodeType.FARMER,
ping_interval,
network_id,
root_path,
config,
f"farmer_server_{port}",
api = Farmer(config, config_pool, bt.keychain, test_constants_copy)
started = asyncio.Event()
async def start_callback():
nonlocal started
started.set()
service = Service(
root_path=root_path,
api=api,
node_type=NodeType.FARMER,
advertised_port=port,
service_name="farmer",
server_listen_ports=[port],
on_connect_callback=api._on_connect,
connect_peers=[PeerInfo("127.0.0.1", full_node_port)],
auth_connect_peers=False,
start_callback=start_callback,
)
farmer.set_server(server)
_ = await start_server(server, farmer._on_connect)
yield (farmer, server)
run_task = asyncio.create_task(service.run())
await started.wait()
_.close()
server.close_all()
await server.await_closed()
yield api, api.server
service.stop()
await run_task
async def setup_introducer(port, dic={}):
net_config = load_config(root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
config = load_config(bt.root_path, "config.yaml", "introducer")
api = Introducer(config["max_peers_to_send"], config["recent_peer_threshold"])
config = load_config(root_path, "config.yaml", "introducer")
started = asyncio.Event()
introducer = Introducer(
config["max_peers_to_send"], config["recent_peer_threshold"]
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=root_path,
api=api,
node_type=NodeType.INTRODUCER,
advertised_port=port,
service_name="introducer",
server_listen_ports=[port],
auth_connect_peers=False,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
)
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
port,
introducer,
NodeType.INTRODUCER,
ping_interval,
network_id,
bt.root_path,
config,
f"introducer_server_{port}",
)
_ = await start_server(server)
yield (introducer, server)
run_task = asyncio.create_task(service.run())
await started.wait()
_.close()
server.close_all()
await server.await_closed()
yield api, api.server
service.stop()
await run_task
async def setup_vdf_clients(port):
vdf_task = asyncio.create_task(spawn_process("127.0.0.1", port, 1))
def stop():
asyncio.create_task(kill_processes())
asyncio.get_running_loop().add_signal_handler(signal.SIGTERM, stop)
asyncio.get_running_loop().add_signal_handler(signal.SIGINT, stop)
yield vdf_task
try:
await kill_processes()
except Exception:
pass
await kill_processes()
async def setup_timelord(port, sanitizer, dic={}):
config = load_config(root_path, "config.yaml", "timelord")
async def setup_timelord(port, full_node_port, sanitizer, dic={}):
config = load_config(bt.root_path, "config.yaml", "timelord")
test_constants_copy = test_constants.copy()
for k in dic.keys():
test_constants_copy[k] = dic[k]
config["sanitizer_mode"] = sanitizer
timelord = Timelord(config, test_constants_copy)
net_config = load_config(root_path, "config.yaml")
ping_interval = net_config.get("ping_interval")
network_id = net_config.get("network_id")
assert ping_interval is not None
assert network_id is not None
server = ChiaServer(
port,
timelord,
NodeType.TIMELORD,
ping_interval,
network_id,
bt.root_path,
config,
f"timelord_server_{port}",
)
vdf_server_port = config["vdf_server"]["port"]
if sanitizer:
vdf_server_port = 7999
config["vdf_server"]["port"] = 7999
coro = asyncio.start_server(
timelord._handle_client,
config["vdf_server"]["host"],
vdf_server_port,
loop=asyncio.get_running_loop(),
api = Timelord(config, test_constants_copy)
started = asyncio.Event()
async def start_callback():
await api._start()
nonlocal started
started.set()
def stop_callback():
api._close()
async def await_closed_callback():
await api._await_closed()
service = Service(
root_path=root_path,
api=api,
node_type=NodeType.TIMELORD,
advertised_port=port,
service_name="timelord",
server_listen_ports=[port],
connect_peers=[PeerInfo("127.0.0.1", full_node_port)],
auth_connect_peers=False,
start_callback=start_callback,
stop_callback=stop_callback,
await_closed_callback=await_closed_callback,
)
vdf_server = asyncio.ensure_future(coro)
run_task = asyncio.create_task(service.run())
await started.wait()
timelord.set_server(server)
yield api, api.server
if not sanitizer:
timelord_task = asyncio.create_task(timelord._manage_discriminant_queue())
else:
timelord_task = asyncio.create_task(timelord._manage_discriminant_queue_sanitizer())
yield (timelord, server)
vdf_server.cancel()
server.close_all()
timelord._shutdown()
await timelord_task
await server.await_closed()
service.stop()
await run_task
async def setup_two_nodes(dic={}):
@ -390,8 +394,8 @@ async def setup_two_nodes(dic={}):
Setup and teardown of two full nodes, with blockchains and separate DBs.
"""
node_iters = [
setup_full_node("blockchain_test.db", 21234, dic=dic),
setup_full_node("blockchain_test_2.db", 21235, dic=dic),
setup_full_node("blockchain_test.db", 21234, simulator=False, dic=dic),
setup_full_node("blockchain_test_2.db", 21235, simulator=False, dic=dic),
]
fn1, s1 = await node_iters[0].__anext__()
@ -404,8 +408,8 @@ async def setup_two_nodes(dic={}):
async def setup_node_and_wallet(dic={}):
node_iters = [
setup_full_node_simulator("blockchain_test.db", 21234, dic=dic),
setup_wallet_node(21235, dic=dic),
setup_full_node("blockchain_test.db", 21234, simulator=False, dic=dic),
setup_wallet_node(21235, None, dic=dic),
]
full_node, s1 = await node_iters[0].__anext__()
@ -416,22 +420,6 @@ async def setup_node_and_wallet(dic={}):
await _teardown_nodes(node_iters)
async def setup_node_and_two_wallets(dic={}):
node_iters = [
setup_full_node("blockchain_test.db", 21234, dic=dic),
setup_wallet_node(21235, key_seed=b"a", dic=dic),
setup_wallet_node(21236, key_seed=b"b", dic=dic),
]
full_node, s1 = await node_iters[0].__anext__()
wallet, s2 = await node_iters[1].__anext__()
wallet_2, s3 = await node_iters[2].__anext__()
yield (full_node, wallet, wallet_2, s1, s2, s3)
await _teardown_nodes(node_iters)
async def setup_simulators_and_wallets(
simulator_count: int, wallet_count: int, dic: Dict
):
@ -440,16 +428,16 @@ async def setup_simulators_and_wallets(
node_iters = []
for index in range(0, simulator_count):
db_name = f"blockchain_test{index}.db"
port = 50000 + index
sim = setup_full_node_simulator(db_name, port, dic=dic)
db_name = f"blockchain_test_{port}.db"
sim = setup_full_node(db_name, port, simulator=True, dic=dic)
simulators.append(await sim.__anext__())
node_iters.append(sim)
for index in range(0, wallet_count):
seed = bytes(uint32(index))
port = 55000 + index
wlt = setup_wallet_node(port, key_seed=seed, dic=dic)
wlt = setup_wallet_node(port, None, key_seed=seed, dic=dic)
wallets.append(await wlt.__anext__())
node_iters.append(wlt)
@ -461,19 +449,20 @@ async def setup_simulators_and_wallets(
async def setup_full_system(dic={}):
node_iters = [
setup_introducer(21233),
setup_harvester(21234, dic),
setup_farmer(21235, dic),
setup_timelord(21236, False, dic),
setup_harvester(21234, 21235, dic),
setup_farmer(21235, 21237, dic),
setup_timelord(21236, 21237, False, dic),
setup_vdf_clients(8000),
setup_full_node("blockchain_test.db", 21237, 21233, dic),
setup_full_node("blockchain_test_2.db", 21238, 21233, dic),
setup_timelord(21239, True, dic),
setup_full_node("blockchain_test.db", 21237, 21233, False, dic),
setup_full_node("blockchain_test_2.db", 21238, 21233, False, dic),
setup_timelord(21239, 21238, True, dic),
setup_vdf_clients(7999),
]
introducer, introducer_server = await node_iters[0].__anext__()
harvester, harvester_server = await node_iters[1].__anext__()
farmer, farmer_server = await node_iters[2].__anext__()
await asyncio.sleep(2)
timelord, timelord_server = await node_iters[3].__anext__()
vdf = await node_iters[4].__anext__()
node1, node1_server = await node_iters[5].__anext__()
@ -481,18 +470,16 @@ async def setup_full_system(dic={}):
sanitizer, sanitizer_server = await node_iters[7].__anext__()
vdf_sanitizer = await node_iters[8].__anext__()
await harvester_server.start_client(
PeerInfo("127.0.0.1", uint16(farmer_server._port)), auth=True
yield (
node1,
node2,
harvester,
farmer,
introducer,
timelord,
vdf,
sanitizer,
vdf_sanitizer,
)
await farmer_server.start_client(PeerInfo("127.0.0.1", uint16(node1_server._port)))
await timelord_server.start_client(
PeerInfo("127.0.0.1", uint16(node1_server._port))
)
await sanitizer_server.start_client(
PeerInfo("127.0.0.1", uint16(node2_server._port))
)
yield (node1, node2, harvester, farmer, introducer, timelord, vdf, sanitizer, vdf_sanitizer)
await _teardown_nodes(node_iters)

View File

@ -1,11 +1,12 @@
import asyncio
import pytest
import time
from typing import Dict, Any
from typing import Dict, Any, List
from tests.setup_nodes import setup_full_system
from tests.block_tools import BlockTools
from src.consensus.constants import constants as consensus_constants
from src.util.ints import uint32
from src.types.full_block import FullBlock
bt = BlockTools()
test_constants: Dict[str, Any] = consensus_constants.copy()
@ -33,16 +34,16 @@ class TestSimulation:
node1, node2, _, _, _, _, _, _, _ = simulation
start = time.time()
# Use node2 to test node communication, since only node1 extends the chain.
while time.time() - start < 500:
while time.time() - start < 100:
if max([h.height for h in node2.blockchain.get_current_tips()]) > 10:
break
await asyncio.sleep(1)
if max([h.height for h in node2.blockchain.get_current_tips()]) <= 10:
raise Exception("Failed: could not get 10 blocks.")
raise Exception("Failed: could not get 10 blocks.")
# Wait additional 2 minutes to get a compact block.
while time.time() - start < 620:
while time.time() - start < 120:
max_height = node1.blockchain.lca_block.height
for h in range(1, max_height):
blocks_1: List[FullBlock] = await node1.block_store.get_blocks_at(
@ -54,10 +55,12 @@ class TestSimulation:
has_compact_1 = False
has_compact_2 = False
for block in blocks_1:
assert block.proof_of_time is not None
if block.proof_of_time.witness_type == 0:
has_compact_1 = True
break
for block in blocks_2:
assert block.proof_of_time is not None
if block.proof_of_time.witness_type == 0:
has_compact_2 = True
break

View File

@ -384,8 +384,3 @@ class TestWalletSync:
assert len(records) == 1
assert not records[0].spent
assert not records[0].coinbase
@pytest.mark.asyncio
async def test_random_order_wallet_node(self, wallet_node):
# Call respond_removals and respond_additions in random orders
pass