Merge commit '12fcaf0e982590efcf26a954c67cd760035d9e6e' into checkpoint/main_from_release_1.6.0_12fcaf0e982590efcf26a954c67cd760035d9e6e

This commit is contained in:
Amine Khaldi 2022-08-29 17:45:49 +01:00
commit eeb2277955
No known key found for this signature in database
GPG Key ID: B1C074FFC904E2D9
85 changed files with 13470 additions and 107 deletions

2
.gitignore vendored
View File

@ -90,6 +90,8 @@ win_code_sign_cert.p12
# chia-blockchain wheel build folder
build/
# data layer
**/dl_server_files*
# Temporal `n` (node version manager) directory
.n/

@ -1 +1 @@
Subproject commit 5e52b7f53478d4f39b4a0f0203f41e31dd53aee9
Subproject commit e44fa0e8ccb2bd44cac01d8c9df385076f97dc06

View File

@ -4,6 +4,7 @@ import click
from chia import __version__
from chia.cmds.configure import configure_cmd
from chia.cmds.farm import farm_cmd
from chia.cmds.data import data_cmd
from chia.cmds.init import init_cmd
from chia.cmds.keys import keys_cmd
from chia.cmds.netspace import netspace_cmd
@ -141,6 +142,7 @@ cli.add_command(farm_cmd)
cli.add_command(plotters_cmd)
cli.add_command(db_cmd)
cli.add_command(peer_cmd)
cli.add_command(data_cmd)
cli.add_command(passphrase_cmd)

View File

@ -22,6 +22,7 @@ def configure(
crawler_minimum_version_count: Optional[int],
seeder_domain_name: str,
seeder_nameserver: str,
enable_data_server: str = "",
):
with lock_and_load_config(root_path, "config.yaml") as config:
change_made = False
@ -92,6 +93,13 @@ def configure(
config["full_node"]["target_peer_count"] = int(set_peer_count)
print("Target peer count updated")
change_made = True
if enable_data_server:
config["data_layer"]["run_server"] = str2bool(enable_data_server)
if str2bool(enable_data_server):
print("Data Server enabled.")
else:
print("Data Server disabled.")
change_made = True
if testnet:
if testnet == "true" or testnet == "t":
print("Setting Testnet")
@ -126,6 +134,7 @@ def configure(
config["ui"]["selected_network"] = testnet
config["introducer"]["selected_network"] = testnet
config["wallet"]["selected_network"] = testnet
config["data_layer"]["selected_network"] = testnet
if "seeder" in config:
config["seeder"]["port"] = int(testnet_port)
@ -163,6 +172,7 @@ def configure(
config["ui"]["selected_network"] = net
config["introducer"]["selected_network"] = net
config["wallet"]["selected_network"] = net
config["data_layer"]["selected_network"] = net
if "seeder" in config:
config["seeder"]["port"] = int(mainnet_port)
@ -260,6 +270,12 @@ def configure(
help="configures the seeder nameserver setting. Ex: `example.com.`",
type=str,
)
@click.option(
"--enable-data-server",
"--data-server",
help="Enable or disable data propagation server for your data layer",
type=click.Choice(["true", "t", "false", "f"]),
)
@click.pass_context
def configure_cmd(
ctx,
@ -277,6 +293,7 @@ def configure_cmd(
crawler_minimum_version_count,
seeder_domain_name,
seeder_nameserver,
enable_data_server,
):
configure(
ctx.obj["root_path"],
@ -294,4 +311,5 @@ def configure_cmd(
crawler_minimum_version_count,
seeder_domain_name,
seeder_nameserver,
enable_data_server,
)

369
chia/cmds/data.py Normal file
View File

@ -0,0 +1,369 @@
import json
import logging
from pathlib import Path
from typing import Any, Coroutine, Dict, List, Optional, TypeVar
import click
from typing_extensions import Protocol
_T = TypeVar("_T")
class IdentityFunction(Protocol):
def __call__(self, __x: _T) -> _T:
...
logger = logging.getLogger(__name__)
# TODO: this is more general and should be part of refactoring the overall CLI code duplication
def run(coro: Coroutine[Any, Any, Optional[Dict[str, Any]]]) -> None:
import asyncio
response = asyncio.run(coro)
success = response is not None and response.get("success", False)
logger.info(f"data layer cli call response:{success}")
# todo make sure all cli methods follow this pattern, uncomment
# if not success:
# raise click.ClickException(message=f"query unsuccessful, response: {response}")
@click.group("data", short_help="Manage your data")
def data_cmd() -> None:
pass
# TODO: maybe use more helpful `type=`s to get click to handle error reporting of
# malformed inputs.
def create_changelist_option() -> IdentityFunction:
return click.option(
"-d",
"--changelist",
"changelist_string",
help="str representing the changelist",
type=str,
required=True,
)
def create_key_option() -> IdentityFunction:
return click.option(
"-h",
"--key",
"key_string",
help="str representing the key",
type=str,
required=True,
)
def create_data_store_id_option() -> "IdentityFunction":
return click.option(
"-store",
"--id",
help="The hexadecimal store id.",
type=str,
required=True,
)
def create_data_store_name_option() -> "IdentityFunction":
return click.option(
"-n",
"--table_name",
"table_name",
help="The name of the table.",
type=str,
required=True,
)
def create_rpc_port_option() -> "IdentityFunction":
return click.option(
"-dp",
"--data-rpc-port",
help="Set the port where the data layer is hosting the RPC interface. See rpc_port under wallet in config.yaml",
type=int,
default=None,
show_default=True,
)
def create_fee_option() -> "IdentityFunction":
return click.option(
"-m",
"--fee",
help="Set the fees for the transaction, in XCH",
type=str,
default=None,
show_default=True,
required=False,
)
@data_cmd.command("create_data_store", short_help="Create a new data store")
@create_rpc_port_option()
@create_fee_option()
def create_data_store(
data_rpc_port: int,
fee: Optional[str],
) -> None:
from chia.cmds.data_funcs import create_data_store_cmd
run(create_data_store_cmd(data_rpc_port, fee))
@data_cmd.command("get_value", short_help="Get the value for a given key and store")
@create_data_store_id_option()
@create_key_option()
@create_rpc_port_option()
def get_value(
id: str,
key_string: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import get_value_cmd
run(get_value_cmd(data_rpc_port, id, key_string))
@data_cmd.command("update_data_store", short_help="Update a store by providing the changelist operations")
@create_data_store_id_option()
@create_changelist_option()
@create_rpc_port_option()
@create_fee_option()
def update_data_store(
id: str,
changelist_string: str,
data_rpc_port: int,
fee: str,
) -> None:
from chia.cmds.data_funcs import update_data_store_cmd
run(update_data_store_cmd(rpc_port=data_rpc_port, store_id=id, changelist=json.loads(changelist_string), fee=fee))
@data_cmd.command("get_keys", short_help="Get all keys for a given store")
@create_data_store_id_option()
@create_rpc_port_option()
def get_keys(
id: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import get_keys_cmd
run(get_keys_cmd(data_rpc_port, id))
@data_cmd.command("get_keys_values", short_help="Get all keys and values for a given store")
@create_data_store_id_option()
@create_rpc_port_option()
def get_keys_values(
id: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import get_keys_values_cmd
run(get_keys_values_cmd(data_rpc_port, id))
@data_cmd.command("get_root", short_help="Get the published root hash value for a given store")
@create_data_store_id_option()
@create_rpc_port_option()
def get_root(
id: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import get_root_cmd
run(get_root_cmd(rpc_port=data_rpc_port, store_id=id))
@data_cmd.command("subscribe", short_help="Subscribe to a store")
@create_data_store_id_option()
@click.option(
"-u",
"--url",
"urls",
help="Manually provide a list of servers urls for downloading the data",
type=str,
multiple=True,
)
@create_rpc_port_option()
def subscribe(
id: str,
urls: List[str],
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import subscribe_cmd
run(subscribe_cmd(rpc_port=data_rpc_port, store_id=id, urls=urls))
@data_cmd.command("remove_subscription", short_help="Remove server urls that are added via subscribing to urls")
@create_data_store_id_option()
@click.option("-u", "--url", "urls", help="Server urls to remove", type=str, multiple=True)
@create_rpc_port_option()
def remove_subscription(
id: str,
urls: List[str],
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import remove_subscriptions_cmd
run(remove_subscriptions_cmd(rpc_port=data_rpc_port, store_id=id, urls=urls))
@data_cmd.command("unsubscribe", short_help="Completely untrack a store")
@create_data_store_id_option()
@create_rpc_port_option()
def unsubscribe(
id: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import unsubscribe_cmd
run(unsubscribe_cmd(rpc_port=data_rpc_port, store_id=id))
@data_cmd.command(
"get_kv_diff", short_help="Get the inserted and deleted keys and values between an initial and a final hash"
)
@create_data_store_id_option()
@click.option("-hash_1", "--hash_1", help="Initial hash", type=str)
@click.option("-hash_2", "--hash_2", help="Final hash", type=str)
@create_rpc_port_option()
def get_kv_diff(
id: str,
hash_1: str,
hash_2: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import get_kv_diff_cmd
run(get_kv_diff_cmd(rpc_port=data_rpc_port, store_id=id, hash_1=hash_1, hash_2=hash_2))
@data_cmd.command("get_root_history", short_help="Get all changes of a singleton")
@create_data_store_id_option()
@create_rpc_port_option()
def get_root_history(
id: str,
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import get_root_history_cmd
run(get_root_history_cmd(rpc_port=data_rpc_port, store_id=id))
@data_cmd.command("add_missing_files", short_help="Manually reconstruct server files from the data layer database")
@click.option(
"-i",
"--ids",
help="List of stores to reconstruct. If not specified, all stores will be reconstructed",
type=str,
required=False,
)
@click.option(
"-o/-n", "--override/--no-override", help="Specify if already existing files need to be overwritten by this command"
)
@click.option(
"-f", "--foldername", type=str, help="If specified, use a non-default folder to write the files", required=False
)
@create_rpc_port_option()
def add_missing_files(ids: Optional[str], override: bool, foldername: Optional[str], data_rpc_port: int) -> None:
from chia.cmds.data_funcs import add_missing_files_cmd
run(
add_missing_files_cmd(
rpc_port=data_rpc_port,
ids=None if ids is None else json.loads(ids),
override=override,
foldername=None if foldername is None else Path(foldername),
)
)
@data_cmd.command("add_mirror", short_help="Publish mirror urls on chain")
@click.option("-i", "--id", help="Store id", type=str, required=True)
@click.option("-a", "--amount", help="Amount for this mirror", type=int, required=True)
@click.option(
"-u",
"--url",
"urls",
help="URL to publish on the new coin, multiple accepted and will be published to a single coin.",
type=str,
multiple=True,
)
@create_fee_option()
@create_rpc_port_option()
def add_mirror(id: str, amount: int, urls: List[str], fee: Optional[str], data_rpc_port: int) -> None:
from chia.cmds.data_funcs import add_mirror_cmd
run(
add_mirror_cmd(
rpc_port=data_rpc_port,
store_id=id,
urls=urls,
amount=amount,
fee=fee,
)
)
@data_cmd.command("delete_mirror", short_help="Delete an owned mirror by its coin id")
@click.option("-i", "--id", help="Coin id", type=str, required=True)
@create_fee_option()
@create_rpc_port_option()
def delete_mirror(id: str, fee: Optional[str], data_rpc_port: int) -> None:
from chia.cmds.data_funcs import delete_mirror_cmd
run(
delete_mirror_cmd(
rpc_port=data_rpc_port,
coin_id=id,
fee=fee,
)
)
@data_cmd.command("get_mirrors", short_help="Get a list of all mirrors for a given store")
@click.option("-i", "--id", help="Store id", type=str, required=True)
@create_rpc_port_option()
def get_mirrors(id: str, data_rpc_port: int) -> None:
from chia.cmds.data_funcs import get_mirrors_cmd
run(
get_mirrors_cmd(
rpc_port=data_rpc_port,
store_id=id,
)
)
@data_cmd.command("get_subscriptions", short_help="Get subscribed stores, including the owned stores")
@create_rpc_port_option()
def get_subscriptions(data_rpc_port: int) -> None:
from chia.cmds.data_funcs import get_subscriptions_cmd
run(
get_subscriptions_cmd(
rpc_port=data_rpc_port,
)
)
@data_cmd.command("get_owned_stores", short_help="Get owned stores")
@create_rpc_port_option()
def get_owned_stores(data_rpc_port: int) -> None:
from chia.cmds.data_funcs import get_owned_stores_cmd
run(
get_owned_stores_cmd(
rpc_port=data_rpc_port,
)
)

309
chia/cmds/data_funcs.py Normal file
View File

@ -0,0 +1,309 @@
from decimal import Decimal
from pathlib import Path
from types import TracebackType
from typing import Dict, List, Optional, Tuple, Type
import aiohttp
from chia.cmds.units import units
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.config import load_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16, uint64
# TODO: there seems to be a large amount of repetition in these to dedupe
class get_client:
_port: Optional[int]
_client: Optional[DataLayerRpcClient] = None
def __init__(self, rpc_port: Optional[int]):
self._port = rpc_port
async def __aenter__(self) -> Tuple[DataLayerRpcClient, int]:
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
self_hostname = config["self_hostname"]
if self._port is None:
self._port = config["data_layer"]["rpc_port"]
self._client = await DataLayerRpcClient.create(self_hostname, uint16(self._port), DEFAULT_ROOT_PATH, config)
assert self._client is not None
return self._client, int(self._port)
async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
if self._client is None:
return
self._client.close()
await self._client.await_closed()
async def create_data_store_cmd(rpc_port: Optional[int], fee: Optional[str]) -> None:
final_fee = None
if fee is not None:
final_fee = uint64(int(Decimal(fee) * units["chia"]))
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.create_data_store(fee=final_fee)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
return
async def get_value_cmd(rpc_port: Optional[int], store_id: str, key: str) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
key_bytes = hexstr_to_bytes(key)
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_value(store_id=store_id_bytes, key=key_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
return
async def update_data_store_cmd(
rpc_port: Optional[int],
store_id: str,
changelist: List[Dict[str, str]],
fee: Optional[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
final_fee = None
if fee is not None:
final_fee = uint64(int(Decimal(fee) * units["chia"]))
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.update_data_store(store_id=store_id_bytes, changelist=changelist, fee=final_fee)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
return
async def get_keys_cmd(
rpc_port: Optional[int],
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_keys(store_id=store_id_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
return
async def get_keys_values_cmd(
rpc_port: Optional[int],
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_keys_values(store_id=store_id_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
return
async def get_root_cmd(
rpc_port: Optional[int],
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_root(store_id=store_id_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
return
async def subscribe_cmd(
rpc_port: Optional[int],
store_id: str,
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
await client.subscribe(store_id=store_id_bytes, urls=urls)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def unsubscribe_cmd(
rpc_port: Optional[int],
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
await client.unsubscribe(store_id=store_id_bytes)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def remove_subscriptions_cmd(
rpc_port: Optional[int],
store_id: str,
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
await client.remove_subscriptions(store_id=store_id_bytes, urls=urls)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def get_kv_diff_cmd(
rpc_port: Optional[int],
store_id: str,
hash_1: str,
hash_2: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
hash_1_bytes = bytes32.from_hexstr(hash_1)
hash_2_bytes = bytes32.from_hexstr(hash_2)
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_kv_diff(store_id=store_id_bytes, hash_1=hash_1_bytes, hash_2=hash_2_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def get_root_history_cmd(
rpc_port: Optional[int],
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_root_history(store_id=store_id_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def add_missing_files_cmd(
rpc_port: Optional[int], ids: Optional[List[str]], override: bool, foldername: Optional[Path]
) -> None:
try:
async with get_client(rpc_port) as (client, rpc_port):
await client.add_missing_files(
store_ids=(None if ids is None else [bytes32.from_hexstr(id) for id in ids]),
override=override,
foldername=foldername,
)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def add_mirror_cmd(
rpc_port: Optional[int], store_id: str, urls: List[str], amount: int, fee: Optional[str]
) -> None:
try:
store_id_bytes = bytes32.from_hexstr(store_id)
final_fee = None
if fee is not None:
final_fee = uint64(int(Decimal(fee) * units["chia"]))
async with get_client(rpc_port) as (client, rpc_port):
await client.add_mirror(
store_id=store_id_bytes,
urls=urls,
amount=amount,
fee=final_fee,
)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def delete_mirror_cmd(rpc_port: Optional[int], coin_id: str, fee: Optional[str]) -> None:
try:
coin_id_bytes = bytes32.from_hexstr(coin_id)
final_fee = None
if fee is not None:
final_fee = uint64(int(Decimal(fee) * units["chia"]))
async with get_client(rpc_port) as (client, rpc_port):
await client.delete_mirror(
coin_id=coin_id_bytes,
fee=final_fee,
)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def get_mirrors_cmd(rpc_port: Optional[int], store_id: str) -> None:
try:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_mirrors(store_id=store_id_bytes)
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def get_subscriptions_cmd(rpc_port: Optional[int]) -> None:
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_subscriptions()
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def get_owned_stores_cmd(rpc_port: Optional[int]) -> None:
try:
async with get_client(rpc_port) as (client, rpc_port):
res = await client.get_owned_stores()
print(res)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")

View File

@ -44,8 +44,17 @@ from chia.wallet.derive_keys import (
)
from chia.cmds.configure import configure
private_node_names: List[str] = ["full_node", "wallet", "farmer", "harvester", "timelord", "crawler", "daemon"]
public_node_names: List[str] = ["full_node", "wallet", "farmer", "introducer", "timelord"]
private_node_names: List[str] = [
"full_node",
"wallet",
"farmer",
"harvester",
"timelord",
"crawler",
"data_layer",
"daemon",
]
public_node_names: List[str] = ["full_node", "wallet", "farmer", "introducer", "timelord", "data_layer"]
def dict_add_new_default(updated: Dict, default: Dict, do_not_migrate_keys: Dict[str, Any]):

View File

@ -10,7 +10,7 @@ from chia.util.config import load_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16
services: List[str] = ["crawler", "farmer", "full_node", "harvester", "timelord", "wallet"]
services: List[str] = ["crawler", "farmer", "full_node", "harvester", "timelord", "wallet", "data_layer"]
async def call_endpoint(service: str, endpoint: str, request: Dict[str, Any], config: Dict[str, Any]) -> Dict[str, Any]:

View File

@ -60,12 +60,12 @@ def print_transaction(tx: TransactionRecord, verbose: bool, name, address_prefix
def get_mojo_per_unit(wallet_type: WalletType) -> int:
mojo_per_unit: int
if wallet_type == WalletType.STANDARD_WALLET or wallet_type == WalletType.POOLING_WALLET:
if wallet_type in {WalletType.STANDARD_WALLET, WalletType.POOLING_WALLET, WalletType.DATA_LAYER}:
mojo_per_unit = units["chia"]
elif wallet_type == WalletType.CAT:
mojo_per_unit = units["cat"]
else:
raise LookupError("Only standard wallet, CAT wallets, and Plot NFTs are supported")
raise LookupError(f"Operation is not supported for Wallet type {wallet_type.name}")
return mojo_per_unit
@ -87,12 +87,12 @@ async def get_name_for_wallet_id(
wallet_id: int,
wallet_client: WalletRpcClient,
):
if wallet_type == WalletType.STANDARD_WALLET or wallet_type == WalletType.POOLING_WALLET:
if wallet_type in {WalletType.STANDARD_WALLET, WalletType.POOLING_WALLET, WalletType.DATA_LAYER}:
name = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"].upper()
elif wallet_type == WalletType.CAT:
name = await wallet_client.get_cat_name(wallet_id=str(wallet_id))
else:
raise LookupError("Only standard wallet, CAT wallets, and Plot NFTs are supported")
raise LookupError(f"Operation is not supported for Wallet type {wallet_type.name}")
return name

View File

@ -83,6 +83,7 @@ class PlotEvent(str, Enum):
if getattr(sys, "frozen", False):
name_map = {
"chia": "chia",
"chia_data_layer": "start_data_layer",
"chia_wallet": "start_wallet",
"chia_full_node": "start_full_node",
"chia_harvester": "start_harvester",

View File

View File

@ -0,0 +1,777 @@
import asyncio
import logging
import random
import time
import traceback
from pathlib import Path
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, Union
import aiohttp
import aiosqlite
from chia.data_layer.data_layer_errors import KeyNotFoundError
from chia.data_layer.data_layer_server import DataLayerServer
from chia.data_layer.data_layer_util import (
DiffData,
InternalNode,
KeyValue,
Layer,
Offer,
OfferStore,
Proof,
ProofOfInclusion,
ProofOfInclusionLayer,
Root,
ServerInfo,
Status,
StoreProofs,
Subscription,
TerminalNode,
leaf_hash,
)
from chia.data_layer.data_layer_wallet import DataLayerWallet, Mirror, SingletonRecord, verify_offer
from chia.data_layer.data_store import DataStore
from chia.data_layer.download_data import insert_from_delta_file, write_files_for_root
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.server.server import ChiaServer
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.db_wrapper import DBWrapper
from chia.util.ints import uint32, uint64
from chia.util.path import path_from_root
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import Offer as TradingOffer
from chia.wallet.transaction_record import TransactionRecord
class DataLayer:
data_store: DataStore
data_layer_server: DataLayerServer
db_wrapper: DBWrapper
batch_update_db_wrapper: DBWrapper
db_path: Path
connection: Optional[aiosqlite.Connection]
config: Dict[str, Any]
log: logging.Logger
wallet_rpc_init: Awaitable[WalletRpcClient]
state_changed_callback: Optional[Callable[..., object]]
wallet_id: uint64
initialized: bool
none_bytes: bytes32
lock: asyncio.Lock
def __init__(
self,
config: Dict[str, Any],
root_path: Path,
wallet_rpc_init: Awaitable[WalletRpcClient],
name: Optional[str] = None,
):
if name == "":
# TODO: If no code depends on "" counting as 'unspecified' then we do not
# need this.
name = None
self.initialized = False
self.config = config
self.connection = None
self.wallet_rpc_init = wallet_rpc_init
self.log = logging.getLogger(name if name is None else __name__)
self._shut_down: bool = False
db_path_replaced: str = config["database_path"].replace("CHALLENGE", config["selected_network"])
self.db_path = path_from_root(root_path, db_path_replaced)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
server_files_replaced: str = config.get(
"server_files_location", "data_layer/db/server_files_location_CHALLENGE"
).replace("CHALLENGE", config["selected_network"])
self.server_files_location = path_from_root(root_path, server_files_replaced)
self.server_files_location.mkdir(parents=True, exist_ok=True)
self.data_layer_server = DataLayerServer(root_path, self.config, self.log)
self.none_bytes = bytes32([0] * 32)
self.lock = asyncio.Lock()
def _set_state_changed_callback(self, callback: Callable[..., object]) -> None:
self.state_changed_callback = callback
def set_server(self, server: ChiaServer) -> None:
self.server = server
async def _start(self) -> bool:
self.connection = await aiosqlite.connect(self.db_path)
self.db_wrapper = DBWrapper(self.connection)
self.data_store = await DataStore.create(self.db_wrapper)
self.wallet_rpc = await self.wallet_rpc_init
self.subscription_lock: asyncio.Lock = asyncio.Lock()
if self.config.get("run_server", False):
await self.data_layer_server.start()
self.periodically_manage_data_task: asyncio.Task[Any] = asyncio.create_task(self.periodically_manage_data())
return True
def _close(self) -> None:
# TODO: review for anything else we need to do here
self._shut_down = True
async def _await_closed(self) -> None:
if self.connection is not None:
await self.connection.close()
if self.config.get("run_server", False):
await self.data_layer_server.stop()
try:
self.periodically_manage_data_task.cancel()
except asyncio.CancelledError:
pass
async def create_store(
self, fee: uint64, root: bytes32 = bytes32([0] * 32)
) -> Tuple[List[TransactionRecord], bytes32]:
txs, tree_id = await self.wallet_rpc.create_new_dl(root, fee)
res = await self.data_store.create_tree(tree_id=tree_id)
if res is None:
self.log.fatal("failed creating store")
self.initialized = True
return txs, tree_id
async def batch_update(
self,
tree_id: bytes32,
changelist: List[Dict[str, Any]],
fee: uint64,
) -> TransactionRecord:
await self.batch_insert(tree_id=tree_id, changelist=changelist)
return await self.publish_update(tree_id=tree_id, fee=fee)
async def batch_insert(
self,
tree_id: bytes32,
changelist: List[Dict[str, Any]],
lock: bool = True,
) -> bytes32:
async with self.data_store.transaction(lock=lock):
# Make sure we update based on the latest confirmed root.
async with self.lock:
await self._update_confirmation_status(tree_id=tree_id, lock=False)
pending_root: Optional[Root] = await self.data_store.get_pending_root(tree_id=tree_id, lock=False)
if pending_root is not None:
raise Exception("Already have a pending root waiting for confirmation.")
# check before any DL changes that this singleton is currently owned by this wallet
singleton_records: List[SingletonRecord] = await self.get_owned_stores()
if not any(tree_id == singleton.launcher_id for singleton in singleton_records):
raise ValueError(f"Singleton with launcher ID {tree_id} is not owned by DL Wallet")
t1 = time.monotonic()
batch_hash = await self.data_store.insert_batch(tree_id, changelist, lock=False)
t2 = time.monotonic()
self.log.info(f"Data store batch update process time: {t2 - t1}.")
# todo return empty node hash from get_tree_root
if batch_hash is not None:
node_hash = batch_hash
else:
node_hash = self.none_bytes # todo change
return node_hash
async def publish_update(
self,
tree_id: bytes32,
fee: uint64,
) -> TransactionRecord:
# Make sure we update based on the latest confirmed root.
async with self.lock:
await self._update_confirmation_status(tree_id=tree_id)
pending_root: Optional[Root] = await self.data_store.get_pending_root(tree_id=tree_id)
if pending_root is None:
raise Exception("Latest root is already confirmed.")
root_hash = self.none_bytes if pending_root.node_hash is None else pending_root.node_hash
transaction_record = await self.wallet_rpc.dl_update_root(
launcher_id=tree_id,
new_root=root_hash,
fee=fee,
)
return transaction_record
async def get_key_value_hash(
self,
store_id: bytes32,
key: bytes,
root_hash: Optional[bytes32] = None,
lock: bool = True,
) -> bytes32:
async with self.data_store.transaction(lock=lock):
async with self.lock:
await self._update_confirmation_status(tree_id=store_id, lock=False)
node = await self.data_store.get_node_by_key(tree_id=store_id, key=key, root_hash=root_hash, lock=False)
return node.hash
async def get_value(self, store_id: bytes32, key: bytes, lock: bool = True) -> Optional[bytes]:
async with self.data_store.transaction(lock=lock):
async with self.lock:
await self._update_confirmation_status(tree_id=store_id, lock=False)
res = await self.data_store.get_node_by_key(tree_id=store_id, key=key, lock=False)
if res is None:
self.log.error("Failed to fetch key")
return None
return res.value
async def get_keys_values(self, store_id: bytes32, root_hash: Optional[bytes32]) -> List[TerminalNode]:
async with self.lock:
await self._update_confirmation_status(tree_id=store_id)
res = await self.data_store.get_keys_values(store_id, root_hash)
if res is None:
self.log.error("Failed to fetch keys values")
return res
async def get_keys(self, store_id: bytes32, root_hash: Optional[bytes32]) -> List[bytes]:
async with self.lock:
await self._update_confirmation_status(tree_id=store_id)
res = await self.data_store.get_keys(store_id, root_hash)
return res
async def get_ancestors(self, node_hash: bytes32, store_id: bytes32) -> List[InternalNode]:
async with self.lock:
await self._update_confirmation_status(tree_id=store_id)
res = await self.data_store.get_ancestors(node_hash=node_hash, tree_id=store_id)
if res is None:
self.log.error("Failed to get ancestors")
return res
async def get_root(self, store_id: bytes32) -> Optional[SingletonRecord]:
latest = await self.wallet_rpc.dl_latest_singleton(store_id, True)
if latest is None:
self.log.error(f"Failed to get root for {store_id.hex()}")
return latest
async def get_local_root(self, store_id: bytes32) -> Optional[bytes32]:
async with self.lock:
await self._update_confirmation_status(tree_id=store_id)
res = await self.data_store.get_tree_root(tree_id=store_id)
if res is None:
self.log.error(f"Failed to get root for {store_id.hex()}")
return None
return res.node_hash
async def get_root_history(self, store_id: bytes32) -> List[SingletonRecord]:
records = await self.wallet_rpc.dl_history(store_id)
if records is None:
self.log.error(f"Failed to get root history for {store_id.hex()}")
root_history = []
prev: Optional[SingletonRecord] = None
for record in records:
if prev is None or record.root != prev.root:
root_history.append(record)
prev = record
return root_history
async def _update_confirmation_status(self, tree_id: bytes32, lock: bool = True) -> None:
async with self.data_store.transaction(lock=lock):
try:
root = await self.data_store.get_tree_root(tree_id=tree_id, lock=False)
except asyncio.CancelledError:
raise
except Exception:
root = None
singleton_record: Optional[SingletonRecord] = await self.wallet_rpc.dl_latest_singleton(tree_id, True)
if singleton_record is None:
return
if root is None:
pending_root = await self.data_store.get_pending_root(tree_id=tree_id, lock=False)
if pending_root is not None:
if pending_root.generation == 0 and pending_root.node_hash is None:
await self.data_store.change_root_status(pending_root, Status.COMMITTED, lock=False)
await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False)
return
else:
root = None
if root is None:
self.log.info(f"Don't have pending root for {tree_id}.")
return
if root.generation == singleton_record.generation:
return
if root.generation > singleton_record.generation:
self.log.info(
f"Local root ahead of chain root: {root.generation} {singleton_record.generation}. "
"Maybe we're doing a batch update."
)
return
wallet_history = await self.wallet_rpc.dl_history(
launcher_id=tree_id,
min_generation=uint32(root.generation + 1),
max_generation=singleton_record.generation,
)
new_hashes = [record.root for record in reversed(wallet_history)]
root_hash = self.none_bytes if root.node_hash is None else root.node_hash
generation_shift = 0
while len(new_hashes) > 0 and new_hashes[0] == root_hash:
generation_shift += 1
new_hashes.pop(0)
if generation_shift > 0:
await self.data_store.shift_root_generations(tree_id=tree_id, shift_size=generation_shift, lock=False)
else:
expected_root_hash = None if new_hashes[0] == self.none_bytes else new_hashes[0]
pending_root = await self.data_store.get_pending_root(tree_id=tree_id, lock=False)
if (
pending_root is not None
and pending_root.generation == root.generation + 1
and pending_root.node_hash == expected_root_hash
):
await self.data_store.change_root_status(pending_root, Status.COMMITTED, lock=False)
await self.data_store.build_ancestor_table_for_latest_root(tree_id=tree_id, lock=False)
await self.data_store.clear_pending_roots(tree_id=tree_id, lock=False)
async def fetch_and_validate(self, tree_id: bytes32) -> None:
singleton_record: Optional[SingletonRecord] = await self.wallet_rpc.dl_latest_singleton(tree_id, True)
if singleton_record is None:
self.log.info(f"Fetch data: No singleton record for {tree_id}.")
return
if singleton_record.generation == uint32(0):
self.log.info(f"Fetch data: No data on chain for {tree_id}.")
return
async with self.lock:
await self._update_confirmation_status(tree_id=tree_id)
if not await self.data_store.tree_id_exists(tree_id=tree_id):
await self.data_store.create_tree(tree_id=tree_id, status=Status.COMMITTED)
timestamp = int(time.time())
servers_info = await self.data_store.get_available_servers_for_store(tree_id, timestamp)
# TODO: maybe append a random object to the whole DataLayer class?
random.shuffle(servers_info)
for server_info in servers_info:
url = server_info.url
root = await self.data_store.get_tree_root(tree_id=tree_id)
if root.generation > singleton_record.generation:
self.log.info(
"Fetch data: local DL store is ahead of chain generation. "
f"Local root: {root}. Singleton: {singleton_record}"
)
break
if root.generation == singleton_record.generation:
self.log.info(f"Fetch data: wallet generation matching on-chain generation: {tree_id}.")
break
self.log.info(
f"Downloading files {tree_id}. "
f"Current wallet generation: {root.generation}. "
f"Target wallet generation: {singleton_record.generation}. "
f"Server used: {url}."
)
to_download = await self.wallet_rpc.dl_history(
launcher_id=tree_id,
min_generation=uint32(root.generation + 1),
max_generation=singleton_record.generation,
)
try:
success = await insert_from_delta_file(
self.data_store,
tree_id,
root.generation,
[record.root for record in reversed(to_download)],
server_info,
self.server_files_location,
self.log,
)
if success:
self.log.info(
f"Finished downloading and validating {tree_id}. "
f"Wallet generation saved: {singleton_record.generation}. "
f"Root hash saved: {singleton_record.root}."
)
break
except asyncio.CancelledError:
raise
except aiohttp.client_exceptions.ClientConnectorError:
self.log.warning(f"Server {url} unavailable for {tree_id}.")
except Exception as e:
self.log.warning(f"Exception while downloading files for {tree_id}: {e} {traceback.format_exc()}.")
async def upload_files(self, tree_id: bytes32) -> None:
singleton_record: Optional[SingletonRecord] = await self.wallet_rpc.dl_latest_singleton(tree_id, True)
if singleton_record is None:
self.log.info(f"Upload files: no on-chain record for {tree_id}.")
return
async with self.lock:
await self._update_confirmation_status(tree_id=tree_id)
root = await self.data_store.get_tree_root(tree_id=tree_id)
publish_generation = min(singleton_record.generation, 0 if root is None else root.generation)
# If we make some batch updates, which get confirmed to the chain, we need to create the files.
# We iterate back and write the missing files, until we find the files already written.
root = await self.data_store.get_tree_root(tree_id=tree_id, generation=publish_generation)
while publish_generation > 0 and await write_files_for_root(
self.data_store,
tree_id,
root,
self.server_files_location,
):
publish_generation -= 1
root = await self.data_store.get_tree_root(tree_id=tree_id, generation=publish_generation)
async def add_missing_files(self, store_id: bytes32, override: bool, foldername: Optional[Path]) -> None:
root = await self.data_store.get_tree_root(tree_id=store_id)
singleton_record: Optional[SingletonRecord] = await self.wallet_rpc.dl_latest_singleton(store_id, True)
if singleton_record is None:
self.log.error(f"No singleton record found for: {store_id}")
return
max_generation = min(singleton_record.generation, 0 if root is None else root.generation)
server_files_location = foldername if foldername is not None else self.server_files_location
for generation in range(1, max_generation + 1):
root = await self.data_store.get_tree_root(tree_id=store_id, generation=generation)
await write_files_for_root(self.data_store, store_id, root, server_files_location, override)
async def subscribe(self, store_id: bytes32, urls: List[str]) -> None:
parsed_urls = [url.rstrip("/") for url in urls]
subscription = Subscription(store_id, [ServerInfo(url, 0, 0) for url in parsed_urls])
await self.wallet_rpc.dl_track_new(subscription.tree_id)
async with self.subscription_lock:
await self.data_store.subscribe(subscription)
self.log.info(f"Done adding subscription: {subscription.tree_id}")
async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> None:
parsed_urls = [url.rstrip("/") for url in urls]
async with self.subscription_lock:
await self.data_store.remove_subscriptions(store_id, parsed_urls)
async def unsubscribe(self, tree_id: bytes32) -> None:
subscriptions = await self.get_subscriptions()
if tree_id not in (subscription.tree_id for subscription in subscriptions):
raise RuntimeError("No subscription found for the given tree_id.")
async with self.subscription_lock:
await self.data_store.unsubscribe(tree_id)
await self.wallet_rpc.dl_stop_tracking(tree_id)
self.log.info(f"Unsubscribed to {tree_id}")
async def get_subscriptions(self) -> List[Subscription]:
async with self.subscription_lock:
return await self.data_store.get_subscriptions()
async def add_mirror(self, store_id: bytes32, urls: List[str], amount: uint64, fee: uint64) -> None:
bytes_urls = [bytes(url, "utf8") for url in urls]
await self.wallet_rpc.dl_new_mirror(store_id, amount, bytes_urls, fee)
async def delete_mirror(self, coin_id: bytes32, fee: uint64) -> None:
await self.wallet_rpc.dl_delete_mirror(coin_id, fee)
async def get_mirrors(self, tree_id: bytes32) -> List[Mirror]:
return await self.wallet_rpc.dl_get_mirrors(tree_id)
async def update_subscriptions_from_wallet(self, tree_id: bytes32) -> None:
mirrors: List[Mirror] = await self.wallet_rpc.dl_get_mirrors(tree_id)
urls: List[str] = []
for mirror in mirrors:
urls = urls + [url.decode("utf8") for url in mirror.urls]
urls = [url.rstrip("/") for url in urls]
await self.data_store.update_subscriptions_from_wallet(tree_id, urls)
async def get_owned_stores(self) -> List[SingletonRecord]:
return await self.wallet_rpc.dl_owned_singletons()
async def get_kv_diff(self, tree_id: bytes32, hash_1: bytes32, hash_2: bytes32) -> Set[DiffData]:
return await self.data_store.get_kv_diff(tree_id, hash_1, hash_2)
async def periodically_manage_data(self) -> None:
manage_data_interval = self.config.get("manage_data_interval", 60)
while not self._shut_down:
async with self.subscription_lock:
try:
subscriptions = await self.data_store.get_subscriptions()
for subscription in subscriptions:
await self.wallet_rpc.dl_track_new(subscription.tree_id)
break
except aiohttp.client_exceptions.ClientConnectorError:
pass
except asyncio.CancelledError:
raise
self.log.warning("Cannot connect to the wallet. Retrying in 3s.")
delay_until = time.monotonic() + 3
while time.monotonic() < delay_until:
if self._shut_down:
break
try:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
raise
while not self._shut_down:
async with self.subscription_lock:
subscriptions = await self.data_store.get_subscriptions()
# Subscribe to all local tree_ids that we can find on chain.
local_tree_ids = await self.data_store.get_tree_ids()
subscription_tree_ids = set(subscription.tree_id for subscription in subscriptions)
for local_id in local_tree_ids:
if local_id not in subscription_tree_ids:
try:
await self.subscribe(local_id, [])
except asyncio.CancelledError:
raise
except Exception as e:
self.log.info(
f"Can't subscribe to locally stored {local_id}: {type(e)} {e} {traceback.format_exc()}"
)
async with self.subscription_lock:
for subscription in subscriptions:
try:
await self.update_subscriptions_from_wallet(subscription.tree_id)
await self.fetch_and_validate(subscription.tree_id)
await self.upload_files(subscription.tree_id)
except asyncio.CancelledError:
raise
except Exception as e:
self.log.error(f"Exception while fetching data: {type(e)} {e} {traceback.format_exc()}.")
try:
await asyncio.sleep(manage_data_interval)
except asyncio.CancelledError:
raise
async def build_offer_changelist(
self,
store_id: bytes32,
inclusions: Tuple[KeyValue, ...],
lock: bool = True,
) -> List[Dict[str, Any]]:
async with self.data_store.transaction(lock=lock):
changelist: List[Dict[str, Any]] = []
for entry in inclusions:
try:
existing_value = await self.get_value(store_id=store_id, key=entry.key, lock=False)
except KeyNotFoundError:
existing_value = None
if existing_value == entry.value:
# already present, nothing needed
continue
if existing_value is not None:
# upsert, delete the existing key and value
changelist.append(
{
"action": "delete",
"key": entry.key,
}
)
changelist.append(
{
"action": "insert",
"key": entry.key,
"value": entry.value,
}
)
return changelist
async def process_offered_stores(
self, offer_stores: Tuple[OfferStore, ...], lock: bool = True
) -> Dict[bytes32, StoreProofs]:
async with self.data_store.transaction(lock=lock):
our_store_proofs: Dict[bytes32, StoreProofs] = {}
for offer_store in offer_stores:
async with self.lock:
await self._update_confirmation_status(tree_id=offer_store.store_id, lock=False)
changelist = await self.build_offer_changelist(
store_id=offer_store.store_id,
inclusions=offer_store.inclusions,
lock=False,
)
if len(changelist) > 0:
new_root_hash = await self.batch_insert(
tree_id=offer_store.store_id,
changelist=changelist,
lock=False,
)
else:
existing_root = await self.get_root(store_id=offer_store.store_id)
if existing_root is None:
raise Exception(f"store id not available: {offer_store.store_id.hex()}")
new_root_hash = existing_root.root
if new_root_hash is None:
raise Exception("only inserts are supported so a None root hash should not be possible")
proofs: List[Proof] = []
for entry in offer_store.inclusions:
node_hash = await self.get_key_value_hash(
store_id=offer_store.store_id,
key=entry.key,
root_hash=new_root_hash,
lock=False,
)
proof_of_inclusion = await self.data_store.get_proof_of_inclusion_by_hash(
node_hash=node_hash,
tree_id=offer_store.store_id,
root_hash=new_root_hash,
lock=False,
)
proof = Proof(
key=entry.key,
value=entry.value,
node_hash=proof_of_inclusion.node_hash,
layers=tuple(
Layer(
other_hash_side=layer.other_hash_side,
other_hash=layer.other_hash,
combined_hash=layer.combined_hash,
)
for layer in proof_of_inclusion.layers
),
)
proofs.append(proof)
store_proof = StoreProofs(store_id=offer_store.store_id, proofs=tuple(proofs))
our_store_proofs[offer_store.store_id] = store_proof
return our_store_proofs
async def make_offer(
self,
maker: Tuple[OfferStore, ...],
taker: Tuple[OfferStore, ...],
fee: uint64,
) -> Offer:
async with self.data_store.transaction():
our_store_proofs = await self.process_offered_stores(offer_stores=maker, lock=False)
offer_dict: Dict[Union[uint32, str], int] = {
**{offer_store.store_id.hex(): -1 for offer_store in maker},
**{offer_store.store_id.hex(): 1 for offer_store in taker},
}
solver: Dict[str, Any] = {
"0x"
+ our_offer_store.store_id.hex(): {
"new_root": "0x" + our_store_proofs[our_offer_store.store_id].proofs[0].root().hex(),
"dependencies": [
{
"launcher_id": "0x" + their_offer_store.store_id.hex(),
"values_to_prove": [
"0x" + leaf_hash(key=entry.key, value=entry.value).hex()
for entry in their_offer_store.inclusions
],
}
for their_offer_store in taker
],
}
for our_offer_store in maker
}
wallet_offer, trade_record = await self.wallet_rpc.create_offer_for_ids(
offer_dict=offer_dict,
solver=solver,
driver_dict={},
fee=fee,
validate_only=False,
)
if wallet_offer is None:
raise Exception("offer is None despite validate_only=False")
offer = Offer(
trade_id=trade_record.trade_id,
offer=bytes(wallet_offer),
taker=taker,
maker=tuple(our_store_proofs.values()),
)
# being extra careful and verifying the offer before returning it
trading_offer = TradingOffer.from_bytes(offer.offer)
summary = await DataLayerWallet.get_offer_summary(offer=trading_offer)
verify_offer(maker=offer.maker, taker=offer.taker, summary=summary)
return offer
async def take_offer(
self,
offer_bytes: bytes,
taker: Tuple[OfferStore, ...],
maker: Tuple[StoreProofs, ...],
fee: uint64,
) -> TradeRecord:
async with self.data_store.transaction():
our_store_proofs = await self.process_offered_stores(offer_stores=taker, lock=False)
offer = TradingOffer.from_bytes(offer_bytes)
summary = await DataLayerWallet.get_offer_summary(offer=offer)
verify_offer(maker=maker, taker=taker, summary=summary)
all_store_proofs: Dict[bytes32, StoreProofs] = {
store_proofs.proofs[0].root(): store_proofs for store_proofs in [*maker, *our_store_proofs.values()]
}
proofs_of_inclusion: List[Tuple[str, str, List[str]]] = []
for root, store_proofs in all_store_proofs.items():
for proof in store_proofs.proofs:
layers = [
ProofOfInclusionLayer(
combined_hash=layer.combined_hash,
other_hash_side=layer.other_hash_side,
other_hash=layer.other_hash,
)
for layer in proof.layers
]
proof_of_inclusion = ProofOfInclusion(node_hash=proof.node_hash, layers=layers)
sibling_sides_integer = proof_of_inclusion.sibling_sides_integer()
proofs_of_inclusion.append(
(
root.hex(),
str(sibling_sides_integer),
["0x" + sibling_hash.hex() for sibling_hash in proof_of_inclusion.sibling_hashes()],
)
)
solver: Dict[str, Any] = {
"proofs_of_inclusion": proofs_of_inclusion,
**{
"0x"
+ our_offer_store.store_id.hex(): {
"new_root": "0x" + root.hex(),
"dependencies": [
{
"launcher_id": "0x" + their_offer_store.store_id.hex(),
"values_to_prove": ["0x" + entry.node_hash.hex() for entry in their_offer_store.proofs],
}
for their_offer_store in maker
],
}
for our_offer_store in taker
},
}
# Excluding wallet from transaction since failures in the wallet may occur
# after the transaction is submitted to the chain. If we roll back data we
# may lose published data.
trade_record = await self.wallet_rpc.take_offer(
offer=offer,
solver=solver,
fee=fee,
)
return trade_record
async def cancel_offer(self, trade_id: bytes32, secure: bool, fee: uint64) -> None:
store_ids: List[bytes32] = []
if not secure:
trade_record = await self.wallet_rpc.get_offer(trade_id=trade_id, file_contents=True)
trading_offer = TradingOffer.from_bytes(trade_record.offer)
summary = await DataLayerWallet.get_offer_summary(offer=trading_offer)
store_ids = [bytes32.from_hexstr(offered["launcher_id"]) for offered in summary["offered"]]
await self.wallet_rpc.cancel_offer(
trade_id=trade_id,
secure=secure,
fee=fee,
)
if not secure:
for store_id in store_ids:
await self.data_store.clear_pending_roots(tree_id=store_id)

View File

@ -0,0 +1,26 @@
import logging
from typing import Any
from chia.data_layer.data_layer import DataLayer
class DataLayerAPI:
data_layer: DataLayer
def __init__(self, data_layer: DataLayer) -> None:
self.data_layer = data_layer
# def _set_state_changed_callback(self, callback: Callable):
# self.full_node.state_changed_callback = callback
@property
def server(self) -> Any:
return self.data_layer.server
@property
def log(self) -> logging.Logger:
return self.data_layer.log
@property
def api_ready(self) -> bool:
return self.data_layer.initialized

View File

@ -0,0 +1,40 @@
from typing import Iterable, List
from chia.types.blockchain_format.sized_bytes import bytes32
class IntegrityError(Exception):
pass
def build_message_with_hashes(message: str, bytes_objects: Iterable[bytes]) -> str:
return "\n".join([message, *[f" {b.hex()}" for b in bytes_objects]])
class TreeGenerationIncrementingError(IntegrityError):
def __init__(self, tree_ids: List[bytes32]) -> None:
super().__init__(
build_message_with_hashes(
message="Found trees with generations not properly incrementing:",
bytes_objects=tree_ids,
)
)
class NodeHashError(IntegrityError):
def __init__(self, node_hashes: List[bytes32]) -> None:
super().__init__(
build_message_with_hashes(
message="Found nodes with incorrect hashes:",
bytes_objects=node_hashes,
)
)
class KeyNotFoundError(Exception):
def __init__(self, key: bytes) -> None:
super().__init__(f"Key not found: {key.hex()}")
class OfferIntegrityError(Exception):
pass

View File

@ -0,0 +1,60 @@
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict
from aiohttp import web
from chia.data_layer.download_data import is_filename_valid
from chia.server.upnp import UPnP
from chia.util.path import path_from_root
@dataclass
class DataLayerServer:
root_path: Path
config: Dict[str, Any]
log: logging.Logger
async def start(self) -> None:
self.log.info("Starting Data Layer Server.")
self.port = self.config["host_port"]
# Setup UPnP for the data_layer_service port
self.upnp: UPnP = UPnP()
self.upnp.remap(self.port)
server_files_replaced: str = self.config.get(
"server_files_location", "data_layer/db/server_files_location_CHALLENGE"
).replace("CHALLENGE", self.config["selected_network"])
self.server_dir = path_from_root(self.root_path, server_files_replaced)
app = web.Application()
app.add_routes([web.get("/{filename}", self.file_handler)])
self.runner = web.AppRunner(app)
await self.runner.setup()
self.site = web.TCPSite(self.runner, self.config["host_ip"], port=self.port)
await self.site.start()
self.log.info("Started Data Layer Server.")
async def stop(self) -> None:
self.upnp.release(self.port)
# this is a blocking call, waiting for the UPnP thread to exit
self.upnp.shutdown()
self.log.info("Stopped Data Layer Server.")
await self.runner.cleanup()
async def file_handler(self, request: web.Request) -> web.Response:
filename = request.match_info["filename"]
if not is_filename_valid(filename):
raise Exception("Invalid file format requested.")
file_path = self.server_dir.joinpath(filename)
with open(file_path, "rb") as reader:
content = reader.read()
response = web.Response(
content_type="application/octet-stream",
headers={"Content-Disposition": "attachment;filename={}".format(filename)},
body=content,
)
return response

View File

@ -0,0 +1,615 @@
from __future__ import annotations
import dataclasses
from dataclasses import dataclass, field
from enum import IntEnum
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union
# TODO: remove or formalize this
import aiosqlite as aiosqlite
from typing_extensions import final
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.ints import uint64
from chia.util.streamable import Streamable, streamable
if TYPE_CHECKING:
from chia.data_layer.data_store import DataStore
def internal_hash(left_hash: bytes32, right_hash: bytes32) -> bytes32:
# ignoring hint error here for:
# https://github.com/Chia-Network/clvm/pull/102
# https://github.com/Chia-Network/clvm/pull/106
return Program.to((left_hash, right_hash)).get_tree_hash(left_hash, right_hash) # type: ignore[no-any-return]
def calculate_internal_hash(hash: bytes32, other_hash_side: Side, other_hash: bytes32) -> bytes32:
if other_hash_side == Side.LEFT:
return internal_hash(left_hash=other_hash, right_hash=hash)
elif other_hash_side == Side.RIGHT:
return internal_hash(left_hash=hash, right_hash=other_hash)
raise Exception(f"Invalid side: {other_hash_side!r}")
def leaf_hash(key: bytes, value: bytes) -> bytes32:
# ignoring hint error here for:
# https://github.com/Chia-Network/clvm/pull/102
# https://github.com/Chia-Network/clvm/pull/106
return Program.to((key, value)).get_tree_hash() # type: ignore[no-any-return]
async def _debug_dump(db: aiosqlite.Connection, description: str = "") -> None:
cursor = await db.execute("SELECT name FROM sqlite_master WHERE type='table';")
print("-" * 50, description, flush=True)
for [name] in await cursor.fetchall():
cursor = await db.execute(f"SELECT * FROM {name}")
print(f"\n -- {name} ------", flush=True)
async for row in cursor:
print(f" {dict(row)}")
async def _dot_dump(data_store: DataStore, store_id: bytes32, root_hash: bytes32) -> str:
terminal_nodes = await data_store.get_keys_values(tree_id=store_id, root_hash=root_hash)
internal_nodes = await data_store.get_internal_nodes(tree_id=store_id, root_hash=root_hash)
n = 8
dot_nodes: List[str] = []
dot_connections: List[str] = []
dot_pair_boxes: List[str] = []
for terminal_node in terminal_nodes:
hash = terminal_node.hash.hex()
key = terminal_node.key.hex()
value = terminal_node.value.hex()
dot_nodes.append(f"""node_{hash} [shape=box, label="{hash[:n]}\\nkey: {key}\\nvalue: {value}"];""")
for internal_node in internal_nodes:
hash = internal_node.hash.hex()
left = internal_node.left_hash.hex()
right = internal_node.right_hash.hex()
dot_nodes.append(f"""node_{hash} [label="{hash[:n]}"]""")
dot_connections.append(f"""node_{hash} -> node_{left} [label="L"];""")
dot_connections.append(f"""node_{hash} -> node_{right} [label="R"];""")
dot_pair_boxes.append(
f"node [shape = box]; " f"{{rank = same; node_{left}->node_{right}[style=invis]; rankdir = LR}}"
)
lines = [
"digraph {",
*dot_nodes,
*dot_connections,
*dot_pair_boxes,
"}",
]
return "\n".join(lines)
def row_to_node(row: aiosqlite.Row) -> Node:
cls = node_type_to_class[row["node_type"]]
return cls.from_row(row=row)
class Status(IntEnum):
PENDING = 1
COMMITTED = 2
class NodeType(IntEnum):
INTERNAL = 1
TERMINAL = 2
@final
class Side(IntEnum):
LEFT = 0
RIGHT = 1
def other(self) -> "Side":
if self == Side.LEFT:
return Side.RIGHT
return Side.LEFT
@classmethod
def unmarshal(cls, o: str) -> Side:
return getattr(cls, o.upper()) # type: ignore[no-any-return]
def marshal(self) -> str:
return self.name.lower()
class OperationType(IntEnum):
INSERT = 0
DELETE = 1
class CommitState(IntEnum):
OPEN = 0
FINALIZED = 1
ROLLED_BACK = 2
Node = Union["TerminalNode", "InternalNode"]
@dataclass(frozen=True)
class TerminalNode:
hash: bytes32
# generation: int
key: bytes
value: bytes
atom: None = field(init=False, default=None)
@property
def pair(self) -> Tuple[bytes32, bytes32]:
return Program.to(self.key), Program.to(self.value)
@classmethod
def from_row(cls, row: aiosqlite.Row) -> "TerminalNode":
return cls(
hash=bytes32(row["hash"]),
# generation=row["generation"],
key=row["key"],
value=row["value"],
)
@final
@dataclass(frozen=True)
class ProofOfInclusionLayer:
other_hash_side: Side
other_hash: bytes32
combined_hash: bytes32
@classmethod
def from_internal_node(
cls,
internal_node: "InternalNode",
traversal_child_hash: bytes32,
) -> "ProofOfInclusionLayer":
return ProofOfInclusionLayer(
other_hash_side=internal_node.other_child_side(hash=traversal_child_hash),
other_hash=internal_node.other_child_hash(hash=traversal_child_hash),
combined_hash=internal_node.hash,
)
@classmethod
def from_hashes(cls, primary_hash: bytes32, other_hash_side: Side, other_hash: bytes32) -> "ProofOfInclusionLayer":
combined_hash = calculate_internal_hash(
hash=primary_hash,
other_hash_side=other_hash_side,
other_hash=other_hash,
)
return cls(other_hash_side=other_hash_side, other_hash=other_hash, combined_hash=combined_hash)
other_side_to_bit = {Side.LEFT: 1, Side.RIGHT: 0}
@dataclass(frozen=True)
class ProofOfInclusion:
node_hash: bytes32
# children before parents
layers: List[ProofOfInclusionLayer]
@property
def root_hash(self) -> bytes32:
if len(self.layers) == 0:
return self.node_hash
return self.layers[-1].combined_hash
def sibling_sides_integer(self) -> int:
return sum(other_side_to_bit[layer.other_hash_side] << index for index, layer in enumerate(self.layers))
def sibling_hashes(self) -> List[bytes32]:
return [layer.other_hash for layer in self.layers]
def as_program(self) -> Program:
# https://github.com/Chia-Network/clvm/pull/102
# https://github.com/Chia-Network/clvm/pull/106
return Program.to([self.sibling_sides_integer(), self.sibling_hashes()]) # type: ignore[no-any-return]
def valid(self) -> bool:
existing_hash = self.node_hash
for layer in self.layers:
calculated_hash = calculate_internal_hash(
hash=existing_hash, other_hash_side=layer.other_hash_side, other_hash=layer.other_hash
)
if calculated_hash != layer.combined_hash:
return False
existing_hash = calculated_hash
if existing_hash != self.root_hash:
return False
return True
@dataclass(frozen=True)
class InternalNode:
hash: bytes32
# generation: int
left_hash: bytes32
right_hash: bytes32
pair: Optional[Tuple[Node, Node]] = None
atom: None = None
@classmethod
def from_row(cls, row: aiosqlite.Row) -> "InternalNode":
return cls(
hash=bytes32(row["hash"]),
# generation=row["generation"],
left_hash=bytes32(row["left"]),
right_hash=bytes32(row["right"]),
)
def other_child_hash(self, hash: bytes32) -> bytes32:
if self.left_hash == hash:
return self.right_hash
elif self.right_hash == hash:
return self.left_hash
# TODO: real exception considerations
raise Exception("provided hash not present")
def other_child_side(self, hash: bytes32) -> Side:
if self.left_hash == hash:
return Side.RIGHT
elif self.right_hash == hash:
return Side.LEFT
# TODO: real exception considerations
raise Exception("provided hash not present")
@dataclass(frozen=True)
class Root:
tree_id: bytes32
node_hash: Optional[bytes32]
generation: int
status: Status
@classmethod
def from_row(cls, row: aiosqlite.Row) -> "Root":
raw_node_hash = row["node_hash"]
if raw_node_hash is None:
node_hash = None
else:
node_hash = bytes32(raw_node_hash)
return cls(
tree_id=bytes32(row["tree_id"]),
node_hash=node_hash,
generation=row["generation"],
status=Status(row["status"]),
)
node_type_to_class: Dict[NodeType, Union[Type[InternalNode], Type[TerminalNode]]] = {
NodeType.INTERNAL: InternalNode,
NodeType.TERMINAL: TerminalNode,
}
@dataclass(frozen=True)
class ServerInfo:
url: str
num_consecutive_failures: int
ignore_till: int
@dataclass(frozen=True)
class Subscription:
tree_id: bytes32
servers_info: List[ServerInfo]
@dataclass(frozen=True)
class DiffData:
type: OperationType
key: bytes
value: bytes
@streamable
@dataclass(frozen=True)
class SerializedNode(Streamable):
is_terminal: bool
value1: bytes
value2: bytes
@final
@dataclasses.dataclass(frozen=True)
class KeyValue:
key: bytes
value: bytes
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> KeyValue:
return cls(
key=hexstr_to_bytes(marshalled["key"]),
value=hexstr_to_bytes(marshalled["value"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"key": self.key.hex(),
"value": self.value.hex(),
}
@dataclasses.dataclass(frozen=True)
class OfferStore:
store_id: bytes32
inclusions: Tuple[KeyValue, ...]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> OfferStore:
return cls(
store_id=bytes32.from_hexstr(marshalled["store_id"]),
inclusions=tuple(KeyValue.unmarshal(key_value) for key_value in marshalled["inclusions"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"store_id": self.store_id.hex(),
"inclusions": [key_value.marshal() for key_value in self.inclusions],
}
@dataclasses.dataclass(frozen=True)
class Layer:
# This class is similar to chia.data_layer.data_layer_util.ProofOfInclusionLayer
# but is being retained for now to keep the API schema definition localized here.
other_hash_side: Side
other_hash: bytes32
combined_hash: bytes32
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> Layer:
return cls(
other_hash_side=Side.unmarshal(marshalled["other_hash_side"]),
other_hash=bytes32.from_hexstr(marshalled["other_hash"]),
combined_hash=bytes32.from_hexstr(marshalled["combined_hash"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"other_hash_side": self.other_hash_side.marshal(),
"other_hash": self.other_hash.hex(),
"combined_hash": self.combined_hash.hex(),
}
@dataclasses.dataclass(frozen=True)
class MakeOfferRequest:
maker: Tuple[OfferStore, ...]
taker: Tuple[OfferStore, ...]
fee: Optional[uint64]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> MakeOfferRequest:
return cls(
maker=tuple(OfferStore.unmarshal(offer_store) for offer_store in marshalled["maker"]),
taker=tuple(OfferStore.unmarshal(offer_store) for offer_store in marshalled["taker"]),
fee=None if marshalled["fee"] is None else uint64(marshalled["fee"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"maker": [offer_store.marshal() for offer_store in self.maker],
"taker": [offer_store.marshal() for offer_store in self.taker],
"fee": None if self.fee is None else int(self.fee),
}
@dataclasses.dataclass(frozen=True)
class Proof:
key: bytes
value: bytes
node_hash: bytes32
layers: Tuple[Layer, ...]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> Proof:
return cls(
key=hexstr_to_bytes(marshalled["key"]),
value=hexstr_to_bytes(marshalled["value"]),
node_hash=bytes32.from_hexstr(marshalled["node_hash"]),
layers=tuple(Layer.unmarshal(layer) for layer in marshalled["layers"]),
)
def root(self) -> bytes32:
if len(self.layers) == 0:
return self.node_hash
return self.layers[-1].combined_hash
def marshal(self) -> Dict[str, Any]:
return {
"key": self.key.hex(),
"value": self.value.hex(),
"node_hash": self.node_hash.hex(),
"layers": [layer.marshal() for layer in self.layers],
}
@dataclasses.dataclass(frozen=True)
class StoreProofs:
store_id: bytes32
proofs: Tuple[Proof, ...]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> StoreProofs:
return cls(
store_id=bytes32.from_hexstr(marshalled["store_id"]),
proofs=tuple(Proof.unmarshal(proof) for proof in marshalled["proofs"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"store_id": self.store_id.hex(),
"proofs": [proof.marshal() for proof in self.proofs],
}
@dataclasses.dataclass(frozen=True)
class Offer:
trade_id: bytes
offer: bytes
taker: Tuple[OfferStore, ...]
maker: Tuple[StoreProofs, ...]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> Offer:
return cls(
trade_id=bytes32.from_hexstr(marshalled["trade_id"]),
offer=hexstr_to_bytes(marshalled["offer"]),
taker=tuple(OfferStore.unmarshal(offer_store) for offer_store in marshalled["taker"]),
maker=tuple(StoreProofs.unmarshal(store_proof) for store_proof in marshalled["maker"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"trade_id": self.trade_id.hex(),
"offer": self.offer.hex(),
"taker": [offer_store.marshal() for offer_store in self.taker],
"maker": [store_proofs.marshal() for store_proofs in self.maker],
}
@dataclasses.dataclass(frozen=True)
class MakeOfferResponse:
success: bool
offer: Offer
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> MakeOfferResponse:
return cls(
success=marshalled["success"],
offer=Offer.unmarshal(marshalled["offer"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"success": self.success,
"offer": self.offer.marshal(),
}
@dataclasses.dataclass(frozen=True)
class TakeOfferRequest:
offer: Offer
fee: Optional[uint64]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> TakeOfferRequest:
return cls(
offer=Offer.unmarshal(marshalled["offer"]),
fee=None if marshalled["fee"] is None else uint64(marshalled["fee"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"offer": self.offer.marshal(),
"fee": None if self.fee is None else int(self.fee),
}
@dataclasses.dataclass(frozen=True)
class TakeOfferResponse:
success: bool
trade_id: bytes32
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> TakeOfferResponse:
return cls(
success=marshalled["success"],
trade_id=bytes32.from_hexstr(marshalled["trade_id"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"success": self.success,
"trade_id": self.trade_id.hex(),
}
@final
@dataclasses.dataclass(frozen=True)
class VerifyOfferResponse:
success: bool
valid: bool
error: Optional[str] = None
fee: Optional[uint64] = None
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> VerifyOfferResponse:
return cls(
success=marshalled["success"],
valid=marshalled["valid"],
error=marshalled["error"],
fee=None if marshalled["fee"] is None else uint64(marshalled["fee"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"success": self.success,
"valid": self.valid,
"error": self.error,
"fee": None if self.fee is None else int(self.fee),
}
@dataclasses.dataclass(frozen=True)
class CancelOfferRequest:
trade_id: bytes32
# cancel on chain (secure) vs. just locally
secure: bool
fee: Optional[uint64]
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> CancelOfferRequest:
return cls(
trade_id=bytes32.from_hexstr(marshalled["trade_id"]),
secure=marshalled["secure"],
fee=None if marshalled["fee"] is None else uint64(marshalled["fee"]),
)
def marshal(self) -> Dict[str, Any]:
return {
"trade_id": self.trade_id.hex(),
"secure": self.secure,
"fee": None if self.fee is None else int(self.fee),
}
@dataclasses.dataclass(frozen=True)
class CancelOfferResponse:
success: bool
@classmethod
def unmarshal(cls, marshalled: Dict[str, Any]) -> CancelOfferResponse:
return cls(
success=marshalled["success"],
)
def marshal(self) -> Dict[str, Any]:
return {
"success": self.success,
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,329 @@
import dataclasses
from typing import List, Optional, Type, TypeVar, Union
from aiosqlite import Row
from chia.data_layer.data_layer_wallet import Mirror, SingletonRecord
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.db_wrapper import DBWrapper2
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.lineage_proof import LineageProof
_T_DataLayerStore = TypeVar("_T_DataLayerStore", bound="DataLayerStore")
def _row_to_singleton_record(row: Row) -> SingletonRecord:
return SingletonRecord(
bytes32(row[0]),
bytes32(row[1]),
bytes32(row[2]),
bytes32(row[3]),
bool(row[4]),
uint32(row[5]),
LineageProof.from_bytes(row[6]),
uint32(row[7]),
uint64(row[8]),
)
def _row_to_mirror(row: Row) -> Mirror:
urls: List[bytes] = []
byte_list: bytes = row[3]
while byte_list != b"":
length = uint16.from_bytes(byte_list[0:2])
url = byte_list[2 : length + 2]
byte_list = byte_list[length + 2 :]
urls.append(url)
return Mirror(bytes32(row[0]), bytes32(row[1]), uint64.from_bytes(row[2]), urls, bool(row[4]))
class DataLayerStore:
"""
WalletUserStore keeps track of all user created wallets and necessary smart-contract data
"""
db_wrapper: DBWrapper2
@classmethod
async def create(cls: Type[_T_DataLayerStore], db_wrapper: DBWrapper2) -> _T_DataLayerStore:
self = cls()
self.db_wrapper = db_wrapper
async with self.db_wrapper.writer_maybe_transaction() as conn:
await conn.execute(
(
"CREATE TABLE IF NOT EXISTS singleton_records("
"coin_id blob PRIMARY KEY,"
" launcher_id blob,"
" root blob,"
" inner_puzzle_hash blob,"
" confirmed tinyint,"
" confirmed_at_height int,"
" proof blob,"
" generation int," # This first singleton will be 0, then 1, and so on. This is handled by the DB.
" timestamp int)"
)
)
await conn.execute(
(
"CREATE TABLE IF NOT EXISTS mirrors("
"coin_id blob PRIMARY KEY,"
"launcher_id blob,"
"amount blob,"
"urls blob,"
"ours tinyint)"
)
)
await conn.execute("CREATE INDEX IF NOT EXISTS coin_id on singleton_records(coin_id)")
await conn.execute("CREATE INDEX IF NOT EXISTS launcher_id on singleton_records(launcher_id)")
await conn.execute("CREATE INDEX IF NOT EXISTS root on singleton_records(root)")
await conn.execute("CREATE INDEX IF NOT EXISTS inner_puzzle_hash on singleton_records(inner_puzzle_hash)")
await conn.execute("CREATE INDEX IF NOT EXISTS confirmed_at_height on singleton_records(root)")
await conn.execute("CREATE INDEX IF NOT EXISTS generation on singleton_records(generation)")
await conn.execute(("CREATE TABLE IF NOT EXISTS launchers(id blob PRIMARY KEY, coin blob)"))
await conn.execute("CREATE INDEX IF NOT EXISTS id on launchers(id)")
return self
async def _clear_database(self) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
await (await conn.execute("DELETE FROM singleton_records")).close()
async def add_singleton_record(self, record: SingletonRecord) -> None:
"""
Store SingletonRecord in DB.
"""
async with self.db_wrapper.writer_maybe_transaction() as conn:
await conn.execute_insert(
"INSERT OR REPLACE INTO singleton_records VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
record.coin_id,
record.launcher_id,
record.root,
record.inner_puzzle_hash,
int(record.confirmed),
record.confirmed_at_height,
bytes(record.lineage_proof),
record.generation,
record.timestamp,
),
)
async def get_all_singletons_for_launcher(
self,
launcher_id: bytes32,
min_generation: Optional[uint32] = None,
max_generation: Optional[uint32] = None,
num_results: Optional[uint32] = None,
) -> List[SingletonRecord]:
"""
Returns stored singletons with a specific launcher ID.
"""
query_params: List[Union[bytes32, uint32]] = [launcher_id]
for optional_param in (min_generation, max_generation, num_results):
if optional_param is not None:
query_params.append(optional_param)
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute(
"SELECT * from singleton_records WHERE launcher_id=? "
f"{'AND generation >=? ' if min_generation is not None else ''}"
f"{'AND generation <=? ' if max_generation is not None else ''}"
"ORDER BY generation DESC"
f"{' LIMIT ?' if num_results is not None else ''}",
tuple(query_params),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
records.append(_row_to_singleton_record(row))
return records
async def get_singleton_record(self, coin_id: bytes32) -> Optional[SingletonRecord]:
"""
Checks DB for SingletonRecord with coin_id: coin_id and returns it.
"""
# if tx_id in self.tx_record_cache:
# return self.tx_record_cache[tx_id]
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute("SELECT * from singleton_records WHERE coin_id=?", (coin_id,))
row = await cursor.fetchone()
await cursor.close()
if row is not None:
return _row_to_singleton_record(row)
return None
async def get_latest_singleton(
self, launcher_id: bytes32, only_confirmed: bool = False
) -> Optional[SingletonRecord]:
"""
Checks DB for SingletonRecords with launcher_id: launcher_id and returns the most recent.
"""
# if tx_id in self.tx_record_cache:
# return self.tx_record_cache[tx_id]
async with self.db_wrapper.reader_no_transaction() as conn:
if only_confirmed:
# get latest confirmed root
cursor = await conn.execute(
"SELECT * from singleton_records WHERE launcher_id=? and confirmed = TRUE "
"ORDER BY generation DESC LIMIT 1",
(launcher_id,),
)
else:
cursor = await conn.execute(
"SELECT * from singleton_records WHERE launcher_id=? ORDER BY generation DESC LIMIT 1",
(launcher_id,),
)
row = await cursor.fetchone()
await cursor.close()
if row is not None:
return _row_to_singleton_record(row)
return None
async def get_unconfirmed_singletons(self, launcher_id: bytes32) -> List[SingletonRecord]:
"""
Returns all singletons with a specific launcher id that have not yet been marked confirmed
"""
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute(
"SELECT * from singleton_records WHERE launcher_id=? AND confirmed=0", (launcher_id,)
)
rows = await cursor.fetchall()
await cursor.close()
records = [_row_to_singleton_record(row) for row in rows]
return records
async def get_singletons_by_root(self, launcher_id: bytes32, root: bytes32) -> List[SingletonRecord]:
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute(
"SELECT * from singleton_records WHERE launcher_id=? AND root=? ORDER BY generation DESC",
(launcher_id, root),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
records.append(_row_to_singleton_record(row))
return records
async def set_confirmed(self, coin_id: bytes32, height: uint32, timestamp: uint64) -> None:
"""
Updates singleton record to be confirmed.
"""
current: Optional[SingletonRecord] = await self.get_singleton_record(coin_id)
if current is None or current.confirmed_at_height == height:
return
await self.add_singleton_record(
dataclasses.replace(current, confirmed=True, confirmed_at_height=height, timestamp=timestamp)
)
async def delete_singleton_record(self, coin_id: bytes32) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
await (await conn.execute("DELETE FROM singleton_records WHERE coin_id=?", (coin_id,))).close()
async def delete_singleton_records_by_launcher_id(self, launcher_id: bytes32) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
await (await conn.execute("DELETE FROM singleton_records WHERE launcher_id=?", (launcher_id,))).close()
async def add_launcher(self, launcher: Coin) -> None:
"""
Add a new launcher coin's information to the DB
"""
launcher_bytes: bytes = launcher.parent_coin_info + launcher.puzzle_hash + bytes(uint64(launcher.amount))
async with self.db_wrapper.writer_maybe_transaction() as conn:
await conn.execute_insert(
"INSERT OR REPLACE INTO launchers VALUES (?, ?)",
(launcher.name(), launcher_bytes),
)
async def get_launcher(self, launcher_id: bytes32) -> Optional[Coin]:
"""
Checks DB for a launcher with the specified ID and returns it.
"""
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute("SELECT * from launchers WHERE id=?", (launcher_id,))
row = await cursor.fetchone()
await cursor.close()
if row is not None:
return Coin(bytes32(row[1][0:32]), bytes32(row[1][32:64]), uint64(int.from_bytes(row[1][64:72], "big")))
return None
async def get_all_launchers(self) -> List[bytes32]:
"""
Checks DB for all launchers.
"""
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute("SELECT id from launchers")
rows = await cursor.fetchall()
await cursor.close()
return [bytes32(row[0]) for row in rows]
async def delete_launcher(self, launcher_id: bytes32) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
await (await conn.execute("DELETE FROM launchers WHERE id=?", (launcher_id,))).close()
async def add_mirror(self, mirror: Mirror) -> None:
"""
Add a mirror coin to the DB
"""
async with self.db_wrapper.writer_maybe_transaction() as conn:
await conn.execute_insert(
"INSERT OR REPLACE INTO mirrors VALUES (?, ?, ?, ?, ?)",
(
mirror.coin_id,
mirror.launcher_id,
bytes(mirror.amount),
b"".join([bytes(uint16(len(url))) + url for url in mirror.urls]), # prefix each item with a length
1 if mirror.ours else 0,
),
)
async def get_mirrors(self, launcher_id: bytes32) -> List[Mirror]:
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute(
"SELECT * from mirrors WHERE launcher_id=?",
(launcher_id,),
)
rows = await cursor.fetchall()
await cursor.close()
mirrors: List[Mirror] = []
for row in rows:
mirrors.append(_row_to_mirror(row))
return mirrors
async def get_mirror(self, coin_id: bytes32) -> Mirror:
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute(
"SELECT * from mirrors WHERE coin_id=?",
(coin_id,),
)
row = await cursor.fetchone()
await cursor.close()
assert row is not None
return _row_to_mirror(row)
async def delete_mirror(self, coin_id: bytes32) -> None:
async with self.db_wrapper.writer_maybe_transaction() as conn:
await (await conn.execute("DELETE FROM mirrors WHERE coin_id=?", (coin_id,))).close()

View File

@ -0,0 +1,184 @@
import asyncio
import logging
import os
import time
from pathlib import Path
from typing import List, Optional
import aiohttp
from typing_extensions import Literal
from chia.data_layer.data_layer_util import NodeType, Root, SerializedNode, ServerInfo, Status
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.sized_bytes import bytes32
def get_full_tree_filename(tree_id: bytes32, node_hash: bytes32, generation: int) -> str:
return f"{tree_id}-{node_hash}-full-{generation}-v1.0.dat"
def get_delta_filename(tree_id: bytes32, node_hash: bytes32, generation: int) -> str:
return f"{tree_id}-{node_hash}-delta-{generation}-v1.0.dat"
def is_filename_valid(filename: str) -> bool:
split = filename.split("-")
try:
raw_tree_id, raw_node_hash, file_type, raw_generation, raw_version, *rest = split
tree_id = bytes32(bytes.fromhex(raw_tree_id))
node_hash = bytes32(bytes.fromhex(raw_node_hash))
generation = int(raw_generation)
except ValueError:
return False
if len(rest) > 0:
return False
# TODO: versions should probably be centrally defined
if raw_version != "v1.0.dat":
return False
if file_type not in {"delta", "full"}:
return False
generate_file_func = get_delta_filename if file_type == "delta" else get_full_tree_filename
reformatted = generate_file_func(tree_id=tree_id, node_hash=node_hash, generation=generation)
return reformatted == filename
async def insert_into_data_store_from_file(
data_store: DataStore,
tree_id: bytes32,
root_hash: Optional[bytes32],
filename: Path,
) -> None:
with open(filename, "rb") as reader:
while True:
chunk = b""
while len(chunk) < 4:
size_to_read = 4 - len(chunk)
cur_chunk = reader.read(size_to_read)
if cur_chunk is None or cur_chunk == b"":
if size_to_read < 4:
raise Exception("Incomplete read of length.")
break
chunk += cur_chunk
if chunk == b"":
break
size = int.from_bytes(chunk, byteorder="big")
serialize_nodes_bytes = b""
while len(serialize_nodes_bytes) < size:
size_to_read = size - len(serialize_nodes_bytes)
cur_chunk = reader.read(size_to_read)
if cur_chunk is None or cur_chunk == b"":
raise Exception("Incomplete read of blob.")
serialize_nodes_bytes += cur_chunk
serialized_node = SerializedNode.from_bytes(serialize_nodes_bytes)
node_type = NodeType.TERMINAL if serialized_node.is_terminal else NodeType.INTERNAL
await data_store.insert_node(node_type, serialized_node.value1, serialized_node.value2)
await data_store.insert_root_with_ancestor_table(tree_id=tree_id, node_hash=root_hash, status=Status.COMMITTED)
async def write_files_for_root(
data_store: DataStore,
tree_id: bytes32,
root: Root,
foldername: Path,
override: bool = False,
) -> bool:
if root.node_hash is not None:
node_hash = root.node_hash
else:
node_hash = bytes32([0] * 32) # todo change
filename_full_tree = foldername.joinpath(get_full_tree_filename(tree_id, node_hash, root.generation))
filename_diff_tree = foldername.joinpath(get_delta_filename(tree_id, node_hash, root.generation))
written = False
mode: Literal["wb", "xb"] = "wb" if override else "xb"
try:
with open(filename_full_tree, mode) as writer:
await data_store.write_tree_to_file(root, node_hash, tree_id, False, writer)
written = True
except FileExistsError:
pass
try:
last_seen_generation = await data_store.get_last_tree_root_by_hash(
tree_id, root.node_hash, max_generation=root.generation
)
if last_seen_generation is None:
with open(filename_diff_tree, mode) as writer:
await data_store.write_tree_to_file(root, node_hash, tree_id, True, writer)
else:
open(filename_diff_tree, mode).close()
written = True
except FileExistsError:
pass
return written
async def insert_from_delta_file(
data_store: DataStore,
tree_id: bytes32,
existing_generation: int,
root_hashes: List[bytes32],
server_info: ServerInfo,
client_foldername: Path,
log: logging.Logger,
) -> bool:
for root_hash in root_hashes:
timestamp = int(time.time())
existing_generation += 1
filename = get_delta_filename(tree_id, root_hash, existing_generation)
try:
async with aiohttp.ClientSession() as session:
async with session.get(server_info.url + "/" + filename) as resp:
resp.raise_for_status()
target_filename = client_foldername.joinpath(filename)
text = await resp.read()
target_filename.write_bytes(text)
except Exception:
await data_store.server_misses_file(tree_id, server_info, timestamp)
raise
log.info(f"Successfully downloaded delta file {filename}.")
try:
await insert_into_data_store_from_file(
data_store,
tree_id,
None if root_hash == bytes32([0] * 32) else root_hash,
client_foldername.joinpath(filename),
)
log.info(
f"Successfully inserted hash {root_hash} from delta file. "
f"Generation: {existing_generation}. Tree id: {tree_id}."
)
filename_full_tree = client_foldername.joinpath(
get_full_tree_filename(tree_id, root_hash, existing_generation)
)
root = await data_store.get_tree_root(tree_id=tree_id)
with open(filename_full_tree, "wb") as writer:
await data_store.write_tree_to_file(root, root_hash, tree_id, False, writer)
log.info(f"Successfully written full tree filename {filename_full_tree}.")
await data_store.received_correct_file(tree_id, server_info)
except asyncio.CancelledError:
raise
except Exception:
target_filename = client_foldername.joinpath(filename)
os.remove(target_filename)
await data_store.received_incorrect_file(tree_id, server_info, timestamp)
await data_store.rollback_to_generation(tree_id, existing_generation - 1)
raise
return True

View File

View File

@ -0,0 +1,121 @@
import asyncio
import os
import sys
import tempfile
import time
from pathlib import Path
from typing import Dict, Optional
import aiosqlite
from chia.data_layer.data_layer_util import Side, TerminalNode, leaf_hash
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.db_wrapper import DBWrapper
async def generate_datastore(num_nodes: int, slow_mode: bool) -> None:
with tempfile.TemporaryDirectory() as temp_directory:
temp_directory_path = Path(temp_directory)
db_path = temp_directory_path.joinpath("dl_benchmark.sqlite")
print(f"Writing DB to {db_path}")
if os.path.exists(db_path):
os.remove(db_path)
connection = await aiosqlite.connect(db_path)
db_wrapper = DBWrapper(connection)
data_store = await DataStore.create(db_wrapper=db_wrapper)
hint_keys_values: Dict[bytes, bytes] = {}
tree_id = bytes32(b"0" * 32)
await data_store.create_tree(tree_id)
insert_time = 0.0
insert_count = 0
autoinsert_time = 0.0
autoinsert_count = 0
delete_time = 0.0
delete_count = 0
for i in range(num_nodes):
key = i.to_bytes(4, byteorder="big")
value = (2 * i).to_bytes(4, byteorder="big")
seed = leaf_hash(key=key, value=value)
reference_node_hash: Optional[bytes32] = await data_store.get_terminal_node_for_seed(tree_id, seed)
side: Optional[Side] = data_store.get_side_for_seed(seed)
if i == 0:
reference_node_hash = None
side = None
if i % 3 == 0:
t1 = time.time()
if not slow_mode:
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
hint_keys_values=hint_keys_values,
)
else:
await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
use_optimized=False,
)
t2 = time.time()
insert_time += t2 - t1
insert_count += 1
elif i % 3 == 1:
t1 = time.time()
if not slow_mode:
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
)
else:
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
use_optimized=False,
)
t2 = time.time()
autoinsert_time += t2 - t1
autoinsert_count += 1
else:
t1 = time.time()
assert reference_node_hash is not None
node = await data_store.get_node(reference_node_hash)
assert isinstance(node, TerminalNode)
if not slow_mode:
await data_store.delete(key=node.key, tree_id=tree_id, hint_keys_values=hint_keys_values)
else:
await data_store.delete(key=node.key, tree_id=tree_id, use_optimized=False)
t2 = time.time()
delete_time += t2 - t1
delete_count += 1
print(f"Average insert time: {insert_time / insert_count}")
print(f"Average autoinsert time: {autoinsert_time / autoinsert_count}")
print(f"Average delete time: {delete_time / delete_count}")
print(f"Total time for {num_nodes} operations: {insert_time + autoinsert_time + delete_time}")
root = await data_store.get_tree_root(tree_id=tree_id)
print(f"Root hash: {root.node_hash}")
await connection.close()
if __name__ == "__main__":
loop = asyncio.get_event_loop()
slow_mode = False
if len(sys.argv) > 2 and sys.argv[2] == "slow":
slow_mode = True
loop.run_until_complete(generate_datastore(int(sys.argv[1]), slow_mode))
loop.close()

View File

@ -55,6 +55,7 @@ version_data = copy_metadata(get_distribution("chia-blockchain"))[0]
block_cipher = None
SERVERS = [
"data_layer",
"wallet",
"full_node",
"harvester",

View File

@ -0,0 +1,402 @@
from __future__ import annotations
import dataclasses
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from chia.data_layer.data_layer_errors import OfferIntegrityError
from chia.data_layer.data_layer_util import (
CancelOfferRequest,
CancelOfferResponse,
MakeOfferRequest,
MakeOfferResponse,
Side,
Subscription,
TakeOfferRequest,
TakeOfferResponse,
VerifyOfferResponse,
)
from chia.data_layer.data_layer_wallet import DataLayerWallet, Mirror, verify_offer
from chia.rpc.data_layer_rpc_util import marshal
from chia.rpc.rpc_server import Endpoint, EndpointResult
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
# todo input assertions for all rpc's
from chia.util.ints import uint64
from chia.util.streamable import recurse_jsonify
from chia.wallet.trading.offer import Offer as TradingOffer
if TYPE_CHECKING:
from chia.data_layer.data_layer import DataLayer
def process_change(change: Dict[str, Any]) -> Dict[str, Any]:
# TODO: A full class would likely be nice for this so downstream doesn't
# have to deal with maybe-present attributes or Dict[str, Any] hints.
reference_node_hash = change.get("reference_node_hash")
if reference_node_hash is not None:
reference_node_hash = bytes32(hexstr_to_bytes(reference_node_hash))
side = change.get("side")
if side is not None:
side = Side(side)
value = change.get("value")
if value is not None:
value = hexstr_to_bytes(value)
return {
**change,
"key": hexstr_to_bytes(change["key"]),
"value": value,
"reference_node_hash": reference_node_hash,
"side": side,
}
def get_fee(config: Dict[str, Any], request: Dict[str, Any]) -> uint64:
fee = request.get("fee")
if fee is None:
config_fee = config.get("fee", 0)
return uint64(config_fee)
return uint64(fee)
class DataLayerRpcApi:
# TODO: other RPC APIs do not accept a wallet and the service start does not expect to provide one
def __init__(self, data_layer: DataLayer): # , wallet: DataLayerWallet):
self.service: DataLayer = data_layer
self.service_name = "chia_data_layer"
def get_routes(self) -> Dict[str, Endpoint]:
return {
"/create_data_store": self.create_data_store,
"/get_owned_stores": self.get_owned_stores,
"/batch_update": self.batch_update,
"/get_value": self.get_value,
"/get_keys": self.get_keys,
"/get_keys_values": self.get_keys_values,
"/get_ancestors": self.get_ancestors,
"/get_root": self.get_root,
"/get_local_root": self.get_local_root,
"/get_roots": self.get_roots,
"/delete_key": self.delete_key,
"/insert": self.insert,
"/subscribe": self.subscribe,
"/unsubscribe": self.unsubscribe,
"/add_mirror": self.add_mirror,
"/delete_mirror": self.delete_mirror,
"/get_mirrors": self.get_mirrors,
"/remove_subscriptions": self.remove_subscriptions,
"/subscriptions": self.subscriptions,
"/get_kv_diff": self.get_kv_diff,
"/get_root_history": self.get_root_history,
"/add_missing_files": self.add_missing_files,
"/make_offer": self.make_offer,
"/take_offer": self.take_offer,
"/verify_offer": self.verify_offer,
"/cancel_offer": self.cancel_offer,
}
async def create_data_store(self, request: Dict[str, Any]) -> EndpointResult:
if self.service is None:
raise Exception("Data layer not created")
fee = get_fee(self.service.config, request)
txs, value = await self.service.create_store(uint64(fee))
return {"txs": txs, "id": value.hex()}
async def get_owned_stores(self, request: Dict[str, Any]) -> EndpointResult:
if self.service is None:
raise Exception("Data layer not created")
singleton_records = await self.service.get_owned_stores()
return {"store_ids": [singleton.launcher_id.hex() for singleton in singleton_records]}
async def get_value(self, request: Dict[str, Any]) -> EndpointResult:
store_id = bytes32.from_hexstr(request["id"])
key = hexstr_to_bytes(request["key"])
if self.service is None:
raise Exception("Data layer not created")
value = await self.service.get_value(store_id=store_id, key=key)
hex = None
if value is not None:
hex = value.hex()
return {"value": hex}
async def get_keys(self, request: Dict[str, Any]) -> EndpointResult:
store_id = bytes32.from_hexstr(request["id"])
root_hash = request.get("root_hash")
if root_hash is not None:
root_hash = bytes32.from_hexstr(root_hash)
if self.service is None:
raise Exception("Data layer not created")
keys = await self.service.get_keys(store_id, root_hash)
return {"keys": [f"0x{key.hex()}" for key in keys]}
async def get_keys_values(self, request: Dict[str, Any]) -> EndpointResult:
store_id = bytes32(hexstr_to_bytes(request["id"]))
root_hash = request.get("root_hash")
if root_hash is not None:
root_hash = bytes32.from_hexstr(root_hash)
if self.service is None:
raise Exception("Data layer not created")
res = await self.service.get_keys_values(store_id, root_hash)
json_nodes = []
for node in res:
json = recurse_jsonify(dataclasses.asdict(node))
json_nodes.append(json)
return {"keys_values": json_nodes}
async def get_ancestors(self, request: Dict[str, Any]) -> EndpointResult:
store_id = bytes32(hexstr_to_bytes(request["id"]))
node_hash = bytes32.from_hexstr(request["hash"])
if self.service is None:
raise Exception("Data layer not created")
value = await self.service.get_ancestors(node_hash, store_id)
return {"ancestors": value}
async def batch_update(self, request: Dict[str, Any]) -> EndpointResult:
"""
id - the id of the store we are operating on
changelist - a list of changes to apply on store
"""
fee = get_fee(self.service.config, request)
changelist = [process_change(change) for change in request["changelist"]]
store_id = bytes32(hexstr_to_bytes(request["id"]))
# todo input checks
if self.service is None:
raise Exception("Data layer not created")
transaction_record = await self.service.batch_update(store_id, changelist, uint64(fee))
if transaction_record is None:
raise Exception(f"Batch update failed for: {store_id}")
return {"tx_id": transaction_record.name}
async def insert(self, request: Dict[str, Any]) -> EndpointResult:
"""
rows_to_add a list of clvm objects as bytes to add to talbe
rows_to_remove a list of row hashes to remove
"""
fee = get_fee(self.service.config, request)
key = hexstr_to_bytes(request["key"])
value = hexstr_to_bytes(request["value"])
store_id = bytes32(hexstr_to_bytes(request["id"]))
# todo input checks
if self.service is None:
raise Exception("Data layer not created")
changelist = [{"action": "insert", "key": key, "value": value}]
transaction_record = await self.service.batch_update(store_id, changelist, uint64(fee))
return {"tx_id": transaction_record.name}
async def delete_key(self, request: Dict[str, Any]) -> EndpointResult:
"""
rows_to_add a list of clvm objects as bytes to add to talbe
rows_to_remove a list of row hashes to remove
"""
fee = get_fee(self.service.config, request)
key = hexstr_to_bytes(request["key"])
store_id = bytes32(hexstr_to_bytes(request["id"]))
# todo input checks
if self.service is None:
raise Exception("Data layer not created")
changelist = [{"action": "delete", "key": key}]
transaction_record = await self.service.batch_update(store_id, changelist, uint64(fee))
return {"tx_id": transaction_record.name}
async def get_root(self, request: Dict[str, Any]) -> EndpointResult:
"""get hash of latest tree root"""
store_id = bytes32(hexstr_to_bytes(request["id"]))
# todo input checks
if self.service is None:
raise Exception("Data layer not created")
rec = await self.service.get_root(store_id)
if rec is None:
raise Exception(f"Failed to get root for {store_id.hex()}")
return {"hash": rec.root, "confirmed": rec.confirmed, "timestamp": rec.timestamp}
async def get_local_root(self, request: Dict[str, Any]) -> EndpointResult:
"""get hash of latest tree root saved in our local datastore"""
store_id = bytes32(hexstr_to_bytes(request["id"]))
# todo input checks
if self.service is None:
raise Exception("Data layer not created")
res = await self.service.get_local_root(store_id)
return {"hash": res}
async def get_roots(self, request: Dict[str, Any]) -> EndpointResult:
"""
get state hashes for a list of roots
"""
store_ids = request["ids"]
# todo input checks
if self.service is None:
raise Exception("Data layer not created")
roots = []
for id in store_ids:
id_bytes = bytes32.from_hexstr(id)
rec = await self.service.get_root(id_bytes)
if rec is not None:
roots.append({"id": id_bytes, "hash": rec.root, "confirmed": rec.confirmed, "timestamp": rec.timestamp})
return {"root_hashes": roots}
async def subscribe(self, request: Dict[str, Any]) -> EndpointResult:
"""
subscribe to singleton
"""
store_id = request.get("id")
if store_id is None:
raise Exception("missing store id in request")
if self.service is None:
raise Exception("Data layer not created")
store_id_bytes = bytes32.from_hexstr(store_id)
urls = request.get("urls", [])
await self.service.subscribe(store_id=store_id_bytes, urls=urls)
return {}
async def unsubscribe(self, request: Dict[str, Any]) -> EndpointResult:
"""
unsubscribe from singleton
"""
store_id = request.get("id")
if store_id is None:
raise Exception("missing store id in request")
if self.service is None:
raise Exception("Data layer not created")
store_id_bytes = bytes32.from_hexstr(store_id)
await self.service.unsubscribe(store_id_bytes)
return {}
async def subscriptions(self, request: Dict[str, Any]) -> EndpointResult:
"""
List current subscriptions
"""
if self.service is None:
raise Exception("Data layer not created")
subscriptions: List[Subscription] = await self.service.get_subscriptions()
return {"store_ids": [sub.tree_id.hex() for sub in subscriptions]}
async def remove_subscriptions(self, request: Dict[str, Any]) -> EndpointResult:
if self.service is None:
raise Exception("Data layer not created")
store_id = request.get("id")
if store_id is None:
raise Exception("missing store id in request")
store_id_bytes = bytes32.from_hexstr(store_id)
urls = request["urls"]
await self.service.remove_subscriptions(store_id=store_id_bytes, urls=urls)
return {}
async def add_missing_files(self, request: Dict[str, Any]) -> EndpointResult:
"""
complete the data server files.
"""
if "ids" in request:
store_ids = request["ids"]
ids_bytes = [bytes32.from_hexstr(id) for id in store_ids]
else:
subscriptions: List[Subscription] = await self.service.get_subscriptions()
ids_bytes = [subscription.tree_id for subscription in subscriptions]
override = request.get("override", False)
foldername: Optional[Path] = None
if "foldername" in request:
foldername = Path(request["foldername"])
for tree_id in ids_bytes:
await self.service.add_missing_files(tree_id, override, foldername)
return {}
async def get_root_history(self, request: Dict[str, Any]) -> EndpointResult:
"""
get history of state hashes for a store
"""
if self.service is None:
raise Exception("Data layer not created")
store_id = request["id"]
id_bytes = bytes32.from_hexstr(store_id)
records = await self.service.get_root_history(id_bytes)
res: List[Dict[str, Any]] = []
for rec in records:
res.insert(0, {"root_hash": rec.root, "confirmed": rec.confirmed, "timestamp": rec.timestamp})
return {"root_history": res}
async def get_kv_diff(self, request: Dict[str, Any]) -> EndpointResult:
"""
get kv diff between two root hashes
"""
if self.service is None:
raise Exception("Data layer not created")
store_id = request["id"]
id_bytes = bytes32.from_hexstr(store_id)
hash_1 = request["hash_1"]
hash_1_bytes = bytes32.from_hexstr(hash_1)
hash_2 = request["hash_2"]
hash_2_bytes = bytes32.from_hexstr(hash_2)
records = await self.service.get_kv_diff(id_bytes, hash_1_bytes, hash_2_bytes)
res: List[Dict[str, Any]] = []
for rec in records:
res.insert(0, {"type": rec.type.name, "key": rec.key.hex(), "value": rec.value.hex()})
return {"diff": res}
async def add_mirror(self, request: Dict[str, Any]) -> EndpointResult:
store_id = request["id"]
id_bytes = bytes32.from_hexstr(store_id)
urls = request["urls"]
amount = request["amount"]
fee = get_fee(self.service.config, request)
await self.service.add_mirror(id_bytes, urls, amount, fee)
return {}
async def delete_mirror(self, request: Dict[str, Any]) -> EndpointResult:
coin_id = request["id"]
id_bytes = bytes32.from_hexstr(coin_id)
fee = get_fee(self.service.config, request)
await self.service.delete_mirror(id_bytes, fee)
return {}
async def get_mirrors(self, request: Dict[str, Any]) -> EndpointResult:
store_id = request["id"]
id_bytes = bytes32.from_hexstr(store_id)
mirrors: List[Mirror] = await self.service.get_mirrors(id_bytes)
return {"mirrors": [mirror.to_json_dict() for mirror in mirrors]}
@marshal() # type: ignore[arg-type]
async def make_offer(self, request: MakeOfferRequest) -> MakeOfferResponse:
fee = get_fee(self.service.config, {"fee": request.fee})
offer = await self.service.make_offer(maker=request.maker, taker=request.taker, fee=fee)
return MakeOfferResponse(success=True, offer=offer)
@marshal() # type: ignore[arg-type]
async def take_offer(self, request: TakeOfferRequest) -> TakeOfferResponse:
fee = get_fee(self.service.config, {"fee": request.fee})
trade_record = await self.service.take_offer(
offer_bytes=request.offer.offer,
maker=request.offer.maker,
taker=request.offer.taker,
fee=fee,
)
return TakeOfferResponse(success=True, trade_id=trade_record.trade_id)
@marshal() # type: ignore[arg-type]
async def verify_offer(self, request: TakeOfferRequest) -> VerifyOfferResponse:
fee = get_fee(self.service.config, {"fee": request.fee})
offer = TradingOffer.from_bytes(request.offer.offer)
summary = await DataLayerWallet.get_offer_summary(offer=offer)
try:
verify_offer(maker=request.offer.maker, taker=request.offer.taker, summary=summary)
except OfferIntegrityError as e:
return VerifyOfferResponse(success=True, valid=False, error=str(e))
return VerifyOfferResponse(success=True, valid=True, fee=fee)
@marshal() # type: ignore[arg-type]
async def cancel_offer(self, request: CancelOfferRequest) -> CancelOfferResponse:
fee = get_fee(self.service.config, {"fee": request.fee})
await self.service.cancel_offer(
trade_id=request.trade_id,
secure=request.secure,
fee=fee,
)
return CancelOfferResponse(success=True)

View File

@ -0,0 +1,112 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from chia.rpc.rpc_client import RpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
class DataLayerRpcClient(RpcClient):
async def create_data_store(self, fee: Optional[uint64]) -> Dict[str, Any]:
response = await self.fetch("create_data_store", {"fee": fee})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_value(self, store_id: bytes32, key: bytes) -> Dict[str, Any]:
response = await self.fetch("get_value", {"id": store_id.hex(), "key": key.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def update_data_store(
self, store_id: bytes32, changelist: List[Dict[str, str]], fee: Optional[uint64]
) -> Dict[str, Any]:
response = await self.fetch("batch_update", {"id": store_id.hex(), "changelist": changelist, "fee": fee})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_keys_values(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_keys_values", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_keys(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_keys", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_ancestors(self, store_id: bytes32, hash: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_ancestors", {"id": store_id.hex(), "hash": hash})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_root(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_root", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_local_root(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_local_root", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def get_roots(self, store_ids: List[bytes32]) -> Dict[str, Any]:
response = await self.fetch("get_roots", {"ids": store_ids})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
async def subscribe(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("subscribe", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("remove_subscriptions", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
async def unsubscribe(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("unsubscribe", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
async def add_missing_files(
self, store_ids: Optional[List[bytes32]], override: Optional[bool], foldername: Optional[Path]
) -> Dict[str, Any]:
request: Dict[str, Any] = {}
if store_ids is not None:
request["ids"] = [store_id.hex() for store_id in store_ids]
if override is not None:
request["override"] = override
if foldername is not None:
request["foldername"] = str(foldername)
response = await self.fetch("add_missing_files", request)
return response # type: ignore[no-any-return]
async def get_kv_diff(self, store_id: bytes32, hash_1: bytes32, hash_2: bytes32) -> Dict[str, Any]:
response = await self.fetch(
"get_kv_diff", {"id": store_id.hex(), "hash_1": hash_1.hex(), "hash_2": hash_2.hex()}
)
return response # type: ignore[no-any-return]
async def get_root_history(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_root_history", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
async def add_mirror(
self, store_id: bytes32, urls: List[str], amount: int, fee: Optional[uint64]
) -> Dict[str, Any]:
response = await self.fetch("add_mirror", {"id": store_id.hex(), "urls": urls, "amount": amount, "fee": fee})
return response # type: ignore[no-any-return]
async def delete_mirror(self, coin_id: bytes32, fee: Optional[uint64]) -> Dict[str, Any]:
response = await self.fetch("delete_mirror", {"id": coin_id.hex(), "fee": fee})
return response # type: ignore[no-any-return]
async def get_mirrors(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_mirrors", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
async def get_subscriptions(self) -> Dict[str, Any]:
response = await self.fetch("subscriptions", {})
return response # type: ignore[no-any-return]
async def get_owned_stores(self) -> Dict[str, Any]:
response = await self.fetch("get_owned_stores", {})
return response # type: ignore[no-any-return]

View File

@ -0,0 +1,62 @@
from typing import Any, Dict, Type, TypeVar
from typing_extensions import Protocol
_T = TypeVar("_T")
# If accepted for general use then this should be moved to a common location
# and probably implemented by the framework instead of manual decoration.
class MarshallableProtocol(Protocol):
@classmethod
def unmarshal(cls: Type[_T], marshalled: Dict[str, Any]) -> _T:
...
def marshal(self) -> Dict[str, Any]:
...
class UnboundRoute(Protocol):
async def __call__(self, request: Dict[str, Any]) -> Dict[str, Any]:
pass
class UnboundMarshalledRoute(Protocol):
# Ignoring pylint complaint about the name of the first argument since this is a
# special case.
async def __call__( # pylint: disable=E0213
protocol_self, self: Any, request: MarshallableProtocol
) -> MarshallableProtocol:
pass
class RouteDecorator(Protocol):
def __call__(self, route: UnboundMarshalledRoute) -> UnboundRoute:
pass
def marshal() -> RouteDecorator:
def decorator(route: UnboundMarshalledRoute) -> UnboundRoute:
from typing import get_type_hints
hints = get_type_hints(route)
request_class: Type[MarshallableProtocol] = hints["request"]
async def wrapper(self: object, request: Dict[str, object]) -> Dict[str, object]:
# import json
# name = route.__name__
# print(f"\n ==== {name} request.json\n{json.dumps(request, indent=4)}")
unmarshalled_request = request_class.unmarshal(request)
response = await route(self, request=unmarshalled_request)
marshalled_response = response.marshal()
# print(f"\n ==== {name} response.json\n{json.dumps(marshalled_response, indent=4)}")
return marshalled_response
# type ignoring since mypy is having issues with bound vs. unbound methods
return wrapper # type: ignore[return-value]
return decorator

View File

@ -10,7 +10,7 @@ from ssl import SSLContext
from typing import Any, Awaitable, Callable, Dict, List, Optional
from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web
from typing_extensions import final
from typing_extensions import Protocol, final
from chia.rpc.util import wrap_http_handler
from chia.server.outbound_message import NodeType
@ -37,6 +37,11 @@ class RpcEnvironment:
listen_port: uint16
class RpcApiProtocol(Protocol):
def get_routes(self) -> Dict[str, Endpoint]:
pass
@final
@dataclass
class RpcServer:
@ -382,5 +387,5 @@ async def start_rpc_server(
return rpc_server
except Exception:
tb = traceback.format_exc()
log.error(f"Starting RPC server failed. Exception {tb}.")
log.error(f"Starting RPC server failed. Exception {tb}")
raise

View File

@ -5,9 +5,10 @@ import logging
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from blspy import G1Element, PrivateKey
from blspy import G1Element, G2Element, PrivateKey
from chia.consensus.block_rewards import calculate_base_farmer_reward
from chia.data_layer.data_layer_wallet import DataLayerWallet
from chia.pools.pool_wallet import PoolWallet
from chia.pools.pool_wallet_info import FARMING_TO_POOL, PoolState, PoolWalletInfo, create_pool_state
from chia.protocols.protocol_message_types import ProtocolMessageTypes
@ -46,7 +47,7 @@ from chia.wallet.nft_wallet.nft_puzzles import get_metadata_and_phs
from chia.wallet.nft_wallet.nft_wallet import NFTWallet
from chia.wallet.nft_wallet.uncurry_nft import UncurriedNFT
from chia.wallet.outer_puzzles import AssetType
from chia.wallet.puzzle_drivers import PuzzleInfo
from chia.wallet.puzzle_drivers import PuzzleInfo, Solver
from chia.wallet.rl_wallet.rl_wallet import RLWallet
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import Offer
@ -162,6 +163,19 @@ class WalletRpcApi:
"/pw_self_pool": self.pw_self_pool,
"/pw_absorb_rewards": self.pw_absorb_rewards,
"/pw_status": self.pw_status,
# DL Wallet
"/create_new_dl": self.create_new_dl,
"/dl_track_new": self.dl_track_new,
"/dl_stop_tracking": self.dl_stop_tracking,
"/dl_latest_singleton": self.dl_latest_singleton,
"/dl_singletons_by_root": self.dl_singletons_by_root,
"/dl_update_root": self.dl_update_root,
"/dl_update_multiple": self.dl_update_multiple,
"/dl_history": self.dl_history,
"/dl_owned_singletons": self.dl_owned_singletons,
"/dl_get_mirrors": self.dl_get_mirrors,
"/dl_new_mirror": self.dl_new_mirror,
"/dl_delete_mirror": self.dl_delete_mirror,
}
async def _state_changed(self, change: str, change_data: Dict[str, Any]) -> List[WsRpcMessage]:
@ -1023,6 +1037,12 @@ class WalletRpcApi:
validate_only: bool = request.get("validate_only", False)
driver_dict_str: Optional[Dict[str, Any]] = request.get("driver_dict", None)
min_coin_amount: uint64 = uint64(request.get("min_coin_amount", 0))
marshalled_solver = request.get("solver")
solver: Optional[Solver]
if marshalled_solver is None:
solver = None
else:
solver = Solver(info=marshalled_solver)
# This driver_dict construction is to maintain backward compatibility where everything is assumed to be a CAT
driver_dict: Dict[bytes32, PuzzleInfo] = {}
@ -1048,7 +1068,12 @@ class WalletRpcApi:
async with self.service.wallet_state_manager.lock:
result = await self.service.wallet_state_manager.trade_manager.create_offer_for_ids(
modified_offer, driver_dict, fee=fee, validate_only=validate_only, min_coin_amount=min_coin_amount
modified_offer,
driver_dict,
solver=solver,
fee=fee,
validate_only=validate_only,
min_coin_amount=min_coin_amount,
)
if result[0]:
success, trade_record, error = result
@ -1087,7 +1112,12 @@ class WalletRpcApi:
raise ValueError("CAT1s are no longer supported")
###
return {"summary": {"offered": offered, "requested": requested, "fees": offer.bundle.fees(), "infos": infos}}
if request.get("advanced", False):
return {
"summary": {"offered": offered, "requested": requested, "fees": offer.bundle.fees(), "infos": infos}
}
else:
return {"summary": await self.service.wallet_state_manager.trade_manager.get_offer_summary(offer)}
async def check_offer_validity(self, request) -> EndpointResult:
offer_hex: str = request["offer"]
@ -1102,13 +1132,19 @@ class WalletRpcApi:
offer = Offer.from_bech32(offer_hex)
fee: uint64 = uint64(request.get("fee", 0))
min_coin_amount: uint64 = uint64(request.get("min_coin_amount", 0))
maybe_marshalled_solver: Dict[str, Any] = request.get("solver")
solver: Optional[Solver]
if maybe_marshalled_solver is None:
solver = None
else:
solver = Solver(info=maybe_marshalled_solver)
async with self.service.wallet_state_manager.lock:
peer: Optional[WSChiaConnection] = self.service.get_full_node_peer()
if peer is None:
raise ValueError("No peer connected")
result = await self.service.wallet_state_manager.trade_manager.respond_to_offer(
offer, peer, fee=fee, min_coin_amount=min_coin_amount
offer, peer, fee=fee, min_coin_amount=min_coin_amount, solver=solver
)
if not result[0]:
raise ValueError(result[2])
@ -2019,3 +2055,262 @@ class WalletRpcApi:
"state": state.to_json_dict(),
"unconfirmed_transactions": unconfirmed_transactions,
}
##########################################################################################
# DataLayer Wallet
##########################################################################################
async def create_new_dl(self, request) -> Dict:
"""Initialize the DataLayer Wallet (only one can exist)"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
dl_wallet = wallet
break
else:
async with self.service.wallet_state_manager.lock:
dl_wallet = await DataLayerWallet.create_new_dl_wallet(
self.service.wallet_state_manager,
self.service.wallet_state_manager.main_wallet,
)
try:
async with self.service.wallet_state_manager.lock:
dl_tx, std_tx, launcher_id = await dl_wallet.generate_new_reporter(
bytes32.from_hexstr(request["root"]), fee=request.get("fee", uint64(0))
)
await self.service.wallet_state_manager.add_pending_transaction(dl_tx)
await self.service.wallet_state_manager.add_pending_transaction(std_tx)
except ValueError as e:
log.error(f"Error while generating new reporter {e}")
return {"success": False, "error": str(e)}
return {
"success": True,
"transactions": [tx.to_json_dict_convenience(self.service.config) for tx in (dl_tx, std_tx)],
"launcher_id": launcher_id,
}
async def dl_track_new(self, request) -> Dict:
"""Initialize the DataLayer Wallet (only one can exist)"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
peer: Optional[WSChiaConnection] = self.service.get_full_node_peer()
if peer is None:
raise ValueError("No peer connected")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
dl_wallet = wallet
break
else:
async with self.service.wallet_state_manager.lock:
dl_wallet = await DataLayerWallet.create_new_dl_wallet(
self.service.wallet_state_manager,
self.service.wallet_state_manager.main_wallet,
)
await dl_wallet.track_new_launcher_id(bytes32.from_hexstr(request["launcher_id"]), peer)
return {}
async def dl_stop_tracking(self, request) -> Dict:
"""Initialize the DataLayer Wallet (only one can exist)"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
dl_wallet = self.service.wallet_state_manager.get_dl_wallet()
if dl_wallet is None:
raise ValueError("The DataLayer wallet has not been initialized")
await dl_wallet.stop_tracking_singleton(bytes32.from_hexstr(request["launcher_id"]))
return {}
async def dl_latest_singleton(self, request) -> Dict:
"""Get the singleton record for the latest singleton of a launcher ID"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
only_confirmed = request.get("only_confirmed")
if only_confirmed is None:
only_confirmed = False
record = await wallet.get_latest_singleton(bytes32.from_hexstr(request["launcher_id"]), only_confirmed)
return {"singleton": None if record is None else record.to_json_dict()}
raise ValueError("No DataLayer wallet has been initialized")
async def dl_singletons_by_root(self, request) -> Dict:
"""Get the singleton records that contain the specified root"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for wallet in self.service.wallet_state_manager.wallets.values():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
records = await wallet.get_singletons_by_root(
bytes32.from_hexstr(request["launcher_id"]), bytes32.from_hexstr(request["root"])
)
records_json = [rec.to_json_dict() for rec in records]
return {"singletons": records_json}
raise ValueError("No DataLayer wallet has been initialized")
async def dl_update_root(self, request) -> Dict:
"""Get the singleton record for the latest singleton of a launcher ID"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
async with self.service.wallet_state_manager.lock:
records = await wallet.create_update_state_spend(
bytes32.from_hexstr(request["launcher_id"]),
bytes32.from_hexstr(request["new_root"]),
fee=uint64(request.get("fee", 0)),
)
for record in records:
await self.service.wallet_state_manager.add_pending_transaction(record)
return {"tx_record": records[0].to_json_dict_convenience(self.service.config)}
raise ValueError("No DataLayer wallet has been initialized")
async def dl_update_multiple(self, request) -> Dict:
"""Update multiple singletons with new merkle roots"""
if self.service.wallet_state_manager is None:
return {"success": False, "error": "not_initialized"}
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
async with self.service.wallet_state_manager.lock:
# TODO: This method should optionally link the singletons with announcements.
# Otherwise spends are vulnerable to signature subtraction.
tx_records: List[TransactionRecord] = []
for launcher, root in request["updates"].items():
records = await wallet.create_update_state_spend(
bytes32.from_hexstr(launcher), bytes32.from_hexstr(root)
)
tx_records.extend(records)
# Now that we have all the txs, we need to aggregate them all into just one spend
modified_txs: List[TransactionRecord] = []
aggregate_spend = SpendBundle([], G2Element())
for tx in tx_records:
if tx.spend_bundle is not None:
aggregate_spend = SpendBundle.aggregate([aggregate_spend, tx.spend_bundle])
modified_txs.append(dataclasses.replace(tx, spend_bundle=None))
modified_txs[0] = dataclasses.replace(modified_txs[0], spend_bundle=aggregate_spend)
for tx in modified_txs:
await self.service.wallet_state_manager.add_pending_transaction(tx)
return {"tx_records": [rec.to_json_dict_convenience(self.service.config) for rec in modified_txs]}
raise ValueError("No DataLayer wallet has been initialized")
async def dl_history(self, request) -> Dict:
"""Get the singleton record for the latest singleton of a launcher ID"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
additional_kwargs = {}
if "min_generation" in request:
additional_kwargs["min_generation"] = uint32(request["min_generation"])
if "max_generation" in request:
additional_kwargs["max_generation"] = uint32(request["max_generation"])
if "num_results" in request:
additional_kwargs["num_results"] = uint32(request["num_results"])
history = await wallet.get_history(bytes32.from_hexstr(request["launcher_id"]), **additional_kwargs)
history_json = [rec.to_json_dict() for rec in history]
return {"history": history_json, "count": len(history_json)}
raise ValueError("No DataLayer wallet has been initialized")
async def dl_owned_singletons(self, request) -> Dict:
"""Get all owned singleton records"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
break
else:
raise ValueError("No DataLayer wallet has been initialized")
singletons = await wallet.get_owned_singletons()
singletons_json = [singleton.to_json_dict() for singleton in singletons]
return {"singletons": singletons_json, "count": len(singletons_json)}
async def dl_get_mirrors(self, request) -> Dict:
"""Get all of the mirrors for a specific singleton"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
break
else:
raise ValueError("No DataLayer wallet has been initialized")
mirrors_json = []
for mirror in await wallet.get_mirrors_for_launcher(bytes32.from_hexstr(request["launcher_id"])):
mirrors_json.append(mirror.to_json_dict())
return {"mirrors": mirrors_json}
async def dl_new_mirror(self, request) -> Dict:
"""Add a new on chain message for a specific singleton"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
dl_wallet = wallet
break
else:
raise ValueError("No DataLayer wallet has been initialized")
async with self.service.wallet_state_manager.lock:
txs = await dl_wallet.create_new_mirror(
bytes32.from_hexstr(request["launcher_id"]),
request["amount"],
[bytes(url, "utf8") for url in request["urls"]],
fee=request.get("fee", uint64(0)),
)
for tx in txs:
await self.service.wallet_state_manager.add_pending_transaction(tx)
return {
"transactions": [tx.to_json_dict_convenience(self.service.config) for tx in txs],
}
async def dl_delete_mirror(self, request) -> Dict:
"""Remove an existing mirror for a specific singleton"""
if self.service.wallet_state_manager is None:
raise ValueError("The wallet service is not currently initialized")
peer: Optional[WSChiaConnection] = self.service.get_full_node_peer()
if peer is None:
raise ValueError("No peer connected")
for _, wallet in self.service.wallet_state_manager.wallets.items():
if WalletType(wallet.type()) == WalletType.DATA_LAYER:
dl_wallet = wallet
break
else:
raise ValueError("No DataLayer wallet has been initialized")
async with self.service.wallet_state_manager.lock:
txs = await dl_wallet.delete_mirror(
bytes32.from_hexstr(request["coin_id"]),
peer,
fee=request.get("fee", uint64(0)),
)
for tx in txs:
await self.service.wallet_state_manager.add_pending_transaction(tx)
return {
"transactions": [tx.to_json_dict_convenience(self.service.config) for tx in txs],
}

View File

@ -1,5 +1,6 @@
from typing import Dict, List, Optional, Any, Tuple, Union
from chia.data_layer.data_layer_wallet import Mirror, SingletonRecord
from chia.pools.pool_wallet_info import PoolWalletInfo
from chia.rpc.rpc_client import RpcClient
from chia.types.announcement import Announcement
@ -542,13 +543,12 @@ class WalletRpcClient(RpcClient):
self,
offer_dict: Dict[Union[uint32, str], int],
driver_dict: Dict[str, Any] = None,
solver: Dict[str, Any] = None,
fee=uint64(0),
validate_only: bool = False,
min_coin_amount: uint64 = uint64(0),
) -> Tuple[Optional[Offer], TradeRecord]:
send_dict: Dict[str, int] = {}
for key in offer_dict:
send_dict[str(key)] = offer_dict[key]
send_dict: Dict[str, int] = {str(key): value for key, value in offer_dict.items()}
req = {
"offer": send_dict,
@ -558,23 +558,28 @@ class WalletRpcClient(RpcClient):
}
if driver_dict is not None:
req["driver_dict"] = driver_dict
if solver is not None:
req["solver"] = solver
res = await self.fetch("create_offer_for_ids", req)
offer: Optional[Offer] = None if validate_only else Offer.from_bech32(res["offer"])
offer_str: str = "" if offer is None else bytes(offer).hex()
return offer, TradeRecord.from_json_dict_convenience(res["trade_record"], offer_str)
async def get_offer_summary(self, offer: Offer) -> Dict[str, Dict[str, int]]:
res = await self.fetch("get_offer_summary", {"offer": offer.to_bech32()})
async def get_offer_summary(self, offer: Offer, advanced: bool = False) -> Dict[str, Dict[str, int]]:
res = await self.fetch("get_offer_summary", {"offer": offer.to_bech32(), "advanced": advanced})
return res["summary"]
async def check_offer_validity(self, offer: Offer) -> bool:
res = await self.fetch("check_offer_validity", {"offer": offer.to_bech32()})
return res["valid"]
async def take_offer(self, offer: Offer, fee=uint64(0), min_coin_amount: uint64 = uint64(0)) -> TradeRecord:
res = await self.fetch(
"take_offer", {"offer": offer.to_bech32(), "fee": fee, "min_coin_amount": min_coin_amount}
)
async def take_offer(
self, offer: Offer, solver: Dict[str, Any] = None, fee=uint64(0), min_coin_amount: uint64 = uint64(0)
) -> TradeRecord:
req = {"offer": offer.to_bech32(), "fee": fee, "min_coin_amount": min_coin_amount}
if solver is not None:
req["solver"] = solver
res = await self.fetch("take_offer", req)
return TradeRecord.from_json_dict_convenience(res["trade_record"])
async def get_offer(self, trade_id: bytes32, file_contents: bool = False) -> TradeRecord:
@ -722,3 +727,99 @@ class WalletRpcClient(RpcClient):
request: Dict[str, Any] = {"wallet_id": wallet_id}
response = await self.fetch("nft_get_wallet_did", request)
return response
# DataLayer
async def create_new_dl(self, root: bytes32, fee: uint64) -> Tuple[List[TransactionRecord], bytes32]:
request = {"root": root.hex(), "fee": fee}
response = await self.fetch("create_new_dl", request)
txs: List[TransactionRecord] = [
TransactionRecord.from_json_dict_convenience(tx) for tx in response["transactions"]
]
launcher_id: bytes32 = bytes32.from_hexstr(response["launcher_id"])
return txs, launcher_id
async def dl_track_new(self, launcher_id: bytes32) -> None:
request = {"launcher_id": launcher_id.hex()}
await self.fetch("dl_track_new", request)
return None
async def dl_stop_tracking(self, launcher_id: bytes32) -> None:
request = {"launcher_id": launcher_id.hex()}
await self.fetch("dl_stop_tracking", request)
return None
async def dl_latest_singleton(
self, launcher_id: bytes32, only_confirmed: bool = False
) -> Optional[SingletonRecord]:
request = {"launcher_id": launcher_id.hex(), "only_confirmed": only_confirmed}
response = await self.fetch("dl_latest_singleton", request)
return None if response["singleton"] is None else SingletonRecord.from_json_dict(response["singleton"])
async def dl_singletons_by_root(self, launcher_id: bytes32, root: bytes32) -> List[SingletonRecord]:
request = {"launcher_id": launcher_id.hex(), "root": root.hex()}
response = await self.fetch("dl_singletons_by_root", request)
return [SingletonRecord.from_json_dict(single) for single in response["singletons"]]
async def dl_update_root(self, launcher_id: bytes32, new_root: bytes32, fee: uint64) -> TransactionRecord:
request = {"launcher_id": launcher_id.hex(), "new_root": new_root.hex(), "fee": fee}
response = await self.fetch("dl_update_root", request)
return TransactionRecord.from_json_dict_convenience(response["tx_record"])
async def dl_update_multiple(self, update_dictionary: Dict[bytes32, bytes32]) -> List[TransactionRecord]:
updates_as_strings: Dict[str, str] = {}
for lid, root in update_dictionary.items():
updates_as_strings[str(lid)] = str(root)
request = {"updates": updates_as_strings}
response = await self.fetch("dl_update_multiple", request)
return [TransactionRecord.from_json_dict_convenience(tx) for tx in response["tx_records"]]
async def dl_history(
self,
launcher_id: bytes32,
min_generation: Optional[uint32] = None,
max_generation: Optional[uint32] = None,
num_results: Optional[uint32] = None,
) -> List[SingletonRecord]:
request = {"launcher_id": launcher_id.hex()}
if min_generation is not None:
request["min_generation"] = str(min_generation)
if max_generation is not None:
request["max_generation"] = str(max_generation)
if num_results is not None:
request["num_results"] = str(num_results)
response = await self.fetch("dl_history", request)
return [SingletonRecord.from_json_dict(single) for single in response["history"]]
async def dl_owned_singletons(self) -> List[SingletonRecord]:
response = await self.fetch(path="dl_owned_singletons", request_json={})
return [SingletonRecord.from_json_dict(singleton) for singleton in response["singletons"]]
async def dl_get_mirrors(self, launcher_id: bytes32) -> List[Mirror]:
response = await self.fetch(path="dl_get_mirrors", request_json={"launcher_id": launcher_id.hex()})
return [Mirror.from_json_dict(mirror) for mirror in response["mirrors"]]
async def dl_new_mirror(
self, launcher_id: bytes32, amount: uint64, urls: List[bytes], fee: uint64 = uint64(0)
) -> List[TransactionRecord]:
response = await self.fetch(
path="dl_new_mirror",
request_json={
"launcher_id": launcher_id.hex(),
"amount": amount,
"urls": [url.decode("utf8") for url in urls],
"fee": fee,
},
)
return [TransactionRecord.from_json_dict_convenience(tx) for tx in response["transactions"]]
async def dl_delete_mirror(self, coin_id: bytes32, fee: uint64 = uint64(0)) -> List[TransactionRecord]:
response = await self.fetch(
path="dl_delete_mirror",
request_json={
"coin_id": coin_id.hex(),
"fee": fee,
},
)
return [TransactionRecord.from_json_dict_convenience(tx) for tx in response["transactions"]]

View File

@ -14,6 +14,7 @@ class NodeType(IntEnum):
TIMELORD = 4
INTRODUCER = 5
WALLET = 6
DATA_LAYER = 7
class Delivery(IntEnum):

View File

@ -123,6 +123,7 @@ class ChiaServer:
self.connection_by_type: Dict[NodeType, Dict[bytes32, WSChiaConnection]] = {
NodeType.FULL_NODE: {},
NodeType.DATA_LAYER: {},
NodeType.WALLET: {},
NodeType.HARVESTER: {},
NodeType.FARMER: {},

View File

@ -0,0 +1,82 @@
import logging
import pathlib
import sys
from typing import Any, Dict, Optional
from chia.data_layer.data_layer import DataLayer
from chia.data_layer.data_layer_api import DataLayerAPI
from chia.rpc.data_layer_rpc_api import DataLayerRpcApi
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.server.outbound_message import NodeType
from chia.server.start_service import RpcInfo, Service, async_run
from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config, load_config_cli
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.ints import uint16
# See: https://bugs.python.org/issue29288
"".encode("idna")
SERVICE_NAME = "data_layer"
log = logging.getLogger(__name__)
# TODO: Review need for config and if retained then hint it properly.
def create_data_layer_service(
root_path: pathlib.Path,
config: Dict[str, Any],
connect_to_daemon: bool = True,
) -> Service:
service_config = config[SERVICE_NAME]
self_hostname = config["self_hostname"]
wallet_rpc_port = service_config["wallet_peer"]["port"]
wallet_rpc_init = WalletRpcClient.create(self_hostname, uint16(wallet_rpc_port), root_path, config)
data_layer = DataLayer(config=service_config, root_path=root_path, wallet_rpc_init=wallet_rpc_init)
api = DataLayerAPI(data_layer)
network_id = service_config["selected_network"]
rpc_port = service_config.get("rpc_port")
rpc_info: Optional[RpcInfo] = None
if rpc_port is not None:
rpc_info = (DataLayerRpcApi, service_config["rpc_port"])
return Service(
server_listen_ports=[service_config["port"]],
root_path=root_path,
config=config,
node=data_layer,
# TODO: not for peers...
peer_api=api,
node_type=NodeType.DATA_LAYER,
# TODO: no publicly advertised port, at least not yet
advertised_port=service_config["port"],
service_name=SERVICE_NAME,
network_id=network_id,
max_request_body_size=service_config.get("rpc_server_max_request_body_size", 26214400),
rpc_info=rpc_info,
connect_to_daemon=connect_to_daemon,
)
async def async_main() -> int:
# TODO: refactor to avoid the double load
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
service_config = load_config_cli(DEFAULT_ROOT_PATH, "config.yaml", SERVICE_NAME)
config[SERVICE_NAME] = service_config
initialize_logging(
service_name=SERVICE_NAME,
logging_config=service_config["logging"],
root_path=DEFAULT_ROOT_PATH,
)
service = create_data_layer_service(DEFAULT_ROOT_PATH, config)
await service.setup_process_global_state()
await service.run()
return 0
def main() -> int:
return async_run(async_main())
if __name__ == "__main__":
sys.exit(main())

View File

@ -1,20 +1,64 @@
import asyncio
import itertools
import time
from typing import Dict, List, Optional, Tuple
from typing import Collection, Dict, Iterator, List, Optional, Set, Tuple
from chia.consensus.block_record import BlockRecord
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.consensus.multiprocess_validation import PreValidationResult
from chia.full_node.full_node import FullNode
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols.full_node_protocol import RespondBlock
from chia.simulator.block_tools import BlockTools
from chia.simulator.simulator_protocol import FarmNewBlockProtocol, GetAllCoinsProtocol, ReorgProtocol
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.full_block import FullBlock
from chia.util.api_decorators import api_request
from chia.util.config import lock_and_load_config, save_config
from chia.util.ints import uint8, uint32, uint128
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.wallet_types import AmountWithPuzzlehash
from chia.wallet.wallet import Wallet
def backoff_times(
initial: float = 0.001,
final: float = 0.100,
time_to_final: float = 0.5,
clock=time.monotonic,
) -> Iterator[float]:
# initially implemented as a simple linear backoff
start = clock()
delta = 0
result_range = final - initial
while True:
yield min(final, initial + ((delta / time_to_final) * result_range))
delta = clock() - start
async def wait_for_coins_in_wallet(coins: Set[Coin], wallet: Wallet):
"""Wait until all of the specified coins are simultaneously reported as spendable
in by the wallet.
Arguments:
coins: The coins expected to be received.
wallet: The wallet expected to receive the coins.
"""
while True:
spendable_wallet_coin_records = await wallet.wallet_state_manager.get_spendable_coins_for_wallet(
wallet_id=wallet.id()
)
spendable_wallet_coins = {record.coin for record in spendable_wallet_coin_records}
if coins.issubset(spendable_wallet_coins):
return
await asyncio.sleep(0.050)
class FullNodeSimulator(FullNodeAPI):
@ -23,7 +67,7 @@ class FullNodeSimulator(FullNodeAPI):
self.bt = block_tools
self.full_node = full_node
self.config = config
self.time_per_block = None
self.time_per_block: Optional[float] = None
self.full_node.simulator_transaction_callback = self.autofarm_transaction
self.use_current_time: bool = self.config.get("simulator", {}).get("use_current_time", False)
self.auto_farm: bool = self.config.get("simulator", {}).get("auto_farm", False)
@ -113,7 +157,9 @@ class FullNodeSimulator(FullNodeAPI):
return ph_total_amount
@api_request
async def farm_new_transaction_block(self, request: FarmNewBlockProtocol, force_wait_for_timestamp: bool = False):
async def farm_new_transaction_block(
self, request: FarmNewBlockProtocol, force_wait_for_timestamp: bool = False
) -> FullBlock:
async with self.full_node._blockchain_lock_high_priority:
self.log.info("Farming new block!")
current_blocks = await self.get_all_full_blocks()
@ -162,6 +208,7 @@ class FullNodeSimulator(FullNodeAPI):
)
rr = RespondBlock(more[-1])
await self.full_node.respond_block(rr)
return more[-1]
@api_request
async def farm_new_block(self, request: FarmNewBlockProtocol, force_wait_for_timestamp: bool = False):
@ -235,3 +282,224 @@ class FullNodeSimulator(FullNodeAPI):
for block in more_blocks:
await self.full_node.respond_block(RespondBlock(block))
async def process_blocks(self, count: int, farm_to: bytes32 = bytes32([0] * 32)) -> int:
"""Process the requested number of blocks including farming to the passed puzzle
hash. Note that the rewards for the last block will not have been processed.
Consider `.farm_blocks()` or `.farm_rewards()` if the goal is to receive XCH at
an address.
Arguments:
count: The number of blocks to process.
farm_to: The puzzle hash to farm the block rewards to.
Returns:
The total number of reward mojos for the processed blocks.
"""
rewards = 0
height = uint32(0)
if count == 0:
return rewards
for _ in range(count):
block: FullBlock = await self.farm_new_transaction_block(FarmNewBlockProtocol(farm_to))
height = uint32(block.height)
rewards += calculate_pool_reward(height) + calculate_base_farmer_reward(height)
while True:
peak_height = self.full_node.blockchain.get_peak_height()
if peak_height is None:
raise RuntimeError("Peak height still None after processing at least one block")
if peak_height >= height:
break
await asyncio.sleep(0.050)
return rewards
async def farm_blocks(self, count: int, wallet: Wallet) -> int:
"""Farm the requested number of blocks to the passed wallet. This will
process additional blocks as needed to process the reward transactions
and also wait for the rewards to be present in the wallet.
Arguments:
count: The number of blocks to farm.
wallet: The wallet to farm the block rewards to.
Returns:
The total number of reward mojos farmed to the requested address.
"""
if count == 0:
return 0
rewards = await self.process_blocks(count=count, farm_to=await wallet.get_new_puzzlehash())
await self.process_blocks(count=1)
peak_height = self.full_node.blockchain.get_peak_height()
if peak_height is None:
raise RuntimeError("Peak height still None after processing at least one block")
coin_records = await self.full_node.coin_store.get_coins_added_at_height(height=peak_height)
block_reward_coins = {record.coin for record in coin_records}
await wait_for_coins_in_wallet(coins=block_reward_coins, wallet=wallet)
return rewards
async def farm_rewards(self, amount: int, wallet: Wallet) -> int:
"""Farm at least the requested amount of mojos to the passed wallet. Extra
mojos will be received based on the block rewards at the present block height.
The rewards will be present in the wall before returning.
Arguments:
amount: The minimum number of mojos to farm.
wallet: The wallet to farm the block rewards to.
Returns:
The total number of reward mojos farmed to the requested wallet.
"""
rewards = 0
if amount == 0:
return rewards
height_before = self.full_node.blockchain.get_peak_height()
if height_before is None:
height_before = uint32(0)
for count in itertools.count(1):
height = uint32(height_before + count)
rewards += calculate_pool_reward(height) + calculate_base_farmer_reward(height)
if rewards >= amount:
await self.farm_blocks(count=count, wallet=wallet)
return rewards
raise Exception("internal error")
async def wait_transaction_records_entered_mempool(self, records: Collection[TransactionRecord]) -> None:
"""Wait until the transaction records have entered the mempool. Transaction
records with no spend bundle are ignored.
Arguments:
records: The transaction records to wait for.
"""
ids_to_check: Set[bytes32] = set()
for record in records:
if record.spend_bundle is None:
continue
ids_to_check.add(record.spend_bundle.name())
while True:
found = set()
for spend_bundle_name in ids_to_check:
tx = self.full_node.mempool_manager.get_spendbundle(spend_bundle_name)
if tx is not None:
found.add(spend_bundle_name)
ids_to_check = ids_to_check.difference(found)
if len(ids_to_check) == 0:
return
await asyncio.sleep(0.050)
async def process_transaction_records(self, records: Collection[TransactionRecord]) -> None:
"""Process the specified transaction records and wait until they have been
included in a block.
Arguments:
records: The transaction records to process.
"""
coins_to_wait_for: Set[Coin] = set()
for record in records:
if record.spend_bundle is None:
continue
coins_to_wait_for.update(record.spend_bundle.additions())
coin_store = self.full_node.coin_store
await self.wait_transaction_records_entered_mempool(records=records)
while True:
await self.process_blocks(count=1)
found: Set[Coin] = set()
for coin in coins_to_wait_for:
# TODO: is this the proper check?
if await coin_store.get_coin_record(coin.name()) is not None:
found.add(coin)
coins_to_wait_for = coins_to_wait_for.difference(found)
if len(coins_to_wait_for) == 0:
return
async def create_coins_with_amounts(
self,
amounts: List[int],
wallet: Wallet,
per_transaction_record_group: int = 50,
) -> Set[Coin]:
"""Create coins with the requested amount. This is useful when you need a
bunch of coins for a test and don't need to farm that many.
Arguments:
amounts: A list with entries of mojo amounts corresponding to each
coin to create.
wallet: The wallet to send the new coins to.
per_transaction_record_group: The maximum number of coins to create in each
transaction record.
Returns:
A set of the generated coins. Note that this does not include any change
coins that were created.
"""
invalid_amounts = [amount for amount in amounts if amount <= 0]
if len(invalid_amounts) > 0:
invalid_amounts_string = ", ".join(str(amount) for amount in invalid_amounts)
raise Exception(f"Coins must have a positive value, request included: {invalid_amounts_string}")
if len(amounts) == 0:
return set()
# TODO: This is a poor duplication of code in
# WalletRpcApi.create_signed_transaction(). Perhaps it should be moved
# somewhere more reusable.
outputs: List[AmountWithPuzzlehash] = []
for amount in amounts:
puzzle_hash = await wallet.get_new_puzzlehash()
outputs.append({"puzzlehash": puzzle_hash, "amount": uint64(amount), "memos": []})
transaction_records: List[TransactionRecord] = []
outputs_iterator = iter(outputs)
while True:
# The outputs iterator must be second in the zip() call otherwise we lose
# an element when reaching the end of the range object.
outputs_group = [output for _, output in zip(range(per_transaction_record_group), outputs_iterator)]
if len(outputs_group) > 0:
async with wallet.wallet_state_manager.lock:
tx = await wallet.generate_signed_transaction(
amount=outputs_group[0]["amount"],
puzzle_hash=outputs_group[0]["puzzlehash"],
primaries=outputs_group[1:],
)
await wallet.push_transaction(tx=tx)
transaction_records.append(tx)
else:
break
await self.process_transaction_records(records=transaction_records)
output_coins = {coin for transaction_record in transaction_records for coin in transaction_record.additions}
puzzle_hashes = {output["puzzlehash"] for output in outputs}
change_coins = {coin for coin in output_coins if coin.puzzle_hash not in puzzle_hashes}
coins_to_receive = output_coins - change_coins
await wait_for_coins_in_wallet(coins=coins_to_receive, wallet=wallet)
return coins_to_receive

View File

@ -38,6 +38,26 @@ class DBWrapper:
async def commit_transaction(self) -> None:
await self.db.commit()
@contextlib.asynccontextmanager
async def locked_transaction(self, *, lock=True):
# TODO: look into contextvars perhaps instead of this manual lock tracking
if not lock:
yield
return
# TODO: add a lock acquisition timeout
# maybe https://docs.python.org/3/library/asyncio-task.html#asyncio.wait_for
async with self.lock:
await self.begin_transaction()
try:
yield
except BaseException:
await self.rollback_transaction()
raise
else:
await self.commit_transaction()
async def execute_fetchone(
c: aiosqlite.Connection, sql: str, parameters: Iterable[Any] = None

View File

@ -545,6 +545,43 @@ wallet:
# Interval to resend unconfirmed transactions, even if previously accepted into Mempool
tx_resend_timeout_secs: 1800
data_layer:
# TODO: consider name
# TODO: organize consistently with other sections
# TODO: shouldn't we not need this since we have no actual public interface (yet)?
port: 8561
wallet_peer:
host: localhost
port: 9256
database_path: "data_layer/db/data_layer_CHALLENGE.sqlite"
# The location where the server files will be stored.
server_files_location: "data_layer/db/server_files_location_CHALLENGE"
# Data for running a data layer server.
host_ip: 0.0.0.0
host_port: 8575
# Switch this to True if we want to run the server.
run_server: True
# Data for running a data layer client.
manage_data_interval: 60
selected_network: *selected_network
# If True, starts an RPC server at the following port
start_rpc_server: True
# TODO: what considerations are there in choosing this?
rpc_port: 8562
rpc_server_max_request_body_size: 26214400
fee: 1000000000
logging: *logging
# TODO: which of these are really appropriate?
ssl:
private_crt: "config/ssl/data_layer/private_data_layer.crt"
private_key: "config/ssl/data_layer/private_data_layer.key"
public_crt: "config/ssl/data_layer/public_data_layer.crt"
public_key: "config/ssl/data_layer/public_data_layer.key"
simulator:
# Should the simulator farm a block whenever a transaction is in mempool

View File

@ -1,7 +1,11 @@
from typing import KeysView, Generator
SERVICES_FOR_GROUP = {
"all": "chia_harvester chia_timelord_launcher chia_timelord chia_farmer chia_full_node chia_wallet".split(),
"all": (
"chia_harvester chia_timelord_launcher chia_timelord chia_farmer chia_full_node chia_wallet chia_data_layer"
).split(),
# TODO: should this be `data_layer`?
"data": "chia_wallet chia_data_layer".split(),
"node": "chia_full_node".split(),
"harvester": "chia_harvester".split(),
"farmer": "chia_harvester chia_farmer chia_full_node chia_wallet".split(),

View File

@ -23,6 +23,8 @@ CERT_CONFIG_KEY_PATHS = [
"farmer:ssl:public_crt",
"full_node:ssl:private_crt",
"full_node:ssl:public_crt",
"data_layer:ssl:private_crt",
"data_layer:ssl:public_crt",
"harvester:chia_ssl_ca:crt",
"harvester:private_ssl_ca:crt",
"harvester:ssl:private_crt",

View File

@ -816,7 +816,7 @@ class CATWallet:
and puzzle_driver.also() is None
)
def get_puzzle_info(self, asset_id: bytes32) -> PuzzleInfo:
async def get_puzzle_info(self, asset_id: bytes32) -> PuzzleInfo:
return PuzzleInfo({"type": AssetType.CAT.value, "tail": "0x" + self.get_asset_id()})
async def get_coins_to_offer(

View File

View File

@ -0,0 +1,99 @@
from typing import Iterator, List, Tuple, Union
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.condition_opcodes import ConditionOpcode
from chia.util.ints import uint64
from chia.wallet.nft_wallet.nft_puzzles import NFT_STATE_LAYER_MOD, create_nft_layer_puzzle_with_curry_params
from chia.wallet.puzzles.load_clvm import load_clvm
# from chia.types.condition_opcodes import ConditionOpcode
# from chia.wallet.util.merkle_tree import MerkleTree, TreeType
ACS_MU = Program.to(11) # returns the third argument a.k.a the full solution
ACS_MU_PH = ACS_MU.get_tree_hash()
SINGLETON_TOP_LAYER_MOD = load_clvm("singleton_top_layer_v1_1.clvm")
SINGLETON_LAUNCHER = load_clvm("singleton_launcher.clvm")
GRAFTROOT_DL_OFFERS = load_clvm("graftroot_dl_offers.clvm")
P2_PARENT = load_clvm("p2_parent.clvm")
def create_host_fullpuz(innerpuz: Union[Program, bytes32], current_root: bytes32, genesis_id: bytes32) -> Program:
db_layer = create_host_layer_puzzle(innerpuz, current_root)
mod_hash = SINGLETON_TOP_LAYER_MOD.get_tree_hash()
singleton_struct = Program.to((mod_hash, (genesis_id, SINGLETON_LAUNCHER.get_tree_hash())))
return SINGLETON_TOP_LAYER_MOD.curry(singleton_struct, db_layer)
def create_host_layer_puzzle(innerpuz: Union[Program, bytes32], current_root: bytes32) -> Program:
# some hard coded metadata formatting and metadata updater for now
return create_nft_layer_puzzle_with_curry_params(
Program.to((current_root, None)),
ACS_MU_PH,
# TODO: the nft driver doesn't like the Union yet, but changing that is out of scope for me rn - Quex
innerpuz, # type: ignore
)
def match_dl_singleton(puzzle: Program) -> Tuple[bool, Iterator[Program]]:
"""
Given a puzzle test if it's a CAT and, if it is, return the curried arguments
"""
mod, singleton_curried_args = puzzle.uncurry()
if mod == SINGLETON_TOP_LAYER_MOD:
mod, dl_curried_args = singleton_curried_args.at("rf").uncurry()
if mod == NFT_STATE_LAYER_MOD and dl_curried_args.at("rrf") == ACS_MU_PH:
launcher_id = singleton_curried_args.at("frf")
root = dl_curried_args.at("rff")
innerpuz = dl_curried_args.at("rrrf")
return True, iter((innerpuz, root, launcher_id))
return False, iter(())
def launch_solution_to_singleton_info(launch_solution: Program) -> Tuple[bytes32, uint64, bytes32, bytes32]:
solution = launch_solution.as_python()
try:
full_puzzle_hash = bytes32(solution[0])
amount = uint64(int.from_bytes(solution[1], "big"))
root = bytes32(solution[2][0])
inner_puzzle_hash = bytes32(solution[2][1])
except (IndexError, TypeError):
raise ValueError("Launcher is not a data layer launcher")
return full_puzzle_hash, amount, root, inner_puzzle_hash
def launcher_to_struct(launcher_id: bytes32) -> Program:
struct: Program = Program.to(
(SINGLETON_TOP_LAYER_MOD.get_tree_hash(), (launcher_id, SINGLETON_LAUNCHER.get_tree_hash()))
)
return struct
def create_graftroot_offer_puz(
launcher_ids: List[bytes32], values_to_prove: List[List[bytes32]], inner_puzzle: Program
) -> Program:
return GRAFTROOT_DL_OFFERS.curry(
inner_puzzle,
[launcher_to_struct(launcher) for launcher in launcher_ids],
[NFT_STATE_LAYER_MOD.get_tree_hash()] * len(launcher_ids),
values_to_prove,
)
def create_mirror_puzzle() -> Program:
return P2_PARENT.curry(Program.to(1))
def get_mirror_info(parent_puzzle: Program, parent_solution: Program) -> Tuple[bytes32, List[bytes]]:
conditions = parent_puzzle.run(parent_solution)
for condition in conditions.as_iter():
if (
condition.first().as_python() == ConditionOpcode.CREATE_COIN
and condition.at("rf").as_python() == create_mirror_puzzle().get_tree_hash()
):
memos: List[bytes] = condition.at("rrrf").as_python()
launcher_id = bytes32(memos[0])
return launcher_id, [url for url in memos[1:]]
raise ValueError("The provided puzzle and solution do not create a mirror coin")

View File

@ -35,19 +35,6 @@ def create_nft_layer_puzzle_with_curry_params(
METADATA
METADATA_UPDATER_PUZZLE_HASH
INNER_PUZZLE"""
log.debug(
"Creating nft layer puzzle curry: mod_hash: %s, metadata: %r, metadata_hash: %s",
NFT_STATE_LAYER_MOD_HASH,
metadata,
metadata_updater_hash,
)
log.debug(
"Currying with: %s %s %s %s",
NFT_STATE_LAYER_MOD_HASH,
inner_puzzle.get_tree_hash(),
metadata_updater_hash,
metadata.get_tree_hash(),
)
return NFT_STATE_LAYER_MOD.curry(NFT_STATE_LAYER_MOD_HASH, metadata, metadata_updater_hash, inner_puzzle)

View File

@ -514,7 +514,7 @@ class NFTWallet:
return coin
return None
def get_puzzle_info(self, nft_id: bytes32) -> PuzzleInfo:
async def get_puzzle_info(self, nft_id: bytes32) -> PuzzleInfo:
nft_coin: Optional[NFTCoinInfo] = self.get_nft(nft_id)
if nft_coin is None:
raise ValueError("An asset ID was specified that this wallet doesn't track")

View File

@ -106,7 +106,7 @@ def decode_info_value(cls: Any, value: Any) -> Any:
else:
atom: bytes = expression.atom
typ = type_for_atom(atom)
if typ == Type.QUOTES:
if typ == Type.QUOTES and value[0:2] != "0x":
return bytes(atom).decode("utf8")
elif typ == Type.INT:
return int_from_bytes(atom)

View File

@ -0,0 +1,100 @@
(mod
(
INNER_PUZZLE
SINGLETON_STRUCTS
METADATA_LAYER_HASHES
VALUES_TO_PROVE ; this is a list of BRANCHES in the merkle tree to prove (so as to prove a whole subtree)
proofs_of_inclusion
new_metadatas ; (root . etc)
new_metadata_updaters
new_inner_puzs
inner_solution
)
(include condition_codes.clvm)
(include merkle_utils.clib)
(include curry-and-treehash.clinc)
(defmacro assert items
(if (r items)
(list if (f items) (c assert (r items)) (q . (x)))
(f items)
)
)
(defun-inline construct_singleton (SINGLETON_STRUCT METADATA_LAYER_HASH new_metadata new_metadata_updater new_inner_puz)
(puzzle-hash-of-curried-function (f SINGLETON_STRUCT)
(puzzle-hash-of-curried-function METADATA_LAYER_HASH
new_inner_puz
(sha256tree new_metadata_updater)
(sha256tree new_metadata)
(sha256tree METADATA_LAYER_HASH)
)
(sha256tree SINGLETON_STRUCT)
)
)
(defun verify_proofs (new_root VALUES_TO_PROVE proofs_of_inclusion)
(if proofs_of_inclusion
(assert (= new_root (simplify_merkle_proof_after_leaf (f VALUES_TO_PROVE) (f proofs_of_inclusion)))
; then
(verify_proofs new_root (r VALUES_TO_PROVE) (r proofs_of_inclusion))
)
1
)
)
(defun loop_over_curried_params
(
SINGLETON_STRUCTS
METADATA_LAYER_HASHES
VALUES_TO_PROVE
proofs_of_inclusion
new_metadatas
new_metadata_updaters
new_inner_puzs
conditions
)
(if SINGLETON_STRUCTS
(assert (verify_proofs (f (f new_metadatas)) (f VALUES_TO_PROVE) (f proofs_of_inclusion))
; then
(loop_over_curried_params
(r SINGLETON_STRUCTS)
(r METADATA_LAYER_HASHES)
(r VALUES_TO_PROVE)
(r proofs_of_inclusion)
(r new_metadatas)
(r new_metadata_updaters)
(r new_inner_puzs)
(c
(list
ASSERT_PUZZLE_ANNOUNCEMENT
(sha256
(construct_singleton (f SINGLETON_STRUCTS) (f METADATA_LAYER_HASHES) (f new_metadatas) (f new_metadata_updaters) (f new_inner_puzs))
'$'
)
)
conditions
)
)
)
conditions
)
)
(if proofs_of_inclusion
(loop_over_curried_params
SINGLETON_STRUCTS
METADATA_LAYER_HASHES
VALUES_TO_PROVE
proofs_of_inclusion
new_metadatas
new_metadata_updaters
new_inner_puzs
(a INNER_PUZZLE inner_solution)
)
; You may want to run the puzzle without a raise to examine conditions so we'll make a "blessed" way to fail
(c (list ASSERT_MY_AMOUNT -1) (a INNER_PUZZLE inner_solution))
)
)

View File

@ -0,0 +1 @@
ff02ffff01ff02ffff03ff5fffff01ff02ff3affff04ff02ffff04ff0bffff04ff17ffff04ff2fffff04ff5fffff04ff81bfffff04ff82017fffff04ff8202ffffff04ffff02ff05ff8205ff80ff8080808080808080808080ffff01ff04ffff04ff10ffff01ff81ff8080ffff02ff05ff8205ff808080ff0180ffff04ffff01ffffff49ff3f02ff04ff0101ffff02ffff02ffff03ff05ffff01ff02ff2affff04ff02ffff04ff0dffff04ffff0bff12ffff0bff2cff1480ffff0bff12ffff0bff12ffff0bff2cff3c80ff0980ffff0bff12ff0bffff0bff2cff8080808080ff8080808080ffff010b80ff0180ff02ffff03ff05ffff01ff02ffff03ffff02ff3effff04ff02ffff04ff82011fffff04ff27ffff04ff4fff808080808080ffff01ff02ff3affff04ff02ffff04ff0dffff04ff1bffff04ff37ffff04ff6fffff04ff81dfffff04ff8201bfffff04ff82037fffff04ffff04ffff04ff28ffff04ffff0bffff02ff26ffff04ff02ffff04ff11ffff04ffff02ff26ffff04ff02ffff04ff13ffff04ff82027fffff04ffff02ff36ffff04ff02ffff04ff82013fff80808080ffff04ffff02ff36ffff04ff02ffff04ff819fff80808080ffff04ffff02ff36ffff04ff02ffff04ff13ff80808080ff8080808080808080ffff04ffff02ff36ffff04ff02ffff04ff09ff80808080ff808080808080ffff012480ff808080ff8202ff80ff8080808080808080808080ffff01ff088080ff0180ffff018202ff80ff0180ffffff0bff12ffff0bff2cff3880ffff0bff12ffff0bff12ffff0bff2cff3c80ff0580ffff0bff12ffff02ff2affff04ff02ffff04ff07ffff04ffff0bff2cff2c80ff8080808080ffff0bff2cff8080808080ff02ffff03ffff07ff0580ffff01ff0bffff0102ffff02ff36ffff04ff02ffff04ff09ff80808080ffff02ff36ffff04ff02ffff04ff0dff8080808080ffff01ff0bffff0101ff058080ff0180ffff02ffff03ff1bffff01ff02ff2effff04ff02ffff04ffff02ffff03ffff18ffff0101ff1380ffff01ff0bffff0102ff2bff0580ffff01ff0bffff0102ff05ff2b8080ff0180ffff04ffff04ffff17ff13ffff0181ff80ff3b80ff8080808080ffff010580ff0180ff02ffff03ff17ffff01ff02ffff03ffff09ff05ffff02ff2effff04ff02ffff04ff13ffff04ff27ff808080808080ffff01ff02ff3effff04ff02ffff04ff05ffff04ff1bffff04ff37ff808080808080ffff01ff088080ff0180ffff01ff010180ff0180ff018080

View File

@ -0,0 +1 @@
0893e36a88c064fddfa6f8abdb42c044584a98cb4273b80cccc83b4867b701a1

View File

@ -0,0 +1,18 @@
(
(defun simplify_merkle_proof_after_leaf (leaf_hash (bitpath . hashes_path))
(if hashes_path
(simplify_merkle_proof_after_leaf
(if (logand 1 bitpath)
(sha256 0x02 (f hashes_path) leaf_hash)
(sha256 0x02 leaf_hash (f hashes_path))
)
(c (lsh bitpath -1) (r hashes_path))
)
leaf_hash
)
)
(defun-inline simplify_merkle_proof (leaf proof)
(simplify_merkle_proof_after_leaf (sha256 0x01 leaf) proof)
)
)

View File

@ -0,0 +1,19 @@
(mod
(
MORPHER ; For no morphing, 1
parent_parent_id
parent_inner_puz
parent_amount
parent_solution
)
(include condition_codes.clvm)
(include curry-and-treehash.clinc)
(c
(list ASSERT_MY_PARENT_ID
(calculate_coin_id parent_parent_id (a MORPHER (sha256tree parent_inner_puz)) parent_amount)
)
(a parent_inner_puz parent_solution)
)
)

View File

@ -0,0 +1 @@
ff02ffff01ff04ffff04ff08ffff04ffff02ff0affff04ff02ffff04ff0bffff04ffff02ff05ffff02ff0effff04ff02ffff04ff17ff8080808080ffff04ff2fff808080808080ff808080ffff02ff17ff5f8080ffff04ffff01ffff4720ffff02ffff03ffff22ffff09ffff0dff0580ff0c80ffff09ffff0dff0b80ff0c80ffff15ff17ffff0181ff8080ffff01ff0bff05ff0bff1780ffff01ff088080ff0180ff02ffff03ffff07ff0580ffff01ff0bffff0102ffff02ff0effff04ff02ffff04ff09ff80808080ffff02ff0effff04ff02ffff04ff0dff8080808080ffff01ff0bffff0101ff058080ff0180ff018080

View File

@ -0,0 +1 @@
b10ce2d0b18dcf8c21ddfaf55d9b9f0adcbf1e0beb55b1a8b9cad9bbff4e5f22

View File

@ -4,9 +4,9 @@ import logging
import time
import traceback
from typing import Any, Dict, List, Optional, Set, Tuple, Union
from typing_extensions import Literal
from chia.data_layer.data_layer_wallet import DataLayerWallet
from chia.protocols.wallet_protocol import CoinState
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.coin import Coin, coin_as_list
@ -16,10 +16,11 @@ from chia.types.spend_bundle import SpendBundle
from chia.util.db_wrapper import DBWrapper2
from chia.util.hash import std_hash
from chia.util.ints import uint32, uint64
from chia.wallet.db_wallet.db_wallet_puzzles import ACS_MU_PH
from chia.wallet.nft_wallet.nft_wallet import NFTWallet
from chia.wallet.outer_puzzles import AssetType
from chia.wallet.payment import Payment
from chia.wallet.puzzle_drivers import PuzzleInfo
from chia.wallet.puzzle_drivers import PuzzleInfo, Solver
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import NotarizedPayment, Offer
from chia.wallet.trading.trade_status import TradeStatus
@ -145,20 +146,19 @@ class TradeManager:
# Then let's filter the offer into coins that WE offered
offer = Offer.from_bytes(trade.offer)
primary_coin_ids = [c.name() for c in offer.get_primary_coins()]
primary_coin_ids = [c.name() for c in offer.bundle.removals()]
our_coin_records: List[WalletCoinRecord] = await self.wallet_state_manager.coin_store.get_multiple_coin_records(
primary_coin_ids
)
our_primary_coins: List[bytes32] = [cr.coin.name() for cr in our_coin_records]
all_settlement_payments: List[Coin] = [c for coins in offer.get_offered_coins().values() for c in coins]
our_settlement_payments: List[Coin] = list(
filter(lambda c: offer.get_root_removal(c).name() in our_primary_coins, all_settlement_payments)
our_primary_coins: List[Coin] = [cr.coin for cr in our_coin_records]
our_additions: List[Coin] = list(
filter(lambda c: offer.get_root_removal(c) in our_primary_coins, offer.bundle.additions())
)
our_settlement_ids: List[bytes32] = [c.name() for c in our_settlement_payments]
our_addition_ids: List[bytes32] = [c.name() for c in our_additions]
# And get all relevant coin states
coin_states = await self.wallet_state_manager.wallet_node.get_coin_state(
our_settlement_ids,
our_addition_ids,
peer=peer,
fork_height=fork_height,
)
@ -166,8 +166,8 @@ class TradeManager:
coin_state_names: List[bytes32] = [cs.coin.name() for cs in coin_states]
# If any of our settlement_payments were spent, this offer was a success!
if set(our_settlement_ids) & set(coin_state_names):
height = coin_states[0].spent_height
if set(our_addition_ids) == set(coin_state_names):
height = coin_states[0].created_height
await self.trade_store.set_status(trade.trade_id, TradeStatus.CONFIRMED, height)
tx_records: List[TransactionRecord] = await self.calculate_tx_records_for_offer(offer, False)
for tx in tx_records:
@ -232,7 +232,7 @@ class TradeManager:
all_txs: List[TransactionRecord] = []
fee_to_pay: uint64 = fee
for coin in Offer.from_bytes(trade.offer).get_primary_coins():
for coin in Offer.from_bytes(trade.offer).get_cancellation_coins():
wallet = await self.wallet_state_manager.get_wallet_for_coin(coin.name())
if wallet is None:
@ -398,13 +398,16 @@ class TradeManager:
self,
offer: Dict[Union[int, bytes32], int],
driver_dict: Optional[Dict[bytes32, PuzzleInfo]] = None,
solver: Optional[Solver] = None,
fee: uint64 = uint64(0),
validate_only: bool = False,
min_coin_amount: Optional[uint64] = None,
) -> Union[Tuple[Literal[True], TradeRecord, None], Tuple[Literal[False], None, str]]:
if driver_dict is None:
driver_dict = {}
result = await self._create_offer_for_ids(offer, driver_dict, fee=fee, min_coin_amount=min_coin_amount)
if solver is None:
solver = Solver({})
result = await self._create_offer_for_ids(offer, driver_dict, solver, fee=fee, min_coin_amount=min_coin_amount)
if not result[0] or result[1] is None:
raise Exception(f"Error creating offer: {result[2]}")
@ -434,6 +437,7 @@ class TradeManager:
self,
offer_dict: Dict[Union[int, bytes32], int],
driver_dict: Optional[Dict[bytes32, PuzzleInfo]] = None,
solver: Optional[Solver] = None,
fee: uint64 = uint64(0),
min_coin_amount: Optional[uint64] = None,
) -> Union[Tuple[Literal[True], Offer, None], Tuple[Literal[False], None, str]]:
@ -442,6 +446,8 @@ class TradeManager:
"""
if driver_dict is None:
driver_dict = {}
if solver is None:
solver = Solver({})
try:
coins_to_offer: Dict[Union[int, bytes32], List[Coin]] = {}
requested_payments: Dict[Optional[bytes32], List[Payment]] = {}
@ -497,7 +503,7 @@ class TradeManager:
if asset_id is not None and wallet is not None: # if this asset is not XCH
if callable(getattr(wallet, "get_puzzle_info", None)):
puzzle_driver: PuzzleInfo = wallet.get_puzzle_info(asset_id)
puzzle_driver: PuzzleInfo = await wallet.get_puzzle_info(asset_id)
if asset_id in driver_dict and driver_dict[asset_id] != puzzle_driver:
# ignore the case if we're an nft transfering the did owner
if self.check_for_owner_change_in_drivers(puzzle_driver, driver_dict[asset_id]):
@ -512,7 +518,7 @@ class TradeManager:
raise ValueError(f"Wallet for asset id {asset_id} is not properly integrated with TradeManager")
potential_special_offer: Optional[Offer] = await self.check_for_special_offer_making(
offer_dict_no_ints, driver_dict, fee, min_coin_amount
offer_dict_no_ints, driver_dict, solver, fee, min_coin_amount
)
if potential_special_offer is not None:
@ -695,9 +701,12 @@ class TradeManager:
self,
offer: Offer,
peer: WSChiaConnection,
solver: Optional[Solver] = None,
fee: uint64 = uint64(0),
min_coin_amount: Optional[uint64] = None,
) -> Union[Tuple[Literal[True], TradeRecord, None], Tuple[Literal[False], None, str]]:
if solver is None:
solver = Solver({})
take_offer_dict: Dict[Union[bytes32, int], int] = {}
arbitrage: Dict[Optional[bytes32], int] = offer.arbitrage()
@ -710,7 +719,7 @@ class TradeManager:
wallet = await self.wallet_state_manager.get_wallet_for_asset_id(asset_id.hex())
if wallet is None and amount < 0:
return False, None, f"Do not have a wallet for asset ID: {asset_id} to fulfill offer"
elif wallet is None or wallet.type() == WalletType.NFT:
elif wallet is None or wallet.type() in [WalletType.NFT, WalletType.DATA_LAYER]:
key = asset_id
else:
key = int(wallet.id())
@ -721,14 +730,14 @@ class TradeManager:
if not valid:
return False, None, "This offer is no longer valid"
result = await self._create_offer_for_ids(
take_offer_dict, offer.driver_dict, fee=fee, min_coin_amount=min_coin_amount
take_offer_dict, offer.driver_dict, solver, fee=fee, min_coin_amount=min_coin_amount
)
if not result[0] or result[1] is None:
return False, None, result[2]
success, take_offer, error = result
complete_offer = Offer.aggregate([offer, take_offer])
complete_offer = await self.check_for_final_modifications(Offer.aggregate([offer, take_offer]), solver)
self.log.info(f"COMPLETE OFFER: {complete_offer.to_bech32()}")
assert complete_offer.is_valid()
final_spend_bundle: SpendBundle = complete_offer.to_valid_spend()
@ -781,6 +790,7 @@ class TradeManager:
self,
offer_dict: Dict[Optional[bytes32], int],
driver_dict: Dict[bytes32, PuzzleInfo],
solver: Solver,
fee: uint64 = uint64(0),
min_coin_amount: Optional[uint64] = None,
) -> Optional[Offer]:
@ -794,6 +804,18 @@ class TradeManager:
return await NFTWallet.make_nft1_offer(
self.wallet_state_manager, offer_dict, driver_dict, fee, min_coin_amount
)
elif (
puzzle_info.check_type(
[
AssetType.SINGLETON.value,
AssetType.METADATA.value,
]
)
and puzzle_info.also()["updater_hash"] == ACS_MU_PH # type: ignore
):
return await DataLayerWallet.make_update_offer(
self.wallet_state_manager, offer_dict, driver_dict, solver, fee
)
return None
def check_for_owner_change_in_drivers(self, puzzle_info: PuzzleInfo, driver_info: PuzzleInfo) -> bool:
@ -815,3 +837,34 @@ class TradeManager:
if driver_info == puzzle_info:
return True
return False
async def get_offer_summary(self, offer: Offer) -> Dict[str, Any]:
for puzzle_info in offer.driver_dict.values():
if (
puzzle_info.check_type(
[
AssetType.SINGLETON.value,
AssetType.METADATA.value,
]
)
and puzzle_info.also()["updater_hash"] == ACS_MU_PH # type: ignore
):
return await DataLayerWallet.get_offer_summary(offer)
# Otherwise just return the same thing as the RPC normally does
offered, requested, infos = offer.summary()
return {"offered": offered, "requested": requested, "fees": offer.bundle.fees(), "infos": infos}
async def check_for_final_modifications(self, offer: Offer, solver: Solver) -> Offer:
for puzzle_info in offer.driver_dict.values():
if (
puzzle_info.check_type(
[
AssetType.SINGLETON.value,
AssetType.METADATA.value,
]
)
and puzzle_info.also()["updater_hash"] == ACS_MU_PH # type: ignore
):
return await DataLayerWallet.finish_graftroot_solutions(offer, solver)
return offer

View File

@ -7,7 +7,7 @@ from clvm_tools.binutils import disassemble
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.coin import Coin, coin_as_list
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.program import Program, INFINITE_COST
from chia.types.announcement import Announcement
from chia.types.coin_spend import CoinSpend
from chia.types.spend_bundle import SpendBundle
@ -34,6 +34,19 @@ OFFER_MOD = load_clvm("settlement_payments.clvm")
ZERO_32 = bytes32([0] * 32)
def detect_dependent_coin(
names: List[bytes32], deps: Dict[bytes32, List[bytes32]], announcement_dict: Dict[bytes32, List[bytes32]]
) -> Optional[Tuple[bytes32, bytes32]]:
# First, we check for any dependencies on coins in the same bundle
for name in names:
for dependency in deps[name]:
for coin, announces in announcement_dict.items():
if dependency in announces and coin != name:
# We found one, now remove it and anything that depends on it (except the "provider")
return name, coin
return None
@dataclass(frozen=True)
class NotarizedPayment(Payment):
nonce: bytes32 = ZERO_32
@ -289,6 +302,51 @@ class Offer:
primary_coins.add(self.get_root_removal(coin))
return list(primary_coins)
# This returns the minimum coins that when spent will invalidate the rest of the bundle
def get_cancellation_coins(self) -> List[Coin]:
# First, we're going to gather:
dependencies: Dict[bytes32, List[bytes32]] = {} # all of the hashes that each coin depends on
announcements: Dict[bytes32, List[bytes32]] = {} # all of the hashes of the announcement that each coin makes
coin_names: List[bytes32] = [] # The names of all the coins
for spend in [cs for cs in self.bundle.coin_spends if cs.coin not in self.bundle.additions()]:
name = bytes32(spend.coin.name())
coin_names.append(name)
dependencies[name] = []
announcements[name] = []
conditions: Program = spend.puzzle_reveal.run_with_cost(INFINITE_COST, spend.solution)[1]
for condition in conditions.as_iter():
if condition.first() == 60: # create coin announcement
announcements[name].append(Announcement(name, condition.at("rf").as_python()).name())
elif condition.first() == 61: # assert coin announcement
dependencies[name].append(bytes32(condition.at("rf").as_python()))
# We now enter a loop that is attempting to express the following logic:
# "If I am depending on another coin in the same bundle, you may as well cancel that coin instead of me"
# By the end of the loop, we should have filtered down the list of coin_names to include only those that will
# cancel everything else
while True:
removed = detect_dependent_coin(coin_names, dependencies, announcements)
if removed is None:
break
removed_coin, provider = removed
removed_announcements: List[bytes32] = announcements[removed_coin]
remove_these_keys: List[bytes32] = [removed_coin]
while True:
for coin, deps in dependencies.items():
if set(deps) & set(removed_announcements) and coin != provider:
remove_these_keys.append(coin)
removed_announcements = []
for coin in remove_these_keys:
dependencies.pop(coin)
removed_announcements.extend(announcements.pop(coin))
coin_names = [n for n in coin_names if n not in remove_these_keys]
if removed_announcements == []:
break
else:
remove_these_keys = []
return [cs.coin for cs in self.bundle.coin_spends if cs.coin.name() in coin_names]
@classmethod
def aggregate(cls, offers: List[Offer]) -> Offer:
total_requested_payments: Dict[Optional[bytes32], List[NotarizedPayment]] = {}

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Dict
from typing import Generic, List, Optional, Tuple, TypeVar, Dict
from chia.consensus.coinbase import pool_parent_id, farmer_parent_id
from chia.types.blockchain_format.coin import Coin
@ -12,6 +12,15 @@ from chia.util.streamable import Streamable, streamable
from chia.wallet.util.transaction_type import TransactionType
T = TypeVar("T")
@dataclass
class ItemAndTransactionRecords(Generic[T]):
item: T
transaction_records: List["TransactionRecord"]
@streamable
@dataclass(frozen=True)
class TransactionRecord(Streamable):

View File

@ -0,0 +1,98 @@
import math
from enum import Enum
from typing import List, Optional, Tuple
from clvm.casts import int_to_bytes
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.hash import std_hash
ONE = int_to_bytes(1)
TWO = int_to_bytes(2)
def hash_a_pair(left: bytes32, right: bytes32) -> bytes32:
return std_hash(TWO + left + right)
def hash_an_atom(atom: bytes32) -> bytes32:
return std_hash(ONE + atom)
class TreeType(Enum):
TREE = 1
WATERFALL = 2
class MerkleTree:
type: TreeType
nodes: List[bytes32]
def __init__(self, nodes: List[bytes32], waterfall: bool = False) -> None:
self.type = TreeType.WATERFALL if waterfall else TreeType.TREE
self.nodes = nodes
def split_list(self, puzzle_hashes: List[bytes32]) -> Tuple[List[bytes32], List[bytes32]]:
if self.type == TreeType.TREE:
mid_index = math.ceil(len(puzzle_hashes) / 2)
first = puzzle_hashes[0:mid_index]
rest = puzzle_hashes[mid_index : len(puzzle_hashes)]
else:
first = puzzle_hashes[0:-1]
rest = puzzle_hashes[-1 : len(puzzle_hashes)]
return first, rest
def _root(self, puzzle_hashes: List[bytes32]) -> bytes32:
if len(puzzle_hashes) == 1:
return hash_an_atom(puzzle_hashes[0])
else:
first, rest = self.split_list(puzzle_hashes)
return hash_a_pair(self._root(first), self._root(rest))
def calculate_root(self) -> bytes32:
return self._root(self.nodes)
def _proof(
self, puzzle_hashes: List[bytes32], searching_for: bytes32
) -> Tuple[Optional[int], Optional[List[bytes32]], bytes32, Optional[int]]:
if len(puzzle_hashes) == 1:
atom_hash = hash_an_atom(puzzle_hashes[0])
if puzzle_hashes[0] == searching_for:
return (0, [], atom_hash, 0)
else:
return (None, [], atom_hash, None)
else:
first, rest = self.split_list(puzzle_hashes)
first_hash = self._proof(first, searching_for)
rest_hash = self._proof(rest, searching_for)
final_path = None
final_list = None
bit_num = None
if first_hash[0] is not None:
final_list = first_hash[1]
# TODO: handle hints
# error: Item "None" of "Optional[List[bytes32]]" has no attribute "append" [union-attr]
final_list.append(rest_hash[2]) # type: ignore[union-attr]
bit_num = first_hash[3]
final_path = first_hash[0]
elif rest_hash[0] is not None:
final_list = rest_hash[1]
# TODO: handle hints
# error: Item "None" of "Optional[List[bytes32]]" has no attribute "append" [union-attr]
final_list.append(first_hash[2]) # type: ignore[union-attr]
bit_num = rest_hash[3]
# TODO: handle hints
# error: Unsupported operand types for << ("int" and "None") [operator]
# note: Right operand is of type "Optional[int]"
final_path = rest_hash[0] | (1 << bit_num) # type: ignore[operator]
pair_hash = hash_a_pair(first_hash[2], rest_hash[2])
return (final_path, final_list, pair_hash, bit_num + 1 if bit_num is not None else None)
def generate_proof(self, leaf_reveal: bytes32) -> Tuple[Optional[int], List[Optional[List[bytes32]]]]:
proof = self._proof(self.nodes, leaf_reveal)
return (proof[0], [proof[1]])

View File

@ -0,0 +1,98 @@
import hashlib
from typing import Any, Dict, List, Tuple
from chia.types.blockchain_format.sized_bytes import bytes32
TupleTree = Any # Union[bytes32, Tuple["TupleTree", "TupleTree"]]
Proof_Tree_Type = Any # Union[bytes32, Tuple[bytes32, "Proof_Tree_Type"]]
HASH_TREE_PREFIX = bytes([2])
HASH_LEAF_PREFIX = bytes([1])
# paths here are not quite the same a `NodePath` paths. We don't need the high order bit
# anymore since the proof indicates how big the path is.
def compose_paths(path_1: int, path_2: int, path_2_length: int) -> int:
return (path_1 << path_2_length) | path_2
def sha256(*args: bytes) -> bytes32:
return bytes32(hashlib.sha256(b"".join(args)).digest())
def build_merkle_tree_from_binary_tree(tuples: TupleTree) -> Tuple[bytes32, Dict[bytes32, Tuple[int, List[bytes32]]]]:
if isinstance(tuples, bytes):
tuples = bytes32(tuples)
return sha256(HASH_LEAF_PREFIX, tuples), {tuples: (0, [])}
left, right = tuples
left_root, left_proofs = build_merkle_tree_from_binary_tree(left)
right_root, right_proofs = build_merkle_tree_from_binary_tree(right)
new_root = sha256(HASH_TREE_PREFIX, left_root, right_root)
new_proofs = {}
for name, (path, proof) in left_proofs.items():
proof.append(right_root)
new_proofs[name] = (path, proof)
for name, (path, proof) in right_proofs.items():
path |= 1 << len(proof)
proof.append(left_root)
new_proofs[name] = (path, proof)
return new_root, new_proofs
def list_to_binary_tree(objects: List[Any]) -> Any:
size = len(objects)
if size == 1:
return objects[0]
midpoint = (size + 1) >> 1
first_half = objects[:midpoint]
last_half = objects[midpoint:]
return (list_to_binary_tree(first_half), list_to_binary_tree(last_half))
def build_merkle_tree(objects: List[bytes32]) -> Tuple[bytes32, Dict[bytes32, Tuple[int, List[bytes32]]]]:
"""
return (merkle_root, dict_of_proofs)
"""
objects_binary_tree = list_to_binary_tree(objects)
return build_merkle_tree_from_binary_tree(objects_binary_tree)
def merkle_proof_from_path_and_tree(node_path: int, proof_tree: Proof_Tree_Type) -> Tuple[int, List[bytes32]]:
proof_path = 0
proof = []
while not isinstance(proof_tree, bytes32):
left_vs_right = node_path & 1
path_element = proof_tree[1][1 - left_vs_right]
if isinstance(path_element, bytes32):
proof.append(path_element)
else:
proof.append(path_element[0])
node_path >>= 1
proof_tree = proof_tree[1][left_vs_right]
proof_path += proof_path + left_vs_right
proof.reverse()
return proof_path, proof
def _simplify_merkle_proof(tree_hash: bytes32, proof: Tuple[int, List[bytes32]]) -> bytes32:
# we return the expected merkle root
path, nodes = proof
for node in nodes:
if path & 1:
tree_hash = sha256(HASH_TREE_PREFIX, node, tree_hash)
else:
tree_hash = sha256(HASH_TREE_PREFIX, tree_hash, node)
path >>= 1
return tree_hash
def simplify_merkle_proof(tree_hash: bytes32, proof: Tuple[int, List[bytes32]]) -> bytes32:
return _simplify_merkle_proof(sha256(HASH_LEAF_PREFIX, tree_hash), proof)
def check_merkle_proof(merkle_root: bytes32, tree_hash: bytes32, proof: Tuple[int, List[bytes32]]) -> bool:
return merkle_root == simplify_merkle_proof(tree_hash, proof)

View File

@ -20,6 +20,8 @@ class WalletType(IntEnum):
DECENTRALIZED_ID = 8
POOLING_WALLET = 9
NFT = 10
DATA_LAYER = 11
DATA_LAYER_OFFER = 12
class AmountWithPuzzlehash(TypedDict):

View File

@ -183,7 +183,9 @@ class Wallet:
async def get_new_puzzle(self) -> Program:
dr = await self.wallet_state_manager.get_unused_derivation_record(self.id())
return puzzle_for_pk(dr.pubkey)
puzzle = puzzle_for_pk(dr.pubkey)
await self.hack_populate_secret_key_for_puzzle_hash(puzzle.get_tree_hash())
return puzzle
async def get_puzzle_hash(self, new: bool) -> bytes32:
if new:
@ -197,7 +199,9 @@ class Wallet:
return record.puzzle_hash
async def get_new_puzzlehash(self) -> bytes32:
return (await self.wallet_state_manager.get_unused_derivation_record(self.id())).puzzle_hash
puzhash = (await self.wallet_state_manager.get_unused_derivation_record(self.id())).puzzle_hash
await self.hack_populate_secret_key_for_puzzle_hash(puzhash)
return puzhash
def make_solution(
self,
@ -399,7 +403,7 @@ class Wallet:
continue
puzzle = await self.puzzle_for_puzzle_hash(coin.puzzle_hash)
solution = self.make_solution(coin_announcements_to_assert={primary_announcement_hash}, primaries=[])
solution = self.make_solution(primaries=[], coin_announcements_to_assert={primary_announcement_hash})
spends.append(
CoinSpend(
coin, SerializedProgram.from_bytes(bytes(puzzle)), SerializedProgram.from_bytes(bytes(solution))
@ -456,7 +460,7 @@ class Wallet:
puzzle_announcements_to_consume,
memos,
negative_change_allowed,
min_coin_amount,
min_coin_amount=min_coin_amount,
exclude_coins=exclude_coins,
)
assert len(transaction) > 0

View File

@ -801,12 +801,14 @@ class WalletNode:
return still_connected and self._server is not None and peer.peer_node_id in self.server.all_connections
async def get_coins_with_puzzle_hash(self, puzzle_hash) -> List[CoinState]:
# TODO Use trusted peer, otherwise try untrusted
all_nodes = self.server.connection_by_type[NodeType.FULL_NODE]
if len(all_nodes.keys()) == 0:
raise ValueError("Not connected to the full node")
first_node = list(all_nodes.values())[0]
msg = wallet_protocol.RegisterForPhUpdates(puzzle_hash, uint32(0))
coin_state: Optional[RespondToPhUpdates] = await first_node.register_interest_in_puzzle_hash(msg)
# TODO validate state if received from untrusted peer
assert coin_state is not None
return coin_state.coin_states

View File

@ -38,11 +38,12 @@ class WalletPuzzleStore:
"CREATE TABLE IF NOT EXISTS derivation_paths("
"derivation_index int,"
" pubkey text,"
" puzzle_hash text PRIMARY KEY,"
" puzzle_hash text,"
" wallet_type int,"
" wallet_id int,"
" used tinyint,"
" hardened tinyint)"
" hardened tinyint,"
" PRIMARY KEY(puzzle_hash, wallet_id))"
)
)
await conn.execute(

View File

@ -16,6 +16,8 @@ from blspy import G1Element, PrivateKey
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.consensus.coinbase import farmer_parent_id, pool_parent_id
from chia.consensus.constants import ConsensusConstants
from chia.data_layer.data_layer_wallet import DataLayerWallet
from chia.data_layer.dl_wallet_store import DataLayerStore
from chia.pools.pool_puzzles import SINGLETON_LAUNCHER_HASH, solution_to_pool_state
from chia.pools.pool_wallet import PoolWallet
from chia.protocols import wallet_protocol
@ -123,6 +125,7 @@ class WalletStateManager:
root_path: Path
wallet_node: Any
pool_store: WalletPoolStore
dl_store: DataLayerStore
default_cats: Dict[str, Any]
asset_to_wallet_map: Dict[AssetType, Any]
initial_num_public_keys: int
@ -190,6 +193,7 @@ class WalletStateManager:
self.trade_manager = await TradeManager.create(self, self.db_wrapper)
self.user_settings = await UserSettings.create(self.basic_store)
self.pool_store = await WalletPoolStore.create(self.db_wrapper)
self.dl_store = await DataLayerStore.create(self.db_wrapper)
self.interested_store = await WalletInterestedStore.create(self.db_wrapper)
self.default_cats = DEFAULT_CATS
@ -245,6 +249,12 @@ class WalletStateManager:
self.main_wallet,
wallet_info,
)
elif wallet_info.type == WalletType.DATA_LAYER:
wallet = await DataLayerWallet.create(
self,
self.main_wallet,
wallet_info,
)
if wallet is not None:
self.wallets[wallet_info.id] = wallet
@ -389,6 +399,7 @@ class WalletStateManager:
if unused is None:
# This handles the case where the database is empty
unused = uint32(0)
if last is not None:
for index in range(unused, last):
# Since DID are not released yet we can assume they are only using unhardened keys derivation
pubkey: G1Element = self.get_public_key_unhardened(uint32(index))
@ -860,7 +871,10 @@ class WalletStateManager:
return wallet_id, wallet_type
async def new_coin_state(
self, coin_states: List[CoinState], peer: WSChiaConnection, fork_height: Optional[uint32]
self,
coin_states: List[CoinState],
peer: WSChiaConnection,
fork_height: Optional[uint32],
) -> None:
# TODO: add comment about what this method does
# Input states should already be sorted by cs_height, with reorgs at the beginning
@ -910,6 +924,11 @@ class WalletStateManager:
wallet_type = local_record.wallet_type
elif coin_state.created_height is not None:
wallet_id, wallet_type = await self.determine_coin_type(peer, coin_state, fork_height)
potential_dl = self.get_dl_wallet()
if potential_dl is not None:
if await potential_dl.get_singleton_record(coin_state.coin.name()) is not None:
wallet_id = potential_dl.id()
wallet_type = WalletType(potential_dl.type())
if wallet_id is None or wallet_type is None:
self.log.debug(f"No wallet for coin state: {coin_state}")
@ -1110,6 +1129,15 @@ class WalletStateManager:
)
assert len(new_coin_state) == 1
curr_coin_state = new_coin_state[0]
if record.wallet_type == WalletType.DATA_LAYER:
singleton_spend = await self.wallet_node.fetch_puzzle_solution(
coin_state.spent_height, coin_state.coin, peer
)
dl_wallet = self.wallets[uint32(record.wallet_id)]
await dl_wallet.singleton_removed(
singleton_spend,
coin_state.spent_height,
)
elif record.wallet_type == WalletType.NFT:
if coin_state.spent_height is not None:
@ -1135,8 +1163,30 @@ class WalletStateManager:
continue
try:
pool_state = solution_to_pool_state(launcher_spend)
except Exception as e:
assert pool_state is not None
except (AssertionError, ValueError) as e:
self.log.debug(f"Not a pool wallet launcher {e}")
matched, inner_puzhash = await DataLayerWallet.match_dl_launcher(launcher_spend)
if (
matched
and inner_puzhash is not None
and (await self.puzzle_store.puzzle_hash_exists(inner_puzhash))
):
for _, wallet in self.wallets.items():
if wallet.type() == WalletType.DATA_LAYER.value:
dl_wallet = wallet
break
else: # No DL wallet exists yet
dl_wallet = await DataLayerWallet.create_new_dl_wallet(
self,
self.main_wallet,
)
await dl_wallet.track_new_launcher_id(
child.coin.name(),
peer,
spend=launcher_spend,
height=child.spent_height,
)
continue
# solution_to_pool_state may return None but this may not be an error
@ -1330,7 +1380,7 @@ class WalletStateManager:
)
await self.coin_store.add_coin_record(coin_record_1, coin_name)
if wallet_type in [WalletType.CAT, WalletType.DECENTRALIZED_ID, WalletType.NFT]:
if wallet_type in (WalletType.CAT, WalletType.DECENTRALIZED_ID, WalletType.NFT, WalletType.DATA_LAYER):
await self.wallets[wallet_id].coin_added(coin, height, peer)
await self.create_more_puzzle_hashes()
@ -1455,6 +1505,9 @@ class WalletStateManager:
if wallet.type() == WalletType.CAT:
if bytes(wallet.cat_info.limitations_program_hash).hex() == asset_id:
return wallet
elif wallet.type() == WalletType.DATA_LAYER:
if await wallet.get_latest_singleton(bytes32.from_hexstr(asset_id)) is not None:
return wallet
elif wallet.type() == WalletType.NFT:
for nft_coin in wallet.my_nft_coins:
if nft_coin.nft_id.hex() == asset_id:
@ -1545,3 +1598,9 @@ class WalletStateManager:
return await wallet.convert_puzzle_hash(puzzle_hash)
return puzzle_hash
def get_dl_wallet(self):
for _, wallet in self.wallets.items():
if wallet.type() == WalletType.DATA_LAYER.value:
return wallet
return None

View File

@ -6,7 +6,9 @@ log_level = WARNING
console_output_style = count
log_format = %(asctime)s %(name)s: %(levelname)s %(message)s
asyncio_mode = strict
markers=benchmark
markers =
benchmark
data_layer: Mark as a data layer related test.
testpaths = tests
filterwarnings =
error

View File

@ -59,7 +59,7 @@ dev_dependencies = [
"ipython", # For asyncio debugging
"pyinstaller==5.3",
"types-aiofiles",
"types-click",
"types-click~=7.1",
"types-cryptography",
"types-pkg_resources",
"types-pyyaml",
@ -88,6 +88,7 @@ kwargs = dict(
"chia.clvm",
"chia.consensus",
"chia.daemon",
"chia.data_layer",
"chia.full_node",
"chia.timelord",
"chia.farmer",
@ -106,6 +107,7 @@ kwargs = dict(
"chia.types",
"chia.util",
"chia.wallet",
"chia.wallet.db_wallet",
"chia.wallet.puzzles",
"chia.wallet.rl_wallet",
"chia.wallet.cat_wallet",
@ -131,6 +133,7 @@ kwargs = dict(
"chia_timelord = chia.server.start_timelord:main",
"chia_timelord_launcher = chia.timelord.timelord_launcher:main",
"chia_full_node_simulator = chia.simulator.start_simulator:main",
"chia_data_layer = chia.server.start_data_layer:main",
]
},
package_data={

View File

@ -3,10 +3,14 @@ from pathlib import Path
from tempfile import NamedTemporaryFile
from unittest import TestCase
import pytest
from clvm_tools.clvmc import compile_clvm
from chia.types.blockchain_format.program import Program, SerializedProgram
pytestmark = pytest.mark.data_layer
wallet_program_files = set(
[
"chia/wallet/puzzles/calculate_synthetic_public_key.clvm",
@ -48,6 +52,8 @@ wallet_program_files = set(
"chia/wallet/puzzles/nft_state_layer.clvm",
"chia/wallet/puzzles/nft_ownership_layer.clvm",
"chia/wallet/puzzles/nft_ownership_transfer_program_one_way_claim_with_royalties.clvm",
"chia/wallet/puzzles/graftroot_dl_offers.clvm",
"chia/wallet/puzzles/p2_parent.clvm",
"chia/wallet/puzzles/decompress_block_spends.clvm",
]
)

View File

@ -605,8 +605,8 @@ async def wallets_prefarm(two_wallet_nodes, self_hostname, trusted):
for i in range(0, buffer):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32(token_bytes(nbytes=32))))
await time_out_assert(10, wallet_is_synced, True, wallet_node_0, full_node_api)
await time_out_assert(10, wallet_is_synced, True, wallet_node_1, full_node_api)
await time_out_assert(30, wallet_is_synced, True, wallet_node_0, full_node_api)
await time_out_assert(30, wallet_is_synced, True, wallet_node_1, full_node_api)
return wallet_node_0, wallet_node_1, full_node_api

View File

View File

@ -0,0 +1,2 @@
job_timeout = 55
checkout_blocks_and_plots = True

View File

@ -0,0 +1,148 @@
import contextlib
import os
import pathlib
import subprocess
import sys
import sysconfig
import time
from typing import Any, AsyncIterable, Awaitable, Callable, Dict, Iterator, List
import aiosqlite
import pytest
import pytest_asyncio
# https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest
from chia.data_layer.data_layer_util import NodeType, Status
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.tree_hash import bytes32
from chia.util.db_wrapper import DBWrapper
from tests.core.data_layer.util import (
ChiaRoot,
Example,
add_0123_example,
add_01234567_example,
create_valid_node_values,
)
# TODO: These are more general than the data layer and should either move elsewhere or
# be replaced with an existing common approach. For now they can at least be
# shared among the data layer test files.
@pytest.fixture(name="scripts_path", scope="session")
def scripts_path_fixture() -> pathlib.Path:
scripts_string = sysconfig.get_path("scripts")
if scripts_string is None:
raise Exception("These tests depend on the scripts path existing")
return pathlib.Path(scripts_string)
@pytest.fixture(name="chia_root", scope="function")
def chia_root_fixture(tmp_path: pathlib.Path, scripts_path: pathlib.Path) -> ChiaRoot:
root = ChiaRoot(path=tmp_path.joinpath("chia_root"), scripts_path=scripts_path)
root.run(args=["init"])
root.run(args=["configure", "--set-log-level", "INFO"])
return root
@contextlib.contextmanager
def closing_chia_root_popen(chia_root: ChiaRoot, args: List[str]) -> Iterator[None]:
environment = {**os.environ, "CHIA_ROOT": os.fspath(chia_root.path)}
with subprocess.Popen(args=args, env=environment) as process:
try:
yield
finally:
process.terminate()
try:
process.wait(timeout=10)
except subprocess.TimeoutExpired:
process.kill()
@pytest.fixture(name="chia_daemon", scope="function")
def chia_daemon_fixture(chia_root: ChiaRoot) -> Iterator[None]:
with closing_chia_root_popen(chia_root=chia_root, args=[sys.executable, "-m", "chia.daemon.server"]):
# TODO: this is not pretty as a hard coded time
# let it settle
time.sleep(5)
yield
@pytest.fixture(name="chia_data", scope="function")
def chia_data_fixture(chia_root: ChiaRoot, chia_daemon: None, scripts_path: pathlib.Path) -> Iterator[None]:
with closing_chia_root_popen(chia_root=chia_root, args=[os.fspath(scripts_path.joinpath("chia_data_layer"))]):
# TODO: this is not pretty as a hard coded time
# let it settle
time.sleep(5)
yield
@pytest.fixture(name="create_example", params=[add_0123_example, add_01234567_example])
def create_example_fixture(request: SubRequest) -> Callable[[DataStore, bytes32], Awaitable[Example]]:
# https://github.com/pytest-dev/pytest/issues/8763
return request.param # type: ignore[no-any-return]
@pytest_asyncio.fixture(name="db_connection", scope="function")
async def db_connection_fixture() -> AsyncIterable[aiosqlite.Connection]:
async with aiosqlite.connect(":memory:") as connection:
# make sure this is on for tests even if we disable it at run time
await connection.execute("PRAGMA foreign_keys = ON")
yield connection
@pytest.fixture(name="db_wrapper", scope="function")
def db_wrapper_fixture(db_connection: aiosqlite.Connection) -> DBWrapper:
return DBWrapper(db_connection)
@pytest.fixture(name="tree_id", scope="function")
def tree_id_fixture() -> bytes32:
base = b"a tree id"
pad = b"." * (32 - len(base))
return bytes32(pad + base)
@pytest_asyncio.fixture(name="raw_data_store", scope="function")
async def raw_data_store_fixture(db_wrapper: DBWrapper) -> DataStore:
return await DataStore.create(db_wrapper=db_wrapper)
@pytest_asyncio.fixture(name="data_store", scope="function")
async def data_store_fixture(raw_data_store: DataStore, tree_id: bytes32) -> AsyncIterable[DataStore]:
await raw_data_store.create_tree(tree_id=tree_id, status=Status.COMMITTED)
await raw_data_store.check()
yield raw_data_store
await raw_data_store.check()
@pytest.fixture(name="node_type", params=NodeType)
def node_type_fixture(request: SubRequest) -> NodeType:
return request.param # type: ignore[no-any-return]
@pytest_asyncio.fixture(name="valid_node_values")
async def valid_node_values_fixture(
data_store: DataStore,
tree_id: bytes32,
node_type: NodeType,
) -> Dict[str, Any]:
await add_01234567_example(data_store=data_store, tree_id=tree_id)
node_a = await data_store.get_node_by_key(key=b"\x02", tree_id=tree_id)
node_b = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id)
return create_valid_node_values(node_type=node_type, left_hash=node_a.hash, right_hash=node_b.hash)
@pytest.fixture(name="bad_node_type", params=range(2 * len(NodeType)))
def bad_node_type_fixture(request: SubRequest, valid_node_values: Dict[str, Any]) -> int:
if request.param == valid_node_values["node_type"]:
pytest.skip("Actually, this is a valid node type")
return request.param # type: ignore[no-any-return]

View File

@ -0,0 +1,55 @@
import json
from typing import Dict, List
import pytest
from tests.core.data_layer.util import ChiaRoot
pytestmark = pytest.mark.data_layer
@pytest.mark.asyncio
async def test_help(chia_root: ChiaRoot) -> None:
"""Just a trivial test to make sure the subprocessing is at least working and the
data executable does run.
"""
completed_process = chia_root.run(args=["data", "--help"])
assert "Show this message and exit" in completed_process.stdout
@pytest.mark.xfail(strict=True)
@pytest.mark.asyncio
def test_round_trip(chia_root: ChiaRoot, chia_daemon: None, chia_data: None) -> None:
"""Create a table, insert a row, get the row by its hash."""
with chia_root.print_log_after():
create = chia_root.run(args=["data", "create_data_store"])
print(f"create_data_store: {create}")
dic = json.loads(create.stdout)
assert dic["success"]
tree_id = dic["id"]
key = "1a6f915513173902a7216e7d9e4a16bfd088e20683f45de3b432ce72e9cc7aa8"
value = "ffff8353594d8083616263"
changelist: List[Dict[str, str]] = [{"action": "insert", "key": key, "value": value}]
print(json.dumps(changelist))
update = chia_root.run(
args=["data", "update_data_store", "--id", tree_id, "--changelist", json.dumps(changelist)]
)
dic = json.loads(create.stdout)
assert dic["success"]
print(f"update_data_store: {update}")
completed_process = chia_root.run(args=["data", "get_value", "--id", tree_id, "--key", key])
parsed = json.loads(completed_process.stdout)
expected = {"value": value, "success": True}
assert parsed == expected
get_keys_values = chia_root.run(args=["data", "get_keys_values", "--id", tree_id])
print(f"get_keys_values: {get_keys_values}")
changelist = [{"action": "delete", "key": key}]
update = chia_root.run(
args=["data", "update_data_store", "--id", tree_id, "--changelist", json.dumps(changelist)]
)
print(f"update_data_store: {update}")
completed_process = chia_root.run(args=["data", "get_value", "--id", tree_id, "--key", key])
parsed = json.loads(completed_process.stdout)
expected = {"data": None, "success": True}
assert parsed == expected

View File

@ -0,0 +1,77 @@
import dataclasses
from typing import List
import pytest
# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest
from chia.data_layer.data_layer_util import ProofOfInclusion, ProofOfInclusionLayer, Side
from chia.types.blockchain_format.sized_bytes import bytes32
pytestmark = pytest.mark.data_layer
def create_valid_proof_of_inclusion(layer_count: int, other_hash_side: Side) -> ProofOfInclusion:
node_hash = bytes32(b"a" * 32)
layers: List[ProofOfInclusionLayer] = []
existing_hash = node_hash
other_hashes = [bytes32([i] * 32) for i in range(layer_count)]
for other_hash in other_hashes:
new_layer = ProofOfInclusionLayer.from_hashes(
primary_hash=existing_hash,
other_hash_side=other_hash_side,
other_hash=other_hash,
)
layers.append(new_layer)
existing_hash = new_layer.combined_hash
return ProofOfInclusion(node_hash=node_hash, layers=layers)
@pytest.fixture(name="side", params=[Side.LEFT, Side.RIGHT])
def side_fixture(request: SubRequest) -> Side:
# https://github.com/pytest-dev/pytest/issues/8763
return request.param # type: ignore[no-any-return]
@pytest.fixture(name="valid_proof_of_inclusion", params=[0, 1, 5])
def valid_proof_of_inclusion_fixture(request: SubRequest, side: Side) -> ProofOfInclusion:
return create_valid_proof_of_inclusion(layer_count=request.param, other_hash_side=side)
@pytest.fixture(
name="invalid_proof_of_inclusion",
params=["bad root hash", "bad other hash", "bad other side", "bad node hash"],
)
def invalid_proof_of_inclusion_fixture(request: SubRequest, side: Side) -> ProofOfInclusion:
valid_proof_of_inclusion = create_valid_proof_of_inclusion(layer_count=5, other_hash_side=side)
layers = list(valid_proof_of_inclusion.layers)
a_hash = bytes32(b"f" * 32)
if request.param == "bad root hash":
layers[-1] = dataclasses.replace(layers[-1], combined_hash=a_hash)
return dataclasses.replace(valid_proof_of_inclusion, layers=layers)
elif request.param == "bad other hash":
layers[1] = dataclasses.replace(layers[1], other_hash=a_hash)
return dataclasses.replace(valid_proof_of_inclusion, layers=layers)
elif request.param == "bad other side":
layers[1] = dataclasses.replace(layers[1], other_hash_side=layers[1].other_hash_side.other())
return dataclasses.replace(valid_proof_of_inclusion, layers=layers)
elif request.param == "bad node hash":
return dataclasses.replace(valid_proof_of_inclusion, node_hash=a_hash)
raise Exception(f"Unhandled parametrization: {request.param!r}")
def test_proof_of_inclusion_is_valid(valid_proof_of_inclusion: ProofOfInclusion) -> None:
assert valid_proof_of_inclusion.valid()
def test_proof_of_inclusion_is_invalid(invalid_proof_of_inclusion: ProofOfInclusion) -> None:
assert not invalid_proof_of_inclusion.valid()

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,374 @@
import sqlite3
from typing import Any, Dict
import pytest
from chia.data_layer.data_layer_util import NodeType, Side, Status
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.tree_hash import bytes32
from tests.core.data_layer.util import add_01234567_example, create_valid_node_values
pytestmark = pytest.mark.data_layer
@pytest.mark.asyncio
async def test_node_update_fails(data_store: DataStore, tree_id: bytes32) -> None:
await add_01234567_example(data_store=data_store, tree_id=tree_id)
node = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id)
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^updates not allowed to the node table$"):
await data_store.db.execute(
"UPDATE node SET value = :value WHERE hash == :hash",
{
"hash": node.hash,
"value": node.value,
},
)
@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32}))
@pytest.mark.asyncio
async def test_node_hash_must_be_32(
data_store: DataStore,
tree_id: bytes32,
length: int,
valid_node_values: Dict[str, Any],
) -> None:
valid_node_values["hash"] = bytes([0] * length)
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
""",
valid_node_values,
)
@pytest.mark.asyncio
async def test_node_hash_must_not_be_null(
data_store: DataStore,
tree_id: bytes32,
valid_node_values: Dict[str, Any],
) -> None:
valid_node_values["hash"] = None
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: node.hash$"):
await data_store.db.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
""",
valid_node_values,
)
@pytest.mark.asyncio
async def test_node_type_must_be_valid(
data_store: DataStore,
node_type: NodeType,
bad_node_type: int,
valid_node_values: Dict[str, Any],
) -> None:
valid_node_values["node_type"] = bad_node_type
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
""",
valid_node_values,
)
@pytest.mark.parametrize(argnames="side", argvalues=Side)
@pytest.mark.asyncio
async def test_node_internal_child_not_null(data_store: DataStore, tree_id: bytes32, side: Side) -> None:
await add_01234567_example(data_store=data_store, tree_id=tree_id)
node_a = await data_store.get_node_by_key(key=b"\x02", tree_id=tree_id)
node_b = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id)
values = create_valid_node_values(node_type=NodeType.INTERNAL, left_hash=node_a.hash, right_hash=node_b.hash)
if side == Side.LEFT:
values["left"] = None
elif side == Side.RIGHT:
values["right"] = None
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
""",
values,
)
@pytest.mark.parametrize(argnames="bad_child_hash", argvalues=[b"\x01" * 32, b"\0" * 31, b""])
@pytest.mark.parametrize(argnames="side", argvalues=Side)
@pytest.mark.asyncio
async def test_node_internal_must_be_valid_reference(
data_store: DataStore,
tree_id: bytes32,
bad_child_hash: bytes,
side: Side,
) -> None:
await add_01234567_example(data_store=data_store, tree_id=tree_id)
node_a = await data_store.get_node_by_key(key=b"\x02", tree_id=tree_id)
node_b = await data_store.get_node_by_key(key=b"\x04", tree_id=tree_id)
values = create_valid_node_values(node_type=NodeType.INTERNAL, left_hash=node_a.hash, right_hash=node_b.hash)
if side == Side.LEFT:
values["left"] = bad_child_hash
elif side == Side.RIGHT:
values["right"] = bad_child_hash
else:
assert False
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"):
await data_store.db.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
""",
values,
)
@pytest.mark.parametrize(argnames="key_or_value", argvalues=["key", "value"])
@pytest.mark.asyncio
async def test_node_terminal_key_value_not_null(data_store: DataStore, tree_id: bytes32, key_or_value: str) -> None:
await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = create_valid_node_values(node_type=NodeType.TERMINAL)
values[key_or_value] = None
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO node(hash, node_type, left, right, key, value)
VALUES(:hash, :node_type, :left, :right, :key, :value)
""",
values,
)
@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32}))
@pytest.mark.asyncio
async def test_root_tree_id_must_be_32(data_store: DataStore, tree_id: bytes32, length: int) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
bad_tree_id = bytes([0] * length)
values = {"tree_id": bad_tree_id, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.PENDING}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.asyncio
async def test_root_tree_id_must_not_be_null(data_store: DataStore, tree_id: bytes32) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {"tree_id": None, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.PENDING}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.tree_id$"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.parametrize(argnames="generation", argvalues=[-200, -2, -1])
@pytest.mark.asyncio
async def test_root_generation_must_not_be_less_than_zero(
data_store: DataStore, tree_id: bytes32, generation: int
) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {
"tree_id": bytes32([0] * 32),
"generation": generation,
"node_hash": example.terminal_nodes[0],
"status": Status.PENDING,
}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.asyncio
async def test_root_generation_must_not_be_null(data_store: DataStore, tree_id: bytes32) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {
"tree_id": bytes32([0] * 32),
"generation": None,
"node_hash": example.terminal_nodes[0],
"status": Status.PENDING,
}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.generation$"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.asyncio
async def test_root_node_hash_must_reference(data_store: DataStore) -> None:
values = {"tree_id": bytes32([0] * 32), "generation": 0, "node_hash": bytes32([0] * 32), "status": Status.PENDING}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^FOREIGN KEY constraint failed$"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.parametrize(argnames="bad_status", argvalues=sorted(set(range(-20, 20)) - {*Status}))
@pytest.mark.asyncio
async def test_root_status_must_be_valid(data_store: DataStore, tree_id: bytes32, bad_status: int) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {
"tree_id": bytes32([0] * 32),
"generation": 0,
"node_hash": example.terminal_nodes[0],
"status": bad_status,
}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.asyncio
async def test_root_status_must_not_be_null(data_store: DataStore, tree_id: bytes32) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {"tree_id": bytes32([0] * 32), "generation": 0, "node_hash": example.terminal_nodes[0], "status": None}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^NOT NULL constraint failed: root.status$"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.asyncio
async def test_root_tree_id_generation_must_be_unique(data_store: DataStore, tree_id: bytes32) -> None:
example = await add_01234567_example(data_store=data_store, tree_id=tree_id)
values = {"tree_id": tree_id, "generation": 0, "node_hash": example.terminal_nodes[0], "status": Status.COMMITTED}
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^UNIQUE constraint failed: root.tree_id, root.generation$"):
await data_store.db.execute(
"""
INSERT INTO root(tree_id, generation, node_hash, status)
VALUES(:tree_id, :generation, :node_hash, :status)
""",
values,
)
@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32}))
@pytest.mark.asyncio
async def test_ancestors_ancestor_must_be_32(
data_store: DataStore,
tree_id: bytes32,
length: int,
) -> None:
async with data_store.db_wrapper.locked_transaction():
node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01")
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO ancestors(hash, ancestor, tree_id, generation)
VALUES(:hash, :ancestor, :tree_id, :generation)
""",
{"hash": node_hash, "ancestor": bytes([0] * length), "tree_id": bytes32([0] * 32), "generation": 0},
)
@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32}))
@pytest.mark.asyncio
async def test_ancestors_tree_id_must_be_32(
data_store: DataStore,
tree_id: bytes32,
length: int,
) -> None:
async with data_store.db_wrapper.locked_transaction():
node_hash = await data_store._insert_terminal_node(key=b"\x00", value=b"\x01")
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO ancestors(hash, ancestor, tree_id, generation)
VALUES(:hash, :ancestor, :tree_id, :generation)
""",
{"hash": node_hash, "ancestor": bytes32([0] * 32), "tree_id": bytes([0] * length), "generation": 0},
)
@pytest.mark.parametrize(argnames="length", argvalues=sorted(set(range(50)) - {32}))
@pytest.mark.asyncio
async def test_subscriptions_tree_id_must_be_32(
data_store: DataStore,
tree_id: bytes32,
length: int,
) -> None:
async with data_store.db_wrapper.locked_transaction():
with pytest.raises(sqlite3.IntegrityError, match=r"^CHECK constraint failed:"):
await data_store.db.execute(
"""
INSERT INTO subscriptions(tree_id, url, ignore_till, num_consecutive_failures, from_wallet)
VALUES(:tree_id, :url, :ignore_till, :num_consecutive_failures, :from_wallet)
""",
{
"tree_id": bytes([0] * length),
"url": "",
"ignore_till": 0,
"num_consecutive_failures": 0,
"from_wallet": False,
},
)

View File

@ -0,0 +1,209 @@
import contextlib
import functools
import os
import pathlib
import subprocess
from dataclasses import dataclass
from typing import IO, TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Union
from chia.data_layer.data_layer_util import NodeType, Side, Status
from chia.data_layer.data_store import DataStore
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.tree_hash import bytes32
# from subprocess.pyi
_FILE = Union[None, int, IO[Any]]
if TYPE_CHECKING:
# these require Python 3.9 at runtime
os_PathLike_str = os.PathLike[str]
subprocess_CompletedProcess_str = subprocess.CompletedProcess[str]
else:
os_PathLike_str = os.PathLike
subprocess_CompletedProcess_str = subprocess.CompletedProcess
async def general_insert(
data_store: DataStore,
tree_id: bytes32,
key: bytes,
value: bytes,
reference_node_hash: bytes32,
side: Optional[Side],
) -> bytes32:
return await data_store.insert(
key=key,
value=value,
tree_id=tree_id,
reference_node_hash=reference_node_hash,
side=side,
status=Status.COMMITTED,
)
@dataclass(frozen=True)
class Example:
expected: Program
terminal_nodes: List[bytes32]
async def add_0123_example(data_store: DataStore, tree_id: bytes32) -> Example:
expected = Program.to(
(
(
(b"\x00", b"\x10\x00"),
(b"\x01", b"\x11\x01"),
),
(
(b"\x02", b"\x12\x02"),
(b"\x03", b"\x13\x03"),
),
),
)
insert = functools.partial(general_insert, data_store=data_store, tree_id=tree_id)
c_hash = await insert(key=b"\x02", value=b"\x12\x02", reference_node_hash=None, side=None)
b_hash = await insert(key=b"\x01", value=b"\x11\x01", reference_node_hash=c_hash, side=Side.LEFT)
d_hash = await insert(key=b"\x03", value=b"\x13\x03", reference_node_hash=c_hash, side=Side.RIGHT)
a_hash = await insert(key=b"\x00", value=b"\x10\x00", reference_node_hash=b_hash, side=Side.LEFT)
return Example(expected=expected, terminal_nodes=[a_hash, b_hash, c_hash, d_hash])
async def add_01234567_example(data_store: DataStore, tree_id: bytes32) -> Example:
expected = Program.to(
(
(
(
(b"\x00", b"\x10\x00"),
(b"\x01", b"\x11\x01"),
),
(
(b"\x02", b"\x12\x02"),
(b"\x03", b"\x13\x03"),
),
),
(
(
(b"\x04", b"\x14\x04"),
(b"\x05", b"\x15\x05"),
),
(
(b"\x06", b"\x16\x06"),
(b"\x07", b"\x17\x07"),
),
),
),
)
insert = functools.partial(general_insert, data_store=data_store, tree_id=tree_id)
g_hash = await insert(key=b"\x06", value=b"\x16\x06", reference_node_hash=None, side=None)
c_hash = await insert(key=b"\x02", value=b"\x12\x02", reference_node_hash=g_hash, side=Side.LEFT)
b_hash = await insert(key=b"\x01", value=b"\x11\x01", reference_node_hash=c_hash, side=Side.LEFT)
d_hash = await insert(key=b"\x03", value=b"\x13\x03", reference_node_hash=c_hash, side=Side.RIGHT)
a_hash = await insert(key=b"\x00", value=b"\x10\x00", reference_node_hash=b_hash, side=Side.LEFT)
f_hash = await insert(key=b"\x05", value=b"\x15\x05", reference_node_hash=g_hash, side=Side.LEFT)
h_hash = await insert(key=b"\x07", value=b"\x17\x07", reference_node_hash=g_hash, side=Side.RIGHT)
e_hash = await insert(key=b"\x04", value=b"\x14\x04", reference_node_hash=f_hash, side=Side.LEFT)
return Example(expected=expected, terminal_nodes=[a_hash, b_hash, c_hash, d_hash, e_hash, f_hash, g_hash, h_hash])
@dataclass
class ChiaRoot:
path: pathlib.Path
scripts_path: pathlib.Path
def run(
self,
args: List[Union[str, os_PathLike_str]],
*other_args: Any,
check: bool = True,
encoding: str = "utf-8",
stdout: Optional[_FILE] = subprocess.PIPE,
stderr: Optional[_FILE] = subprocess.PIPE,
**kwargs: Any,
) -> subprocess_CompletedProcess_str:
# TODO: --root-path doesn't seem to work here...
kwargs.setdefault("env", {})
kwargs["env"]["CHIA_ROOT"] = os.fspath(self.path)
kwargs["env"]["CHIA_KEYS_ROOT"] = os.fspath(self.path)
# This is for windows
if "SYSTEMROOT" in os.environ:
kwargs["env"]["SYSTEMROOT"] = os.environ["SYSTEMROOT"]
modified_args: List[Union[str, os_PathLike_str]] = [
self.scripts_path.joinpath("chia"),
"--root-path",
self.path,
*args,
]
processed_args: List[str] = [os.fspath(element) for element in modified_args]
final_args = [processed_args, *other_args]
kwargs["check"] = check
kwargs["encoding"] = encoding
kwargs["stdout"] = stdout
kwargs["stderr"] = stderr
return subprocess.run(*final_args, **kwargs)
def read_log(self) -> str:
return self.path.joinpath("log", "debug.log").read_text(encoding="utf-8")
def print_log(self) -> None:
log_text: Optional[str]
try:
log_text = self.read_log()
except FileNotFoundError:
log_text = None
if log_text is None:
print(f"---- no log at: {self.path}")
else:
print(f"---- start of: {self.path}")
print(log_text)
print(f"---- end of: {self.path}")
@contextlib.contextmanager
def print_log_after(self) -> Iterator[None]:
try:
yield
finally:
self.print_log()
def create_valid_node_values(
node_type: NodeType,
left_hash: Optional[bytes32] = None,
right_hash: Optional[bytes32] = None,
) -> Dict[str, Any]:
if node_type == NodeType.INTERNAL:
return {
"hash": Program.to((left_hash, right_hash)).get_tree_hash(left_hash, right_hash),
"node_type": node_type,
"left": left_hash,
"right": right_hash,
"key": None,
"value": None,
}
elif node_type == NodeType.TERMINAL:
key = b""
value = b""
return {
"hash": Program.to((key, value)).get_tree_hash(),
"node_type": node_type,
"left": None,
"right": None,
"key": key,
"value": value,
}
raise Exception(f"Unhandled node type: {node_type!r}")

View File

@ -36,7 +36,7 @@ from chia.simulator.time_out_assert import time_out_assert, time_out_assert_not_
log = logging.getLogger(__name__)
FEE_AMOUNT = 2000000000000
MAX_WAIT_SECS = 20 # A high value for WAIT_SECS is useful when paused in the debugger
MAX_WAIT_SECS = 30 # A high value for WAIT_SECS is useful when paused in the debugger
def get_pool_plot_dir():
@ -170,7 +170,7 @@ class TestPoolWalletRpc:
await farm_blocks(full_node_api, our_ph, 6)
assert full_node_api.full_node.mempool_manager.get_spendbundle(creation_tx.name) is None
await time_out_assert(20, wallet_is_synced, True, wallet_node_0, full_node_api)
await time_out_assert(30, wallet_is_synced, True, wallet_node_0, full_node_api)
summaries_response = await client.get_wallets(WalletType.POOLING_WALLET)
assert len(summaries_response) == 1
wallet_id: int = summaries_response[0]["id"]

View File

@ -6,11 +6,15 @@ from pathlib import Path
from chia.consensus.constants import ConsensusConstants
from chia.full_node.full_node_api import FullNodeAPI
from chia.protocols.shared_protocol import Capability
from chia.server.server import ChiaServer
from chia.server.start_data_layer import create_data_layer_service
from chia.server.start_service import Service
from chia.simulator.block_tools import BlockTools, create_block_tools_async, test_constants
from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.types.peer_info import PeerInfo
from chia.util.hash import std_hash
from chia.util.ints import uint16, uint32
from chia.simulator.block_tools import BlockTools, create_block_tools_async, test_constants
from chia.wallet.wallet_node import WalletNode
from tests.setup_services import (
setup_daemon,
setup_farmer,
@ -27,6 +31,9 @@ from tests.util.keyring import TempKeyring
from chia.simulator.socket import find_available_listen_port
SimulatorsAndWallets = Tuple[List[FullNodeSimulator], List[Tuple[WalletNode, ChiaServer]], BlockTools]
def cleanup_keyring(keyring: TempKeyring):
keyring.cleanup()
@ -47,6 +54,36 @@ async def _teardown_nodes(node_aiters: List) -> None:
pass
async def setup_data_layer(local_bt):
# db_path = local_bt.root_path / f"{db_name}"
# if db_path.exists():
# db_path.unlink()
config = local_bt.config["data_layer"]
# config["database_path"] = db_name
# if introducer_port is not None:
# config["introducer_peer"]["host"] = self_hostname
# config["introducer_peer"]["port"] = introducer_port
# else:
# config["introducer_peer"] = None
# config["dns_servers"] = []
# config["rpc_port"] = port + 1000
# overrides = config["network_overrides"]["constants"][config["selected_network"]]
# updated_constants = consensus_constants.replace_str_to_bytes(**overrides)
# if simulator:
# kwargs = service_kwargs_for_full_node_simulator(local_bt.root_path, config, local_bt)
# else:
# kwargs = service_kwargs_for_full_node(local_bt.root_path, config, updated_constants)
service = create_data_layer_service(local_bt.root_path, config, connect_to_daemon=False)
await service.start()
yield service._api
service.stop()
await service.wait_closed()
async def setup_two_nodes(consensus_constants: ConsensusConstants, db_version: int, self_hostname: str):
"""
Setup and teardown of two full nodes, with blockchains and separate DBs.

View File

@ -1,19 +1,26 @@
from typing import Tuple, List
from typing import AsyncIterator, List, Tuple
import pytest
import pytest_asyncio
from chia.cmds.units import units
from chia.consensus.block_rewards import calculate_pool_reward, calculate_base_farmer_reward
from chia.server.server import ChiaServer
from chia.simulator.block_tools import create_block_tools_async, BlockTools
from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.simulator.simulator_protocol import FarmNewBlockProtocol, GetAllCoinsProtocol, ReorgProtocol
from chia.types.peer_info import PeerInfo
from chia.wallet.wallet_node import WalletNode
from chia.simulator.block_tools import create_block_tools_async, BlockTools
from chia.util.ints import uint16, uint32, uint64
from tests.core.node_height import node_height_at_least
from tests.setup_nodes import setup_full_node, setup_full_system, test_constants
from chia.simulator.time_out_assert import time_out_assert
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.wallet_node import WalletNode
from tests.core.node_height import node_height_at_least
from tests.setup_nodes import (
SimulatorsAndWallets,
setup_full_node,
setup_full_system,
test_constants,
setup_simulators_and_wallets,
)
from tests.util.keyring import TempKeyring
test_constants_modified = test_constants.replace(
@ -56,6 +63,12 @@ async def simulation(bt):
yield _
@pytest_asyncio.fixture(scope="function")
async def one_wallet_node() -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(simulator_count=1, wallet_count=1, dic={}):
yield _
class TestSimulation:
@pytest.mark.asyncio
async def test_simulation_1(self, simulation, extra_node, self_hostname):
@ -177,3 +190,228 @@ class TestSimulation:
reorg_spent_and_non_spent_coins = await full_node_api.get_all_coins(GetAllCoinsProtocol(True))
assert len(reorg_non_spent_coins) == 12 and len(reorg_spent_and_non_spent_coins) == 12
assert tx.additions not in spent_and_non_spent_coins # just double check that those got reverted.
@pytest.mark.asyncio
@pytest.mark.parametrize(argnames="count", argvalues=[0, 1, 2, 5, 10])
async def test_simulation_process_blocks(
self,
count,
one_wallet_node: SimulatorsAndWallets,
):
[[full_node_api], _, _] = one_wallet_node
# Starting at the beginning.
assert full_node_api.full_node.blockchain.get_peak_height() is None
await full_node_api.process_blocks(count=count)
# The requested number of blocks had been processed.
expected_height = None if count == 0 else count
assert full_node_api.full_node.blockchain.get_peak_height() == expected_height
@pytest.mark.asyncio
@pytest.mark.parametrize(argnames="count", argvalues=[0, 1, 2, 5, 10])
async def test_simulation_farm_blocks(
self,
count,
one_wallet_node: SimulatorsAndWallets,
):
[[full_node_api], [[wallet_node, wallet_server]], _] = one_wallet_node
await wallet_server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
# Avoiding an attribute error below.
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
# Starting at the beginning.
assert full_node_api.full_node.blockchain.get_peak_height() is None
rewards = await full_node_api.farm_blocks(count=count, wallet=wallet)
# The requested number of blocks had been processed plus 1 to handle the final reward
# transactions in the case of a non-zero count.
expected_height = count
if count > 0:
expected_height += 1
peak_height = full_node_api.full_node.blockchain.get_peak_height()
if peak_height is None:
peak_height = uint32(0)
assert peak_height == expected_height
# The expected rewards have been received and confirmed.
unconfirmed_balance = await wallet.get_unconfirmed_balance()
confirmed_balance = await wallet.get_confirmed_balance()
assert [unconfirmed_balance, confirmed_balance] == [rewards, rewards]
@pytest.mark.asyncio
@pytest.mark.parametrize(
argnames=["amount", "coin_count"],
argvalues=[
[0, 0],
[1, 2],
[(2 * units["chia"]) - 1, 2],
[2 * units["chia"], 2],
[(2 * units["chia"]) + 1, 4],
[3 * units["chia"], 4],
[10 * units["chia"], 10],
],
)
async def test_simulation_farm_rewards(
self,
amount: int,
coin_count: int,
one_wallet_node: SimulatorsAndWallets,
):
[[full_node_api], [[wallet_node, wallet_server]], _] = one_wallet_node
await wallet_server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
# Avoiding an attribute error below.
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
rewards = await full_node_api.farm_rewards(amount=amount, wallet=wallet)
# At least the requested amount was farmed.
assert rewards >= amount
# The rewards amount is both received and confirmed.
unconfirmed_balance = await wallet.get_unconfirmed_balance()
confirmed_balance = await wallet.get_confirmed_balance()
assert [unconfirmed_balance, confirmed_balance] == [rewards, rewards]
# The expected number of coins were received.
spendable_coins = await wallet.wallet_state_manager.get_spendable_coins_for_wallet(wallet.id())
assert len(spendable_coins) == coin_count
@pytest.mark.asyncio
async def test_wait_transaction_records_entered_mempool(
self,
one_wallet_node: SimulatorsAndWallets,
) -> None:
repeats = 50
tx_amount = 1
[[full_node_api], [[wallet_node, wallet_server]], _] = one_wallet_node
await wallet_server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
# Avoiding an attribute hint issue below.
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
# generate some coins for repetitive testing
await full_node_api.farm_rewards(amount=repeats * tx_amount, wallet=wallet)
coins = await full_node_api.create_coins_with_amounts(amounts=[tx_amount] * repeats, wallet=wallet)
assert len(coins) == repeats
# repeating just to try to expose any flakiness
for coin in coins:
tx = await wallet.generate_signed_transaction(
amount=uint64(tx_amount),
puzzle_hash=await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash(),
coins={coin},
)
await wallet.push_transaction(tx)
await full_node_api.wait_transaction_records_entered_mempool(records=[tx])
assert tx.spend_bundle is not None
assert full_node_api.full_node.mempool_manager.get_spendbundle(tx.spend_bundle.name()) is not None
# TODO: this fails but it seems like it shouldn't when above passes
# assert tx.is_in_mempool()
@pytest.mark.asyncio
async def test_process_transaction_records(
self,
one_wallet_node: SimulatorsAndWallets,
) -> None:
repeats = 50
tx_amount = 1
[[full_node_api], [[wallet_node, wallet_server]], _] = one_wallet_node
await wallet_server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
# Avoiding an attribute hint issue below.
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
# generate some coins for repetitive testing
await full_node_api.farm_rewards(amount=repeats * tx_amount, wallet=wallet)
coins = await full_node_api.create_coins_with_amounts(amounts=[tx_amount] * repeats, wallet=wallet)
assert len(coins) == repeats
# repeating just to try to expose any flakiness
for coin in coins:
tx = await wallet.generate_signed_transaction(
amount=uint64(tx_amount),
puzzle_hash=await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash(),
coins={coin},
)
await wallet.push_transaction(tx)
await full_node_api.process_transaction_records(records=[tx])
# TODO: is this the proper check?
assert full_node_api.full_node.coin_store.get_coin_record(coin.name()) is not None
@pytest.mark.asyncio
@pytest.mark.parametrize(
argnames="amounts",
argvalues=[
*[pytest.param([1] * n, id=f"1 mojo x {n}") for n in [0, 1, 10, 49, 51, 103]],
*[pytest.param(list(range(1, n + 1)), id=f"incrementing x {n}") for n in [1, 10, 49, 51, 103]],
],
)
async def test_create_coins_with_amounts(
self,
amounts: List[int],
one_wallet_node: SimulatorsAndWallets,
) -> None:
[[full_node_api], [[wallet_node, wallet_server]], _] = one_wallet_node
await wallet_server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
# Avoiding an attribute hint issue below.
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
await full_node_api.farm_rewards(amount=sum(amounts), wallet=wallet)
# Get some more coins. The creator helper doesn't get you all the coins you
# need yet.
await full_node_api.farm_blocks(count=2, wallet=wallet)
coins = await full_node_api.create_coins_with_amounts(amounts=amounts, wallet=wallet)
assert sorted(coin.amount for coin in coins) == sorted(amounts)
@pytest.mark.asyncio
@pytest.mark.parametrize(
argnames="amounts",
argvalues=[
[0],
[5, -5],
[4, 0],
],
ids=lambda amounts: ", ".join(str(amount) for amount in amounts),
)
async def test_create_coins_with_invalid_amounts_raises(
self,
amounts: List[int],
one_wallet_node: SimulatorsAndWallets,
) -> None:
[[full_node_api], [[wallet_node, wallet_server]], _] = one_wallet_node
await wallet_server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
# Avoiding an attribute hint issue below.
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
with pytest.raises(Exception, match="Coins must have a positive value"):
await full_node_api.create_coins_with_amounts(amounts=amounts, wallet=wallet)

View File

@ -1,4 +1,8 @@
async def validate_get_routes(client, api):
from chia.rpc.rpc_client import RpcClient
from chia.rpc.rpc_server import RpcApiProtocol
async def validate_get_routes(client: RpcClient, api: RpcApiProtocol):
routes_client = (await client.fetch("get_routes", {}))["routes"]
assert len(routes_client) > 0
routes_api = list(api.get_routes().keys())

View File

View File

@ -0,0 +1 @@
checkout_blocks_and_plots = True

View File

@ -0,0 +1,140 @@
from typing import Dict, List, Tuple
import pytest
from blspy import G2Element
from chia.clvm.spend_sim import SimClient, SpendSim
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_spend import CoinSpend
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.util.errors import Err
from chia.wallet.puzzles.load_clvm import load_clvm
from chia.wallet.util.merkle_utils import build_merkle_tree, build_merkle_tree_from_binary_tree, simplify_merkle_proof
GRAFTROOT_MOD = load_clvm("graftroot_dl_offers.clvm")
# Always returns the last value
# (mod solution
#
# (defun recurse (solution last_value)
# (if solution
# (recurse (r solution) (f solution))
# last_value
# )
# )
#
# (recurse solution ())
# )
ACS = Program.fromhex(
"ff02ffff01ff02ff02ffff04ff02ffff04ff03ffff01ff8080808080ffff04ffff01ff02ffff03ff05ffff01ff02ff02ffff04ff02ffff04ff0dffff04ff09ff8080808080ffff010b80ff0180ff018080" # noqa
)
ACS_PH = ACS.get_tree_hash()
NIL_PH = Program.to(None).get_tree_hash()
@pytest.mark.asyncio
async def test_graftroot(setup_sim: Tuple[SpendSim, SimClient]) -> None:
sim, sim_client = setup_sim
try:
# Create the coin we're testing
all_values: List[bytes32] = [bytes32([x] * 32) for x in range(0, 100)]
root, proofs = build_merkle_tree(all_values)
p2_conditions = Program.to((1, [[51, ACS_PH, 0]])) # An coin to create to make sure this hits the blockchain
desired_key_values = ((bytes32([0] * 32), bytes32([1] * 32)), (bytes32([7] * 32), bytes32([8] * 32)))
desired_row_hashes: List[bytes32] = [build_merkle_tree_from_binary_tree(kv)[0] for kv in desired_key_values]
fake_struct: Program = Program.to((ACS_PH, NIL_PH))
graftroot_puzzle: Program = GRAFTROOT_MOD.curry(
# Do everything twice to test depending on multiple singleton updates
p2_conditions,
[fake_struct, fake_struct],
[ACS_PH, ACS_PH],
[desired_row_hashes, desired_row_hashes],
)
await sim.farm_block(graftroot_puzzle.get_tree_hash())
graftroot_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(graftroot_puzzle.get_tree_hash()))[
0
].coin
# Build some merkle trees that won't satidy the requirements
def filter_all(values: List[bytes32]) -> List[bytes32]:
return [h for i, h in enumerate(values) if (h, values[min(i, i + 1)]) not in desired_key_values]
def filter_to_only_one(values: List[bytes32]) -> List[bytes32]:
return [h for i, h in enumerate(values) if (h, values[min(i, i + 1)]) not in desired_key_values[1:]]
# And one that will
def filter_none(values: List[bytes32]) -> List[bytes32]:
return values
for list_filter in (filter_all, filter_to_only_one, filter_none):
# Create the "singleton"
filtered_values = list_filter(all_values)
root, proofs = build_merkle_tree(filtered_values)
filtered_row_hashes: Dict[bytes32, Tuple[int, List[bytes32]]] = {
simplify_merkle_proof(v, (proofs[v][0], [proofs[v][1][0]])): (proofs[v][0] >> 1, proofs[v][1][1:])
for v in filtered_values
}
fake_puzzle: Program = ACS.curry(fake_struct, ACS.curry(ACS_PH, (root, None), NIL_PH, None))
await sim.farm_block(fake_puzzle.get_tree_hash())
fake_coin: Coin = (await sim_client.get_coin_records_by_puzzle_hash(fake_puzzle.get_tree_hash()))[0].coin
# Create the spend
fake_spend = CoinSpend(
fake_coin,
fake_puzzle,
Program.to([[[62, "$"]]]),
)
proofs_of_inclusion = []
for row_hash in desired_row_hashes:
if row_hash in filtered_row_hashes:
proofs_of_inclusion.append(filtered_row_hashes[row_hash])
else:
proofs_of_inclusion.append((0, []))
graftroot_spend = CoinSpend(
graftroot_coin,
graftroot_puzzle,
Program.to(
[
# Again, everything twice
[proofs_of_inclusion] * 2,
[(root, None), (root, None)],
[NIL_PH, NIL_PH],
[NIL_PH, NIL_PH],
[],
]
),
)
final_bundle = SpendBundle([fake_spend, graftroot_spend], G2Element())
result = await sim_client.push_tx(final_bundle)
# If this is the satisfactory merkle tree
if filtered_values == all_values:
assert result == (MempoolInclusionStatus.SUCCESS, None)
# clear the mempool
same_height = sim.block_height
await sim.farm_block()
assert len(await sim_client.get_coin_records_by_puzzle_hash(ACS_PH)) > 0
await sim.rewind(same_height)
# try with a bad merkle root announcement
new_fake_spend = CoinSpend(
fake_coin,
ACS.curry(fake_struct, ACS.curry(ACS_PH, (bytes32([0] * 32), None), None, None)),
Program.to([[[62, "$"]]]),
)
new_final_bundle = SpendBundle([new_fake_spend, graftroot_spend], G2Element())
result = await sim_client.push_tx(new_final_bundle)
assert result == (MempoolInclusionStatus.FAILED, Err.ASSERT_ANNOUNCE_CONSUMED_FAILED)
else:
assert result == (MempoolInclusionStatus.FAILED, Err.GENERATOR_RUNTIME_ERROR)
with pytest.raises(ValueError, match="clvm raise"):
graftroot_puzzle.run(graftroot_spend.solution.to_program())
finally:
await sim.close()

View File

@ -0,0 +1,472 @@
import dataclasses
from typing import Any, List, Tuple
import pytest
from chia.data_layer.data_layer_wallet import DataLayerWallet
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import uint64
from chia.wallet.puzzle_drivers import Solver
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import Offer
from chia.wallet.trading.trade_status import TradeStatus
from chia.wallet.util.merkle_utils import build_merkle_tree, simplify_merkle_proof
async def is_singleton_confirmed_and_root(dl_wallet: DataLayerWallet, lid: bytes32, root: bytes32) -> bool:
rec = await dl_wallet.get_latest_singleton(lid)
if rec is None:
return False
if rec.confirmed is True:
assert rec.confirmed_at_height > 0
assert rec.timestamp > 0
return rec.confirmed and rec.root == root
async def get_trade_and_status(trade_manager: Any, trade: TradeRecord) -> TradeStatus:
trade_rec = await trade_manager.get_trade_by_id(trade.trade_id)
return TradeStatus(trade_rec.status)
def get_parent_branch(value: bytes32, proof: Tuple[int, List[bytes32]]) -> Tuple[bytes32, Tuple[int, List[bytes32]]]:
branch: bytes32 = simplify_merkle_proof(value, (proof[0], [proof[1][0]]))
new_proof: Tuple[int, List[bytes32]] = (proof[0] >> 1, proof[1][1:])
return branch, new_proof
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_dl_offers(wallets_prefarm: Any, trusted: bool) -> None:
wallet_node_maker, wallet_node_taker, full_node_api = wallets_prefarm
assert wallet_node_maker.wallet_state_manager is not None
assert wallet_node_taker.wallet_state_manager is not None
wsm_maker = wallet_node_maker.wallet_state_manager
wsm_taker = wallet_node_taker.wallet_state_manager
wallet_maker = wsm_maker.main_wallet
wallet_taker = wsm_taker.main_wallet
funds = 20000000000000
await time_out_assert(10, wallet_maker.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_taker.get_confirmed_balance, funds)
async with wsm_maker.lock:
dl_wallet_maker = await DataLayerWallet.create_new_dl_wallet(wsm_maker, wallet_maker)
async with wsm_taker.lock:
dl_wallet_taker = await DataLayerWallet.create_new_dl_wallet(wsm_taker, wallet_taker)
MAKER_ROWS = [bytes32([i] * 32) for i in range(0, 10)]
TAKER_ROWS = [bytes32([i] * 32) for i in range(10, 20)]
maker_root, _ = build_merkle_tree(MAKER_ROWS)
taker_root, _ = build_merkle_tree(TAKER_ROWS)
dl_record, std_record, launcher_id_maker = await dl_wallet_maker.generate_new_reporter(
maker_root, fee=uint64(1999999999999)
)
assert await dl_wallet_maker.get_latest_singleton(launcher_id_maker) is not None
await wsm_maker.add_pending_transaction(dl_record)
await wsm_maker.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_maker, maker_root)
dl_record, std_record, launcher_id_taker = await dl_wallet_taker.generate_new_reporter(
taker_root, fee=uint64(1999999999999)
)
assert await dl_wallet_taker.get_latest_singleton(launcher_id_taker) is not None
await wsm_taker.add_pending_transaction(dl_record)
await wsm_taker.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_taker, taker_root)
peer = wallet_node_taker.get_full_node_peer()
assert peer is not None
await dl_wallet_maker.track_new_launcher_id(launcher_id_taker, peer)
await dl_wallet_taker.track_new_launcher_id(launcher_id_maker, peer)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_taker, taker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_maker, maker_root)
trade_manager_maker = wsm_maker.trade_manager
trade_manager_taker = wsm_taker.trade_manager
maker_addition = bytes32([101] * 32)
taker_addition = bytes32([202] * 32)
MAKER_ROWS.append(maker_addition)
TAKER_ROWS.append(taker_addition)
maker_root, maker_proofs = build_merkle_tree(MAKER_ROWS)
taker_root, taker_proofs = build_merkle_tree(TAKER_ROWS)
maker_branch, maker_branch_proof = get_parent_branch(maker_addition, maker_proofs[maker_addition])
taker_branch, taker_branch_proof = get_parent_branch(taker_addition, taker_proofs[taker_addition])
success, offer_maker, error = await trade_manager_maker.create_offer_for_ids(
{launcher_id_maker: -1, launcher_id_taker: 1},
solver=Solver(
{
launcher_id_maker.hex(): {
"new_root": "0x" + maker_root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_taker.hex(),
"values_to_prove": ["0x" + taker_branch.hex()],
},
],
}
}
),
fee=uint64(2000000000000),
)
assert error is None
assert success is True
assert offer_maker is not None
assert await trade_manager_taker.get_offer_summary(Offer.from_bytes(offer_maker.offer)) == {
"offered": [
{
"launcher_id": launcher_id_maker.hex(),
"new_root": maker_root.hex(),
"dependencies": [
{
"launcher_id": launcher_id_taker.hex(),
"values_to_prove": [taker_branch.hex()],
}
],
}
]
}
success, offer_taker, error = await trade_manager_taker.respond_to_offer(
Offer.from_bytes(offer_maker.offer),
peer,
solver=Solver(
{
launcher_id_taker.hex(): {
"new_root": "0x" + taker_root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_maker.hex(),
"values_to_prove": ["0x" + maker_branch.hex()],
},
],
},
"proofs_of_inclusion": [
[
maker_root.hex(),
str(maker_branch_proof[0]),
["0x" + sibling.hex() for sibling in maker_branch_proof[1]],
],
[
taker_root.hex(),
str(taker_branch_proof[0]),
["0x" + sibling.hex() for sibling in taker_branch_proof[1]],
],
],
}
),
fee=uint64(2000000000000),
)
assert error is None
assert success is True
assert offer_taker is not None
assert await trade_manager_maker.get_offer_summary(Offer.from_bytes(offer_taker.offer)) == {
"offered": [
{
"launcher_id": launcher_id_maker.hex(),
"new_root": maker_root.hex(),
"dependencies": [
{
"launcher_id": launcher_id_taker.hex(),
"values_to_prove": [taker_branch.hex()],
}
],
},
{
"launcher_id": launcher_id_taker.hex(),
"new_root": taker_root.hex(),
"dependencies": [
{
"launcher_id": launcher_id_maker.hex(),
"values_to_prove": [maker_branch.hex()],
}
],
},
]
}
await time_out_assert(15, wallet_maker.get_unconfirmed_balance, funds - 2000000000000)
await time_out_assert(15, wallet_taker.get_unconfirmed_balance, funds - 4000000000000)
# Let's hack a way to await this offer's confirmation
offer_record = dataclasses.replace(dl_record, spend_bundle=Offer.from_bytes(offer_taker.offer).bundle)
await full_node_api.process_transaction_records(records=[offer_record])
await time_out_assert(15, wallet_maker.get_confirmed_balance, funds - 4000000000000)
await time_out_assert(15, wallet_maker.get_unconfirmed_balance, funds - 4000000000000)
await time_out_assert(15, wallet_taker.get_confirmed_balance, funds - 4000000000000)
await time_out_assert(15, wallet_taker.get_unconfirmed_balance, funds - 4000000000000)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_taker, taker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_maker, maker_root)
await time_out_assert(15, get_trade_and_status, TradeStatus.CONFIRMED, trade_manager_maker, offer_maker)
await time_out_assert(15, get_trade_and_status, TradeStatus.CONFIRMED, trade_manager_taker, offer_taker)
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_dl_offer_cancellation(wallets_prefarm: Any, trusted: bool) -> None:
wallet_node, _, full_node_api = wallets_prefarm
assert wallet_node.wallet_state_manager is not None
wsm = wallet_node.wallet_state_manager
wallet = wsm.main_wallet
funds = 20000000000000
await time_out_assert(10, wallet.get_unconfirmed_balance, funds)
async with wsm.lock:
dl_wallet = await DataLayerWallet.create_new_dl_wallet(wsm, wallet)
ROWS = [bytes32([i] * 32) for i in range(0, 10)]
root, _ = build_merkle_tree(ROWS)
dl_record, std_record, launcher_id = await dl_wallet.generate_new_reporter(root)
assert await dl_wallet.get_latest_singleton(launcher_id) is not None
await wsm.add_pending_transaction(dl_record)
await wsm.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet, launcher_id, root)
dl_record_2, std_record_2, launcher_id_2 = await dl_wallet.generate_new_reporter(root)
await wsm.add_pending_transaction(dl_record_2)
await wsm.add_pending_transaction(std_record_2)
await full_node_api.process_transaction_records(records=[dl_record_2, std_record_2])
trade_manager = wsm.trade_manager
addition = bytes32([101] * 32)
ROWS.append(addition)
root, proofs = build_merkle_tree(ROWS)
success, offer, error = await trade_manager.create_offer_for_ids(
{launcher_id: -1, launcher_id_2: 1},
solver=Solver(
{
launcher_id.hex(): {
"new_root": "0x" + root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_2.hex(),
"values_to_prove": ["0x" + addition.hex()],
},
],
}
}
),
fee=uint64(2000000000000),
)
assert error is None
assert success is True
assert offer is not None
cancellation_txs = await trade_manager.cancel_pending_offer_safely(offer.trade_id, fee=uint64(2000000000000))
assert len(cancellation_txs) == 3
await time_out_assert(15, get_trade_and_status, TradeStatus.PENDING_CANCEL, trade_manager, offer)
await full_node_api.process_transaction_records(records=cancellation_txs)
await time_out_assert(15, get_trade_and_status, TradeStatus.CANCELLED, trade_manager, offer)
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_multiple_dl_offers(wallets_prefarm: Any, trusted: bool) -> None:
wallet_node_maker, wallet_node_taker, full_node_api = wallets_prefarm
assert wallet_node_maker.wallet_state_manager is not None
assert wallet_node_taker.wallet_state_manager is not None
wsm_maker = wallet_node_maker.wallet_state_manager
wsm_taker = wallet_node_taker.wallet_state_manager
wallet_maker = wsm_maker.main_wallet
wallet_taker = wsm_taker.main_wallet
funds = 20000000000000
await time_out_assert(10, wallet_maker.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_taker.get_confirmed_balance, funds)
async with wsm_maker.lock:
dl_wallet_maker = await DataLayerWallet.create_new_dl_wallet(wsm_maker, wallet_maker)
async with wsm_taker.lock:
dl_wallet_taker = await DataLayerWallet.create_new_dl_wallet(wsm_taker, wallet_taker)
MAKER_ROWS = [bytes32([i] * 32) for i in range(0, 10)]
TAKER_ROWS = [bytes32([i] * 32) for i in range(10, 20)]
maker_root, _ = build_merkle_tree(MAKER_ROWS)
taker_root, _ = build_merkle_tree(TAKER_ROWS)
dl_record, std_record, launcher_id_maker_1 = await dl_wallet_maker.generate_new_reporter(
maker_root, fee=uint64(1999999999999)
)
assert await dl_wallet_maker.get_latest_singleton(launcher_id_maker_1) is not None
await wsm_maker.add_pending_transaction(dl_record)
await wsm_maker.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_maker_1, maker_root)
dl_record, std_record, launcher_id_maker_2 = await dl_wallet_maker.generate_new_reporter(
maker_root, fee=uint64(1999999999999)
)
assert await dl_wallet_maker.get_latest_singleton(launcher_id_maker_2) is not None
await wsm_maker.add_pending_transaction(dl_record)
await wsm_maker.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_maker_2, maker_root)
dl_record, std_record, launcher_id_taker_1 = await dl_wallet_taker.generate_new_reporter(
taker_root, fee=uint64(1999999999999)
)
assert await dl_wallet_taker.get_latest_singleton(launcher_id_taker_1) is not None
await wsm_taker.add_pending_transaction(dl_record)
await wsm_taker.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_taker_1, taker_root)
dl_record, std_record, launcher_id_taker_2 = await dl_wallet_taker.generate_new_reporter(
taker_root, fee=uint64(1999999999999)
)
assert await dl_wallet_taker.get_latest_singleton(launcher_id_taker_2) is not None
await wsm_taker.add_pending_transaction(dl_record)
await wsm_taker.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_taker_2, taker_root)
peer = wallet_node_taker.get_full_node_peer()
assert peer is not None
await dl_wallet_maker.track_new_launcher_id(launcher_id_taker_1, peer)
await dl_wallet_maker.track_new_launcher_id(launcher_id_taker_2, peer)
await dl_wallet_taker.track_new_launcher_id(launcher_id_maker_1, peer)
await dl_wallet_taker.track_new_launcher_id(launcher_id_maker_2, peer)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_taker_1, taker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_taker_2, taker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_maker_1, maker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_maker_2, maker_root)
trade_manager_maker = wsm_maker.trade_manager
trade_manager_taker = wsm_taker.trade_manager
maker_addition = bytes32([101] * 32)
taker_addition = bytes32([202] * 32)
MAKER_ROWS.append(maker_addition)
TAKER_ROWS.append(taker_addition)
maker_root, maker_proofs = build_merkle_tree(MAKER_ROWS)
taker_root, taker_proofs = build_merkle_tree(TAKER_ROWS)
maker_branch, maker_branch_proof = get_parent_branch(maker_addition, maker_proofs[maker_addition])
taker_branch, taker_branch_proof = get_parent_branch(taker_addition, taker_proofs[taker_addition])
success, offer_maker, error = await trade_manager_maker.create_offer_for_ids(
{launcher_id_maker_1: -1, launcher_id_taker_1: 1, launcher_id_maker_2: -1, launcher_id_taker_2: 1},
solver=Solver(
{
launcher_id_maker_1.hex(): {
"new_root": "0x" + maker_root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_taker_1.hex(),
"values_to_prove": ["0x" + taker_branch.hex(), "0x" + taker_branch.hex()],
}
],
},
launcher_id_maker_2.hex(): {
"new_root": "0x" + maker_root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_taker_1.hex(),
"values_to_prove": ["0x" + taker_branch.hex()],
},
{
"launcher_id": "0x" + launcher_id_taker_2.hex(),
"values_to_prove": ["0x" + taker_branch.hex()],
},
],
},
}
),
fee=uint64(2000000000000),
)
assert error is None
assert success is True
assert offer_maker is not None
success, offer_taker, error = await trade_manager_taker.respond_to_offer(
Offer.from_bytes(offer_maker.offer),
peer,
solver=Solver(
{
launcher_id_taker_1.hex(): {
"new_root": "0x" + taker_root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_maker_1.hex(),
"values_to_prove": ["0x" + maker_branch.hex(), "0x" + maker_branch.hex()],
}
],
},
launcher_id_taker_2.hex(): {
"new_root": "0x" + taker_root.hex(),
"dependencies": [
{
"launcher_id": "0x" + launcher_id_maker_1.hex(),
"values_to_prove": ["0x" + maker_branch.hex()],
},
{
"launcher_id": "0x" + launcher_id_maker_2.hex(),
"values_to_prove": ["0x" + maker_branch.hex()],
},
],
},
"proofs_of_inclusion": [
[
maker_root.hex(),
str(maker_branch_proof[0]),
["0x" + sibling.hex() for sibling in maker_branch_proof[1]],
],
[
taker_root.hex(),
str(taker_branch_proof[0]),
["0x" + sibling.hex() for sibling in taker_branch_proof[1]],
],
],
}
),
fee=uint64(2000000000000),
)
assert error is None
assert success is True
assert offer_taker is not None
await time_out_assert(15, wallet_maker.get_unconfirmed_balance, funds - 4000000000000)
await time_out_assert(15, wallet_taker.get_unconfirmed_balance, funds - 6000000000000)
# Let's hack a way to await this offer's confirmation
offer_record = dataclasses.replace(dl_record, spend_bundle=Offer.from_bytes(offer_taker.offer).bundle)
await full_node_api.process_transaction_records(records=[offer_record])
await time_out_assert(15, wallet_maker.get_confirmed_balance, funds - 6000000000000)
await time_out_assert(15, wallet_maker.get_unconfirmed_balance, funds - 6000000000000)
await time_out_assert(15, wallet_taker.get_confirmed_balance, funds - 6000000000000)
await time_out_assert(15, wallet_taker.get_unconfirmed_balance, funds - 6000000000000)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_taker_1, taker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_maker, launcher_id_taker_2, taker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_maker_1, maker_root)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_taker, launcher_id_maker_2, maker_root)
await time_out_assert(15, get_trade_and_status, TradeStatus.CONFIRMED, trade_manager_maker, offer_maker)
await time_out_assert(15, get_trade_and_status, TradeStatus.CONFIRMED, trade_manager_taker, offer_taker)

View File

@ -0,0 +1,595 @@
import asyncio
import dataclasses
from typing import Any, AsyncIterator, Iterator, List
import pytest
import pytest_asyncio
from chia.data_layer.data_layer_wallet import DataLayerWallet, Mirror
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.db_wallet.db_wallet_puzzles import create_mirror_puzzle
from chia.wallet.util.merkle_tree import MerkleTree
from tests.setup_nodes import SimulatorsAndWallets, setup_simulators_and_wallets
pytestmark = pytest.mark.data_layer
@pytest.fixture(scope="module")
def event_loop() -> Iterator[asyncio.AbstractEventLoop]:
loop = asyncio.get_event_loop()
yield loop
async def is_singleton_confirmed(dl_wallet: DataLayerWallet, lid: bytes32) -> bool:
rec = await dl_wallet.get_latest_singleton(lid)
if rec is None:
return False
if rec.confirmed is True:
assert rec.confirmed_at_height > 0
assert rec.timestamp > 0
return rec.confirmed
class TestDLWallet:
@pytest_asyncio.fixture(scope="function")
async def wallet_node(self) -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(1, 1, {}):
yield _
@pytest_asyncio.fixture(scope="function")
async def two_wallet_nodes(self) -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(1, 2, {}):
yield _
@pytest_asyncio.fixture(scope="function")
async def three_wallet_nodes(self) -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(1, 3, {}):
yield _
@pytest_asyncio.fixture(scope="function")
async def two_wallet_nodes_five_freeze(self) -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(1, 2, {}):
yield _
@pytest_asyncio.fixture(scope="function")
async def three_sim_two_wallets(self) -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(3, 2, {}):
yield _
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_initial_creation(self, wallet_node: SimulatorsAndWallets, trusted: bool) -> None:
full_nodes, wallets, _ = wallet_node
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
if trusted:
wallet_node_0.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node_0.config["trusted_peers"] = {}
await server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
funds = await full_node_api.farm_blocks(count=2, wallet=wallet_0)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_0.get_confirmed_balance, funds)
async with wallet_node_0.wallet_state_manager.lock:
dl_wallet = await DataLayerWallet.create_new_dl_wallet(wallet_node_0.wallet_state_manager, wallet_0)
nodes = [Program.to("thing").get_tree_hash(), Program.to([8]).get_tree_hash()]
current_tree = MerkleTree(nodes)
current_root = current_tree.calculate_root()
for i in range(0, 2):
dl_record, std_record, launcher_id = await dl_wallet.generate_new_reporter(
current_root, fee=uint64(1999999999999)
)
assert await dl_wallet.get_latest_singleton(launcher_id) is not None
await wallet_node_0.wallet_state_manager.add_pending_transaction(dl_record)
await wallet_node_0.wallet_state_manager.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet, launcher_id)
await asyncio.sleep(0.5)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, 0)
await time_out_assert(10, wallet_0.get_confirmed_balance, 0)
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_get_owned_singletons(self, wallet_node: SimulatorsAndWallets, trusted: bool) -> None:
full_nodes, wallets, _ = wallet_node
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
if trusted:
wallet_node_0.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node_0.config["trusted_peers"] = {}
await server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
funds = await full_node_api.farm_blocks(count=2, wallet=wallet_0)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_0.get_confirmed_balance, funds)
async with wallet_node_0.wallet_state_manager.lock:
dl_wallet = await DataLayerWallet.create_new_dl_wallet(wallet_node_0.wallet_state_manager, wallet_0)
nodes = [Program.to("thing").get_tree_hash(), Program.to([8]).get_tree_hash()]
current_tree = MerkleTree(nodes)
current_root = current_tree.calculate_root()
expected_launcher_ids = set()
for i in range(0, 2):
dl_record, std_record, launcher_id = await dl_wallet.generate_new_reporter(
current_root, fee=uint64(1999999999999)
)
expected_launcher_ids.add(launcher_id)
assert await dl_wallet.get_latest_singleton(launcher_id) is not None
await wallet_node_0.wallet_state_manager.add_pending_transaction(dl_record)
await wallet_node_0.wallet_state_manager.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet, launcher_id)
await asyncio.sleep(0.5)
owned_singletons = await dl_wallet.get_owned_singletons()
owned_launcher_ids = sorted(singleton.launcher_id for singleton in owned_singletons)
assert owned_launcher_ids == sorted(expected_launcher_ids)
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_tracking_non_owned(self, two_wallet_nodes: SimulatorsAndWallets, trusted: bool) -> None:
full_nodes, wallets, _ = two_wallet_nodes
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_node_1, server_1 = wallets[1]
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
if trusted:
wallet_node_0.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
wallet_node_1.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node_0.config["trusted_peers"] = {}
wallet_node_1.config["trusted_peers"] = {}
await server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
await server_1.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
funds = await full_node_api.farm_blocks(count=2, wallet=wallet_0)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_0.get_confirmed_balance, funds)
async with wallet_node_0.wallet_state_manager.lock:
dl_wallet_0 = await DataLayerWallet.create_new_dl_wallet(wallet_node_0.wallet_state_manager, wallet_0)
async with wallet_node_1.wallet_state_manager.lock:
dl_wallet_1 = await DataLayerWallet.create_new_dl_wallet(wallet_node_1.wallet_state_manager, wallet_1)
nodes = [Program.to("thing").get_tree_hash(), Program.to([8]).get_tree_hash()]
current_tree = MerkleTree(nodes)
current_root = current_tree.calculate_root()
dl_record, std_record, launcher_id = await dl_wallet_0.generate_new_reporter(current_root)
assert await dl_wallet_0.get_latest_singleton(launcher_id) is not None
await wallet_node_0.wallet_state_manager.add_pending_transaction(dl_record)
await wallet_node_0.wallet_state_manager.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet_0, launcher_id)
await asyncio.sleep(0.5)
peer = wallet_node_1.get_full_node_peer()
assert peer is not None
await dl_wallet_1.track_new_launcher_id(launcher_id, peer)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet_1, launcher_id)
await asyncio.sleep(0.5)
for i in range(0, 5):
new_root = MerkleTree([Program.to("root").get_tree_hash()]).calculate_root()
txs = await dl_wallet_0.create_update_state_spend(launcher_id, new_root)
for tx in txs:
await wallet_node_0.wallet_state_manager.add_pending_transaction(tx)
await full_node_api.process_transaction_records(records=txs)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet_0, launcher_id)
await asyncio.sleep(0.5)
async def do_tips_match() -> bool:
latest_singleton_0 = await dl_wallet_0.get_latest_singleton(launcher_id)
latest_singleton_1 = await dl_wallet_1.get_latest_singleton(launcher_id)
return latest_singleton_0 == latest_singleton_1
await time_out_assert(15, do_tips_match, True)
await dl_wallet_1.stop_tracking_singleton(launcher_id)
assert await dl_wallet_1.get_latest_singleton(launcher_id) is None
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_lifecycle(self, wallet_node: SimulatorsAndWallets, trusted: bool) -> None:
full_nodes, wallets, _ = wallet_node
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
if trusted:
wallet_node_0.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node_0.config["trusted_peers"] = {}
await server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
funds = await full_node_api.farm_blocks(count=5, wallet=wallet_0)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_0.get_confirmed_balance, funds)
async with wallet_node_0.wallet_state_manager.lock:
dl_wallet = await DataLayerWallet.create_new_dl_wallet(wallet_node_0.wallet_state_manager, wallet_0)
nodes = [Program.to("thing").get_tree_hash(), Program.to([8]).get_tree_hash()]
current_tree = MerkleTree(nodes)
current_root = current_tree.calculate_root()
dl_record, std_record, launcher_id = await dl_wallet.generate_new_reporter(current_root)
assert await dl_wallet.get_latest_singleton(launcher_id) is not None
await wallet_node_0.wallet_state_manager.add_pending_transaction(dl_record)
await wallet_node_0.wallet_state_manager.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet, launcher_id)
await asyncio.sleep(0.5)
previous_record = await dl_wallet.get_latest_singleton(launcher_id)
assert previous_record is not None
assert previous_record.lineage_proof.amount is not None
new_root = MerkleTree([Program.to("root").get_tree_hash()]).calculate_root()
txs = await dl_wallet.generate_signed_transaction(
[previous_record.lineage_proof.amount],
[previous_record.inner_puzzle_hash],
launcher_id=previous_record.launcher_id,
new_root_hash=new_root,
fee=uint64(1999999999999),
)
assert txs[0].spend_bundle is not None
with pytest.raises(ValueError, match="is currently pending"):
await dl_wallet.generate_signed_transaction(
[previous_record.lineage_proof.amount],
[previous_record.inner_puzzle_hash],
coins=set([txs[0].spend_bundle.removals()[0]]),
fee=uint64(1999999999999),
)
new_record = await dl_wallet.get_latest_singleton(launcher_id)
assert new_record is not None
assert new_record != previous_record
assert not new_record.confirmed
for tx in txs:
await wallet_node_0.wallet_state_manager.add_pending_transaction(tx)
await full_node_api.process_transaction_records(records=txs)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet, launcher_id)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, funds - 2000000000000)
await time_out_assert(10, wallet_0.get_confirmed_balance, funds - 2000000000000)
await asyncio.sleep(0.5)
previous_record = await dl_wallet.get_latest_singleton(launcher_id)
new_root = MerkleTree([Program.to("new root").get_tree_hash()]).calculate_root()
txs = await dl_wallet.create_update_state_spend(launcher_id, new_root)
new_record = await dl_wallet.get_latest_singleton(launcher_id)
assert new_record is not None
assert new_record != previous_record
assert not new_record.confirmed
for tx in txs:
await wallet_node_0.wallet_state_manager.add_pending_transaction(tx)
await full_node_api.process_transaction_records(records=txs)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet, launcher_id)
await asyncio.sleep(0.5)
@pytest.mark.skip(reason="maybe no longer relevant, needs to be rewritten at least")
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_rebase(self, two_wallet_nodes: SimulatorsAndWallets, trusted: bool) -> None:
full_nodes, wallets, _ = two_wallet_nodes
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_node_1, server_1 = wallets[1]
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
if trusted:
wallet_node_0.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
wallet_node_1.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node_0.config["trusted_peers"] = {}
wallet_node_1.config["trusted_peers"] = {}
await server_0.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
await server_1.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
funds = await full_node_api.farm_blocks(count=5, wallet=wallet_0)
await full_node_api.farm_blocks(count=5, wallet=wallet_1)
await time_out_assert(10, wallet_0.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_0.get_confirmed_balance, funds)
await time_out_assert(10, wallet_1.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_1.get_confirmed_balance, funds)
async with wallet_node_0.wallet_state_manager.lock:
dl_wallet_0 = await DataLayerWallet.create_new_dl_wallet(wallet_node_0.wallet_state_manager, wallet_0)
async with wallet_node_1.wallet_state_manager.lock:
dl_wallet_1 = await DataLayerWallet.create_new_dl_wallet(wallet_node_1.wallet_state_manager, wallet_1)
nodes = [Program.to("thing").get_tree_hash(), Program.to([8]).get_tree_hash()]
current_tree = MerkleTree(nodes)
current_root = current_tree.calculate_root()
async def is_singleton_confirmed(wallet: DataLayerWallet, lid: bytes32) -> bool:
latest_singleton = await wallet.get_latest_singleton(lid)
if latest_singleton is None:
return False
return latest_singleton.confirmed
dl_record, std_record, launcher_id = await dl_wallet_0.generate_new_reporter(current_root)
initial_record = await dl_wallet_0.get_latest_singleton(launcher_id)
assert initial_record is not None
await wallet_node_0.wallet_state_manager.add_pending_transaction(dl_record)
await wallet_node_0.wallet_state_manager.add_pending_transaction(std_record)
await asyncio.wait_for(full_node_api.process_transaction_records(records=[dl_record, std_record]), timeout=15)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet_0, launcher_id)
await asyncio.sleep(0.5)
peer = wallet_node_1.get_full_node_peer()
assert peer is not None
await dl_wallet_1.track_new_launcher_id(launcher_id, peer)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet_1, launcher_id)
current_record = await dl_wallet_1.get_latest_singleton(launcher_id)
assert current_record is not None
await asyncio.sleep(0.5)
# Because these have the same fee, the one that gets pushed first will win
report_txs = await dl_wallet_1.create_update_state_spend(
launcher_id, current_record.root, fee=uint64(2000000000000)
)
record_1 = await dl_wallet_1.get_latest_singleton(launcher_id)
assert record_1 is not None
assert current_record != record_1
update_txs = await dl_wallet_0.create_update_state_spend(
launcher_id, bytes32([0] * 32), fee=uint64(2000000000000)
)
record_0 = await dl_wallet_0.get_latest_singleton(launcher_id)
assert record_0 is not None
assert initial_record != record_0
assert record_0 != record_1
for tx in report_txs:
await wallet_node_1.wallet_state_manager.add_pending_transaction(tx)
await asyncio.wait_for(full_node_api.wait_transaction_records_entered_mempool(records=report_txs), timeout=15)
for tx in update_txs:
await wallet_node_0.wallet_state_manager.add_pending_transaction(tx)
await asyncio.wait_for(full_node_api.process_transaction_records(records=report_txs), timeout=15)
funds -= 2000000000001
async def is_singleton_generation(wallet: DataLayerWallet, launcher_id: bytes32, generation: int) -> bool:
latest = await wallet.get_latest_singleton(launcher_id)
if latest is not None and latest.generation == generation:
return True
return False
next_generation = current_record.generation + 2
await time_out_assert(15, is_singleton_generation, True, dl_wallet_0, launcher_id, next_generation)
for i in range(0, 2):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32(32 * b"0")))
await asyncio.sleep(0.5)
await time_out_assert(15, is_singleton_confirmed, True, dl_wallet_0, launcher_id)
await time_out_assert(15, is_singleton_generation, True, dl_wallet_1, launcher_id, next_generation)
latest = await dl_wallet_0.get_latest_singleton(launcher_id)
assert latest is not None
assert latest == (await dl_wallet_1.get_latest_singleton(launcher_id))
await time_out_assert(15, wallet_0.get_confirmed_balance, funds)
await time_out_assert(15, wallet_0.get_unconfirmed_balance, funds)
assert (
len(
await dl_wallet_0.get_history(
launcher_id, min_generation=uint32(next_generation - 1), max_generation=uint32(next_generation - 1)
)
)
== 1
)
for tx in update_txs:
assert await wallet_node_0.wallet_state_manager.tx_store.get_transaction_record(tx.name) is None
assert await dl_wallet_0.get_singleton_record(record_0.coin_id) is None
update_txs_1 = await dl_wallet_0.create_update_state_spend(
launcher_id, bytes32([1] * 32), fee=uint64(2000000000000)
)
record_1 = await dl_wallet_0.get_latest_singleton(launcher_id)
assert record_1 is not None
for tx in update_txs_1:
await wallet_node_0.wallet_state_manager.add_pending_transaction(tx)
await full_node_api.wait_transaction_records_entered_mempool(update_txs_1)
# Delete any trace of that update
await wallet_node_0.wallet_state_manager.dl_store.delete_singleton_record(record_1.coin_id)
for tx in update_txs_1:
await wallet_node_0.wallet_state_manager.tx_store.delete_transaction_record(tx.name)
update_txs_0 = await dl_wallet_0.create_update_state_spend(launcher_id, bytes32([2] * 32))
record_0 = await dl_wallet_0.get_latest_singleton(launcher_id)
assert record_0 is not None
assert record_0 != record_1
for tx in update_txs_0:
await wallet_node_0.wallet_state_manager.add_pending_transaction(tx)
await asyncio.wait_for(full_node_api.process_transaction_records(records=update_txs_1), timeout=15)
async def does_singleton_have_root(wallet: DataLayerWallet, lid: bytes32, root: bytes32) -> bool:
latest_singleton = await wallet.get_latest_singleton(lid)
if latest_singleton is None:
return False
return latest_singleton.root == root
funds -= 2000000000000
next_generation += 1
await time_out_assert(15, is_singleton_generation, True, dl_wallet_0, launcher_id, next_generation)
await time_out_assert(15, does_singleton_have_root, True, dl_wallet_0, launcher_id, bytes32([1] * 32))
await time_out_assert(15, wallet_0.get_confirmed_balance, funds)
await time_out_assert(15, wallet_0.get_unconfirmed_balance, funds)
assert (
len(
await dl_wallet_0.get_history(
launcher_id, min_generation=uint32(next_generation), max_generation=uint32(next_generation)
)
)
== 1
)
for tx in update_txs_0:
assert await wallet_node_0.wallet_state_manager.tx_store.get_transaction_record(tx.name) is None
assert await dl_wallet_0.get_singleton_record(record_0.coin_id) is None
async def is_singleton_confirmed_and_root(dl_wallet: DataLayerWallet, lid: bytes32, root: bytes32) -> bool:
rec = await dl_wallet.get_latest_singleton(lid)
if rec is None:
return False
if rec.confirmed is True:
assert rec.confirmed_at_height > 0
assert rec.timestamp > 0
return rec.confirmed and rec.root == root
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_mirrors(wallets_prefarm: Any, trusted: bool) -> None:
wallet_node_1, wallet_node_2, full_node_api = wallets_prefarm
assert wallet_node_1.wallet_state_manager is not None
assert wallet_node_2.wallet_state_manager is not None
wsm_1 = wallet_node_1.wallet_state_manager
wsm_2 = wallet_node_2.wallet_state_manager
wallet_1 = wsm_1.main_wallet
wallet_2 = wsm_2.main_wallet
funds = 20000000000000
await time_out_assert(10, wallet_1.get_unconfirmed_balance, funds)
await time_out_assert(10, wallet_2.get_confirmed_balance, funds)
async with wsm_1.lock:
dl_wallet_1 = await DataLayerWallet.create_new_dl_wallet(wsm_1, wallet_1)
async with wsm_2.lock:
dl_wallet_2 = await DataLayerWallet.create_new_dl_wallet(wsm_2, wallet_2)
dl_record, std_record, launcher_id_1 = await dl_wallet_1.generate_new_reporter(bytes32([0] * 32))
assert await dl_wallet_1.get_latest_singleton(launcher_id_1) is not None
await wsm_1.add_pending_transaction(dl_record)
await wsm_1.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_1, launcher_id_1, bytes32([0] * 32))
dl_record, std_record, launcher_id_2 = await dl_wallet_2.generate_new_reporter(bytes32([0] * 32))
assert await dl_wallet_2.get_latest_singleton(launcher_id_2) is not None
await wsm_2.add_pending_transaction(dl_record)
await wsm_2.add_pending_transaction(std_record)
await full_node_api.process_transaction_records(records=[dl_record, std_record])
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_2, launcher_id_2, bytes32([0] * 32))
peer_1 = wallet_node_1.get_full_node_peer()
assert peer_1 is not None
await dl_wallet_1.track_new_launcher_id(launcher_id_2, peer_1)
peer_2 = wallet_node_2.get_full_node_peer()
assert peer_2 is not None
await dl_wallet_2.track_new_launcher_id(launcher_id_1, peer_2)
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_1, launcher_id_2, bytes32([0] * 32))
await time_out_assert(15, is_singleton_confirmed_and_root, True, dl_wallet_2, launcher_id_1, bytes32([0] * 32))
txs = await dl_wallet_1.create_new_mirror(launcher_id_2, uint64(3), [b"foo", b"bar"], fee=uint64(1999999999999))
additions: List[Coin] = []
for tx in txs:
if tx.spend_bundle is not None:
additions.extend(tx.spend_bundle.additions())
await wsm_1.add_pending_transaction(tx)
await full_node_api.process_transaction_records(records=txs)
mirror_coin: Coin = [c for c in additions if c.puzzle_hash == create_mirror_puzzle().get_tree_hash()][0]
mirror = Mirror(
bytes32(mirror_coin.name()), bytes32(launcher_id_2), uint64(mirror_coin.amount), [b"foo", b"bar"], True
)
await time_out_assert(15, dl_wallet_1.get_mirrors_for_launcher, [mirror], launcher_id_2)
await time_out_assert(
15, dl_wallet_2.get_mirrors_for_launcher, [dataclasses.replace(mirror, ours=False)], launcher_id_2
)
txs = await dl_wallet_1.delete_mirror(mirror.coin_id, peer_1, fee=uint64(2000000000000))
for tx in txs:
await wsm_1.add_pending_transaction(tx)
await full_node_api.process_transaction_records(records=txs)
await time_out_assert(15, dl_wallet_1.get_mirrors_for_launcher, [], launcher_id_2)
await time_out_assert(15, dl_wallet_2.get_mirrors_for_launcher, [], launcher_id_2)

View File

@ -0,0 +1,267 @@
import asyncio
import logging
from typing import AsyncIterator
import pytest
import pytest_asyncio
from chia.consensus.block_rewards import calculate_base_farmer_reward, calculate_pool_reward
from chia.data_layer.data_layer_wallet import Mirror, SingletonRecord
from chia.rpc.full_node_rpc_api import FullNodeRpcApi
from chia.rpc.rpc_server import start_rpc_server
from chia.rpc.wallet_rpc_api import WalletRpcApi
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.simulator.time_out_assert import time_out_assert
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.peer_info import PeerInfo
from chia.util.ints import uint16, uint32, uint64
from chia.wallet.db_wallet.db_wallet_puzzles import create_mirror_puzzle
from tests.setup_nodes import SimulatorsAndWallets, setup_simulators_and_wallets
from tests.util.rpc import validate_get_routes
log = logging.getLogger(__name__)
class TestWalletRpc:
@pytest_asyncio.fixture(scope="function")
async def two_wallet_nodes(self) -> AsyncIterator[SimulatorsAndWallets]:
async for _ in setup_simulators_and_wallets(1, 2, {}):
yield _
@pytest.mark.parametrize(
"trusted",
[True, False],
)
@pytest.mark.asyncio
async def test_wallet_make_transaction(
self, two_wallet_nodes: SimulatorsAndWallets, trusted: bool, self_hostname: str
) -> None:
num_blocks = 5
full_nodes, wallets, bt = two_wallet_nodes
full_node_api = full_nodes[0]
full_node_server = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
wallet_node_2, server_3 = wallets[1]
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
if trusted:
wallet_node.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
wallet_node_2.config["trusted_peers"] = {full_node_server.node_id.hex(): full_node_server.node_id.hex()}
else:
wallet_node.config["trusted_peers"] = {}
wallet_node_2.config["trusted_peers"] = {}
await server_2.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
await server_3.start_client(PeerInfo("localhost", uint16(full_node_server._port)), None)
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
initial_funds = sum(
[calculate_pool_reward(uint32(i)) + calculate_base_farmer_reward(uint32(i)) for i in range(1, num_blocks)]
)
wallet_rpc_api = WalletRpcApi(wallet_node)
wallet_rpc_api_2 = WalletRpcApi(wallet_node_2)
config = bt.config
hostname = config["self_hostname"]
daemon_port = config["daemon_port"]
def stop_node_cb() -> None:
pass
full_node_rpc_api = FullNodeRpcApi(full_node_api.full_node)
rpc_cleanup_node, node_rpc_port = await start_rpc_server(
full_node_rpc_api,
hostname,
daemon_port,
uint16(0),
stop_node_cb,
bt.root_path,
config,
connect_to_daemon=False,
)
rpc_cleanup_wallet, wallet_1_rpc_port = await start_rpc_server(
wallet_rpc_api,
hostname,
daemon_port,
uint16(0),
stop_node_cb,
bt.root_path,
config,
connect_to_daemon=False,
)
rpc_cleanup_wallet_2, wallet_2_rpc_port = await start_rpc_server(
wallet_rpc_api_2,
hostname,
daemon_port,
uint16(0),
stop_node_cb,
bt.root_path,
config,
connect_to_daemon=False,
)
await time_out_assert(15, wallet.get_confirmed_balance, initial_funds)
await time_out_assert(15, wallet.get_unconfirmed_balance, initial_funds)
client = await WalletRpcClient.create(self_hostname, wallet_1_rpc_port, bt.root_path, config)
await validate_get_routes(client, wallet_rpc_api)
client_2 = await WalletRpcClient.create(self_hostname, wallet_2_rpc_port, bt.root_path, config)
await validate_get_routes(client_2, wallet_rpc_api_2)
try:
merkle_root: bytes32 = bytes32([0] * 32)
txs, launcher_id = await client.create_new_dl(merkle_root, uint64(50))
for i in range(0, 5):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await asyncio.sleep(0.5)
async def is_singleton_confirmed(rpc_client: WalletRpcClient, lid: bytes32) -> bool:
rec = await rpc_client.dl_latest_singleton(lid)
if rec is None:
return False
return rec.confirmed
await time_out_assert(15, is_singleton_confirmed, True, client, launcher_id)
singleton_record: SingletonRecord = await client.dl_latest_singleton(launcher_id)
assert singleton_record.root == merkle_root
new_root: bytes32 = bytes32([1] * 32)
await client.dl_update_root(launcher_id, new_root, uint64(100))
for i in range(0, 5):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await asyncio.sleep(0.5)
new_singleton_record: SingletonRecord = await client.dl_latest_singleton(launcher_id)
assert new_singleton_record.root == new_root
assert new_singleton_record.confirmed
assert await client.dl_history(launcher_id) == [new_singleton_record, singleton_record]
await client_2.dl_track_new(launcher_id)
async def is_singleton_generation(rpc_client: WalletRpcClient, lid: bytes32, generation: int) -> bool:
if await is_singleton_confirmed(rpc_client, lid):
rec = await rpc_client.dl_latest_singleton(lid)
if rec is None:
raise Exception("No latest singleton for: {lid!r}")
return rec.generation == generation
else:
return False
await time_out_assert(15, is_singleton_generation, True, client_2, launcher_id, 1)
assert await client_2.dl_history(launcher_id) == [new_singleton_record, singleton_record]
assert await client.dl_history(launcher_id, min_generation=uint32(1)) == [new_singleton_record]
assert await client.dl_history(launcher_id, max_generation=uint32(0)) == [singleton_record]
assert await client.dl_history(launcher_id, num_results=uint32(1)) == [new_singleton_record]
assert await client.dl_history(launcher_id, num_results=uint32(2)) == [
new_singleton_record,
singleton_record,
]
assert (
await client.dl_history(
launcher_id,
min_generation=uint32(1),
max_generation=uint32(1),
)
== [new_singleton_record]
)
assert (
await client.dl_history(
launcher_id,
max_generation=uint32(0),
num_results=uint32(1),
)
== [singleton_record]
)
assert (
await client.dl_history(
launcher_id,
min_generation=uint32(1),
num_results=uint32(1),
)
== [new_singleton_record]
)
assert (
await client.dl_history(
launcher_id,
min_generation=uint32(1),
max_generation=uint32(1),
num_results=uint32(1),
)
== [new_singleton_record]
)
assert await client.dl_singletons_by_root(launcher_id, new_root) == [new_singleton_record]
txs, launcher_id_2 = await client.create_new_dl(merkle_root, uint64(50))
txs, launcher_id_3 = await client.create_new_dl(merkle_root, uint64(50))
for i in range(0, 5):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await asyncio.sleep(0.5)
await time_out_assert(15, is_singleton_confirmed, True, client, launcher_id_2)
await time_out_assert(15, is_singleton_confirmed, True, client, launcher_id_3)
next_root = bytes32([2] * 32)
await client.dl_update_multiple(
{
launcher_id: next_root,
launcher_id_2: next_root,
launcher_id_3: next_root,
}
)
for i in range(0, 5):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await asyncio.sleep(0.5)
await time_out_assert(15, is_singleton_confirmed, True, client, launcher_id)
await time_out_assert(15, is_singleton_confirmed, True, client, launcher_id_2)
await time_out_assert(15, is_singleton_confirmed, True, client, launcher_id_3)
for lid in [launcher_id, launcher_id_2, launcher_id_3]:
rec = await client.dl_latest_singleton(lid)
assert rec.root == next_root
await client_2.dl_stop_tracking(launcher_id)
assert await client_2.dl_latest_singleton(lid) is None
owned_singletons = await client.dl_owned_singletons()
owned_launcher_ids = sorted(singleton.launcher_id for singleton in owned_singletons)
assert owned_launcher_ids == sorted([launcher_id, launcher_id_2, launcher_id_3])
txs = await client.dl_new_mirror(launcher_id, uint64(1000), [b"foo", b"bar"], fee=uint64(2000000000000))
for i in range(0, 5):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await asyncio.sleep(0.5)
additions = []
for tx in txs:
if tx.spend_bundle is not None:
additions.extend(tx.spend_bundle.additions())
mirror_coin = [c for c in additions if c.puzzle_hash == create_mirror_puzzle().get_tree_hash()][0]
mirror = Mirror(mirror_coin.name(), launcher_id, uint64(1000), [b"foo", b"bar"], True)
await time_out_assert(15, client.dl_get_mirrors, [mirror], launcher_id)
await client.dl_delete_mirror(mirror_coin.name(), fee=uint64(2000000000000))
for i in range(0, 5):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(bytes32([0] * 32)))
await asyncio.sleep(0.5)
await time_out_assert(15, client.dl_get_mirrors, [], launcher_id)
finally:
# Checks that the RPC manages to stop the node
client.close()
await client.await_closed()
await rpc_cleanup_node()
await rpc_cleanup_wallet()
await rpc_cleanup_wallet_2()

View File

@ -671,7 +671,9 @@ async def test_offer_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment)
assert offer is not None
summary = await wallet_1_rpc.get_offer_summary(offer)
advanced_summary = await wallet_1_rpc.get_offer_summary(offer, advanced=True)
assert summary == {"offered": {"xch": 5}, "requested": {cat_asset_id.hex(): 1}, "infos": driver_dict, "fees": 1}
assert advanced_summary == summary
assert await wallet_1_rpc.check_offer_validity(offer)

View File

@ -148,8 +148,8 @@ class TestWalletSimulator:
]
)
await time_out_assert(20, wallet.get_confirmed_balance, new_funds - 10)
await time_out_assert(20, wallet.get_unconfirmed_balance, new_funds - 10)
await time_out_assert(30, wallet.get_confirmed_balance, new_funds - 10)
await time_out_assert(30, wallet.get_unconfirmed_balance, new_funds - 10)
@pytest.mark.parametrize(
"trusted",