Merge remote-tracking branch 'origin/long_lived/atari' into dl_to_nft_chialisp

This commit is contained in:
Matt Hauff 2022-07-28 15:44:03 -05:00
commit 6206a6a95b
No known key found for this signature in database
GPG Key ID: 3CBA6CFC81A00E46
71 changed files with 848 additions and 950 deletions

View File

@ -1,4 +1,4 @@
name: Benchmarks
name: ⚡️ Benchmarks
on:
push:

View File

@ -1,4 +1,4 @@
name: Build Installer - Linux DEB ARM64
name: 📦🚀 Build Installer - Linux DEB ARM64
on:
push:
@ -158,11 +158,6 @@ jobs:
sha256sum $GITHUB_WORKSPACE/build_scripts/final_installer/chia-blockchain-cli_${CHIA_INSTALLER_VERSION}-1_arm64.deb > $GITHUB_WORKSPACE/build_scripts/final_installer/chia-blockchain-cli_${CHIA_INSTALLER_VERSION}-1_arm64.deb.sha256
ls $GITHUB_WORKSPACE/build_scripts/final_installer/
- name: Install py3createtorrent
if: startsWith(github.ref, 'refs/tags/')
run: |
pip3 install py3createtorrent
- name: Create torrent
if: startsWith(github.ref, 'refs/tags/')
env:

View File

@ -1,4 +1,4 @@
name: Build Installer - Linux DEB AMD64
name: 📦🚀 Build Installer - Linux DEB AMD64
on:
workflow_dispatch:
@ -199,11 +199,6 @@ jobs:
sha256sum ${{ github.workspace }}/build_scripts/final_installer/chia-blockchain-cli_${CHIA_INSTALLER_VERSION}-1_amd64.deb > ${{ github.workspace }}/build_scripts/final_installer/chia-blockchain-cli_${CHIA_INSTALLER_VERSION}-1_amd64.deb.sha256
ls ${{ github.workspace }}/build_scripts/final_installer/
- name: Install py3createtorrent
if: startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main'
run: |
pip3 install py3createtorrent
- name: Create .deb torrent
env:
CHIA_INSTALLER_VERSION: ${{ steps.version_number.outputs.CHIA_INSTALLER_VERSION }}

View File

@ -1,4 +1,4 @@
name: Build Installer - Linux RPM AMD64
name: 📦🚀 Build Installer - Linux RPM AMD64
on:
workflow_dispatch:
@ -159,11 +159,6 @@ jobs:
sha256sum $GITHUB_WORKSPACE/build_scripts/final_installer/chia-blockchain-cli-${CHIA_INSTALLER_VERSION}-1.x86_64.rpm > $GITHUB_WORKSPACE/build_scripts/final_installer/chia-blockchain-cli-${CHIA_INSTALLER_VERSION}-1.x86_64.rpm.sha256
ls $GITHUB_WORKSPACE/build_scripts/final_installer/
- name: Install py3createtorrent
if: startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main'
run: |
pip3 install py3createtorrent
- name: Create .rpm torrent
if: startsWith(github.ref, 'refs/tags/')
env:

View File

@ -1,4 +1,4 @@
name: Build Installer - MacOS Intel
name: 📦🚀 Build Installer - MacOS Intel
on:
push:
@ -173,11 +173,6 @@ jobs:
echo "CHIA_DEV_BUILD=$CHIA_DEV_BUILD" >>$GITHUB_ENV
aws s3 cp ${{ github.workspace }}/build_scripts/final_installer/Chia-${{ steps.version_number.outputs.CHIA_INSTALLER_VERSION }}.dmg s3://download.chia.net/dev/Chia-${CHIA_DEV_BUILD}.dmg
- name: Install py3createtorrent
if: startsWith(github.ref, 'refs/tags/')
run: |
pip install py3createtorrent
- name: Create torrent
if: startsWith(github.ref, 'refs/tags/')
run: |

View File

@ -1,4 +1,4 @@
name: Build Installer - MacOS arm64
name: 📦🚀 Build Installer - MacOS arm64
on:
push:
@ -152,11 +152,6 @@ jobs:
echo "CHIA_DEV_BUILD=$CHIA_DEV_BUILD" >>$GITHUB_ENV
arch -arm64 aws s3 cp ${{ github.workspace }}/build_scripts/final_installer/Chia-${CHIA_INSTALLER_VERSION}-arm64.dmg s3://download.chia.net/dev/Chia-${CHIA_DEV_BUILD}-arm64.dmg
- name: Install py3createtorrent
if: startsWith(github.ref, 'refs/tags/')
run: |
arch -arm64 pip install py3createtorrent
- name: Create torrent
if: startsWith(github.ref, 'refs/tags/')
run: |

View File

@ -1,4 +1,4 @@
name: Build Installer - Windows 10
name: 📦🚀 Build Installer - Windows 10
on:
push:

View File

@ -1,4 +1,4 @@
name: Check Dependency Artifacts
name: 🚨 Check Dependency Artifacts
on:
push:

View File

@ -9,7 +9,7 @@
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
name: 🚨 CodeQL
on:
push:

View File

@ -1,4 +1,4 @@
name: "Conflict Check"
name: 🩹 Conflict Check
on:
# So that PRs touching the same files as the push are updated
push:

View File

@ -1,4 +1,4 @@
name: pre-commit
name: 🚨 pre-commit
on:
pull_request:
@ -15,6 +15,7 @@ concurrency:
jobs:
pre-commit:
name: pre-commit
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3

View File

@ -4,7 +4,7 @@
## Linter GitHub Actions ##
###########################
###########################
name: GitHub Super Linter
name: 🚨 GitHub Super Linter
#
# Documentation:
@ -64,9 +64,6 @@ jobs:
DEFAULT_BRANCH: main
LINTER_RULES_PATH: .
MARKDOWN_CONFIG_FILE: .markdown-lint.yml
# PYTHON_FLAKE8_CONFIG_FILE: .flake8
PYTHON_ISORT_CONFIG_FILE: .isort.cfg
PYTHON_PYLINT_CONFIG_FILE: pylintrc
VALIDATE_BASH: true
VALIDATE_CSS: true
VALIDATE_DOCKER: true
@ -76,9 +73,6 @@ jobs:
VALIDATE_JSON: true
VALIDATE_MD: true
VALIDATE_POWERSHELL: true
VALIDATE_PYTHON_PYLINT: true
# VALIDATE_PYTHON_FLAKE8: true
# VALIDATE_PYTHON_ISORT: true
VALIDATE_SHELL_SHFMT: true
VALIDATE_TYPESCRIPT_ES: true
VALIDATE_YAML: true

View File

@ -1,4 +1,4 @@
name: Test Install Scripts
name: 🏗️ Test Install Scripts
on:
push:
@ -107,14 +107,6 @@ jobs:
type: ubuntu
# https://packages.ubuntu.com/focal/python3 (20.04, 3.8)
url: "docker://ubuntu:focal"
- name: ubuntu:hirsute (21.04)
type: ubuntu
# https://packages.ubuntu.com/hirsute/python3 (21.04, 3.9)
url: "docker://ubuntu:hirsute"
- name: ubuntu:impish (21.10)
type: ubuntu
# https://packages.ubuntu.com/impish/python3 (21.10, 3.9)
url: "docker://ubuntu:impish"
- name: ubuntu:jammy (22.04)
type: ubuntu
# https://packages.ubuntu.com/jammy/python3 (22.04, 3.10)

View File

@ -1,4 +1,4 @@
name: test
name: 🧪 test
on:
push:

View File

@ -1,4 +1,4 @@
name: Trigger Dev Docker Build
name: 📦🚀 Trigger Dev Docker Build
on:
push:

View File

@ -1,4 +1,4 @@
name: Trigger Main Docker Build
name: 📦🚀 Trigger Main Docker Build
on:
push:

View File

@ -1,4 +1,4 @@
name: Lint and upload source distribution
name: 🚨🚀 Lint and upload source distribution
on:
push:
@ -62,6 +62,10 @@ jobs:
run: |
mypy
- name: Lint source with pylint
run: |
pylint benchmarks build_scripts chia tests tools *.py
- name: Build source distribution
run: |
python -m build --sdist --outdir dist .

View File

@ -53,7 +53,7 @@ to configure how the tests are run. For example, for more logging: change the lo
```bash
sh install.sh -d
. ./activate
black . && isort benchmarks build_scripts chia tests tools *.py && mypy && flake8 benchmarks build_scripts chia tests tools *.py
black . && isort benchmarks build_scripts chia tests tools *.py && mypy && flake8 benchmarks build_scripts chia tests tools *.py && pylint benchmarks build_scripts chia tests tools *.py
py.test tests -v --durations 0
```
@ -61,6 +61,7 @@ The [black library](https://black.readthedocs.io/en/stable/) is used as an autom
The [flake8 library](https://readthedocs.org/projects/flake8/) helps ensure consistent style.
The [Mypy library](https://mypy.readthedocs.io/en/stable/) is very useful for ensuring objects are of the correct type, so try to always add the type of the return value, and the type of local variables.
The [isort library](https://isort.readthedocs.io) is used to sort, group and validate imports in all python files.
The [pylint library](https://pylint.pycqa.org/en/stable/) is used to further lint all python files.
If you want verbose logging for tests, edit the `tests/pytest.ini` file.

View File

@ -6216,9 +6216,9 @@
}
},
"node_modules/plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"dependencies": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"
@ -12928,9 +12928,9 @@
}
},
"plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"requires": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"

View File

@ -7311,9 +7311,9 @@
}
},
"node_modules/plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"dependencies": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"
@ -14931,9 +14931,9 @@
}
},
"plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"requires": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"

View File

@ -7279,9 +7279,9 @@
}
},
"node_modules/plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"dependencies": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"
@ -14859,9 +14859,9 @@
}
},
"plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"requires": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"

View File

@ -5996,9 +5996,9 @@
}
},
"node_modules/plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"dependencies": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"
@ -12495,9 +12495,9 @@
}
},
"plist": {
"version": "3.0.4",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.4.tgz",
"integrity": "sha512-ksrr8y9+nXOxQB2osVNqrgvX/XQPOXaU4BQMKjYq8PvaY1U18mo+fKgBSwzK+luSyinOuPae956lSVcBwxlAMg==",
"version": "3.0.5",
"resolved": "https://registry.npmjs.org/plist/-/plist-3.0.5.tgz",
"integrity": "sha512-83vX4eYdQp3vP9SxuYgEM/G/pJQqLUz/V/xzPrzruLs7fz7jxGQ1msZ/mg1nwZxUSuOp4sb+/bEIbRrbzZRxDA==",
"requires": {
"base64-js": "^1.5.1",
"xmlbuilder": "^9.0.7"

View File

@ -1,12 +1,14 @@
import aiosqlite
import random
from pathlib import Path
from dataclasses import dataclass
from typing import Optional, List, Dict, Tuple, Any
from typing import Optional, List, Dict, Tuple, Any, Type, TypeVar
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program, SerializedProgram
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint64, uint32
from chia.util.hash import std_hash
from chia.util.errors import Err, ValidationError
@ -45,6 +47,9 @@ class SimFullBlock(Streamable):
height: uint32 # Note that height is not on a regular FullBlock
_T_SimBlockRecord = TypeVar("_T_SimBlockRecord", bound="SimBlockRecord")
@streamable
@dataclass(frozen=True)
class SimBlockRecord(Streamable):
@ -57,7 +62,7 @@ class SimBlockRecord(Streamable):
prev_transaction_block_hash: bytes32
@classmethod
def create(cls, rci: List[Coin], height: uint32, timestamp: uint64):
def create(cls: Type[_T_SimBlockRecord], rci: List[Coin], height: uint32, timestamp: uint64) -> _T_SimBlockRecord:
return cls(
rci,
height,
@ -78,6 +83,9 @@ class SimStore(Streamable):
blocks: List[SimFullBlock]
_T_SpendSim = TypeVar("_T_SpendSim", bound="SpendSim")
class SpendSim:
db_wrapper: DBWrapper2
@ -89,7 +97,9 @@ class SpendSim:
defaults: ConsensusConstants
@classmethod
async def create(cls, db_path=None, defaults=DEFAULT_CONSTANTS):
async def create(
cls: Type[_T_SpendSim], db_path: Optional[Path] = None, defaults: ConsensusConstants = DEFAULT_CONSTANTS
) -> _T_SpendSim:
self = cls()
if db_path is None:
uri = f"file:db_{random.randint(0, 99999999)}?mode=memory&cache=shared"
@ -114,15 +124,16 @@ class SpendSim:
self.block_height = store_data.block_height
self.block_records = store_data.block_records
self.blocks = store_data.blocks
self.mempool_manager.peak = self.block_records[-1]
# Create a protocol to make BlockRecord and SimBlockRecord interchangeable.
self.mempool_manager.peak = self.block_records[-1] # type: ignore[assignment]
else:
self.timestamp = 1
self.block_height = 0
self.timestamp = uint64(1)
self.block_height = uint32(0)
self.block_records = []
self.blocks = []
return self
async def close(self):
async def close(self) -> None:
async with self.db_wrapper.write_db() as conn:
c = await conn.execute("DELETE FROM block_data")
await c.close()
@ -133,10 +144,11 @@ class SpendSim:
await c.close()
await self.db_wrapper.close()
async def new_peak(self):
await self.mempool_manager.new_peak(self.block_records[-1], None)
async def new_peak(self) -> None:
# Create a protocol to make BlockRecord and SimBlockRecord interchangeable.
await self.mempool_manager.new_peak(self.block_records[-1], None) # type: ignore[arg-type]
def new_coin_record(self, coin: Coin, coinbase=False) -> CoinRecord:
def new_coin_record(self, coin: Coin, coinbase: bool = False) -> CoinRecord:
return CoinRecord(
coin,
uint32(self.block_height + 1),
@ -164,7 +176,7 @@ class SpendSim:
return None
return simple_solution_generator(bundle)
async def farm_block(self, puzzle_hash: bytes32 = bytes32(b"0" * 32)):
async def farm_block(self, puzzle_hash: bytes32 = bytes32(b"0" * 32)) -> Tuple[List[Coin], List[Coin]]:
# Fees get calculated
fees = uint64(0)
if self.mempool_manager.mempool.spends:
@ -234,13 +246,13 @@ class SpendSim:
def get_height(self) -> uint32:
return self.block_height
def pass_time(self, time: uint64):
def pass_time(self, time: uint64) -> None:
self.timestamp = uint64(self.timestamp + time)
def pass_blocks(self, blocks: uint32):
def pass_blocks(self, blocks: uint32) -> None:
self.block_height = uint32(self.block_height + blocks)
async def rewind(self, block_height: uint32):
async def rewind(self, block_height: uint32) -> None:
new_br_list = list(filter(lambda br: br.height <= block_height, self.block_records))
new_block_list = list(filter(lambda block: block.height <= block_height, self.blocks))
self.block_records = new_br_list
@ -255,7 +267,7 @@ class SpendSim:
class SimClient:
def __init__(self, service):
def __init__(self, service: SpendSim) -> None:
self.service = service
async def push_tx(self, spend_bundle: SpendBundle) -> Tuple[MempoolInclusionStatus, Optional[Err]]:
@ -270,7 +282,7 @@ class SimClient:
)
return status, error
async def get_coin_record_by_name(self, name: bytes32) -> CoinRecord:
async def get_coin_record_by_name(self, name: bytes32) -> Optional[CoinRecord]:
return await self.service.mempool_manager.coin_store.get_coin_record(name)
async def get_coin_records_by_names(
@ -363,8 +375,13 @@ class SimClient:
return additions, removals
async def get_puzzle_and_solution(self, coin_id: bytes32, height: uint32) -> Optional[CoinSpend]:
generator = list(filter(lambda block: block.height == height, self.service.blocks))[0].transactions_generator
coin_record = await self.service.mempool_manager.coin_store.get_coin_record(coin_id)
filtered_generators = list(filter(lambda block: block.height == height, self.service.blocks))
# real consideration should be made for the None cases instead of just hint ignoring
generator: BlockGenerator = filtered_generators[0].transactions_generator # type: ignore[assignment]
coin_record: CoinRecord
coin_record = await self.service.mempool_manager.coin_store.get_coin_record( # type: ignore[assignment]
coin_id,
)
error, puzzle, solution = get_puzzle_and_solution_for_coin(
generator,
coin_id,
@ -380,13 +397,13 @@ class SimClient:
async def get_all_mempool_tx_ids(self) -> List[bytes32]:
return list(self.service.mempool_manager.mempool.spends.keys())
async def get_all_mempool_items(self) -> Dict[bytes32, Dict]:
async def get_all_mempool_items(self) -> Dict[bytes32, MempoolItem]:
spends = {}
for tx_id, item in self.service.mempool_manager.mempool.spends.items():
spends[tx_id] = item
return spends
async def get_mempool_item_by_tx_id(self, tx_id: bytes32) -> Optional[Dict]:
async def get_mempool_item_by_tx_id(self, tx_id: bytes32) -> Optional[Dict[str, Any]]:
item = self.service.mempool_manager.get_mempool_item(tx_id)
if item is None:
return None

View File

@ -519,9 +519,13 @@ def chia_init(
db_path_replaced = config["database_path"].replace("CHALLENGE", config["selected_network"])
db_path = path_from_root(root_path, db_path_replaced)
db_path.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(db_path) as connection:
set_db_version(connection, 2)
try:
# create new v2 db file
with sqlite3.connect(db_path) as connection:
set_db_version(connection, 2)
except sqlite3.OperationalError:
# db already exists, so we're good
pass
print("")
print("To see your keys, run 'chia keys show --show-mnemonic-seed'")

View File

@ -313,6 +313,7 @@ async def make_offer(args: dict, wallet_client: WalletRpcClient, fingerprint: in
},
}
if info.supports_did:
assert info.royalty_puzzle_hash is not None
driver_dict[id]["also"]["also"] = {
"type": "ownership",
"owner": "()",
@ -347,13 +348,13 @@ async def make_offer(args: dict, wallet_client: WalletRpcClient, fingerprint: in
print("--------------")
print()
print("OFFERING:")
for name, info in printable_dict.items():
amount, unit, multiplier = info
for name, data in printable_dict.items():
amount, unit, multiplier = data
if multiplier < 0:
print(f" - {amount} {name} ({int(Decimal(amount) * unit)} mojos)")
print("REQUESTING:")
for name, info in printable_dict.items():
amount, unit, multiplier = info
for name, data in printable_dict.items():
amount, unit, multiplier = data
if multiplier > 0:
print(f" - {amount} {name} ({int(Decimal(amount) * unit)} mojos)")

View File

@ -118,15 +118,12 @@ class DataLayerWallet:
raise ValueError("DataLayer Wallet already exists for this key")
assert name is not None
maybe_wallet_info = await wallet_state_manager.user_store.create_wallet(
self.wallet_info = await wallet_state_manager.user_store.create_wallet(
name,
WalletType.DATA_LAYER.value,
"",
in_transaction=in_transaction,
)
if maybe_wallet_info is None:
raise ValueError("Internal Error")
self.wallet_info = maybe_wallet_info
self.wallet_id = uint8(self.wallet_info.id)
await self.wallet_state_manager.add_new_wallet(self, self.wallet_info.id, in_transaction=in_transaction)

View File

@ -134,7 +134,11 @@ class Farmer:
async def setup_keys(self) -> bool:
no_keys_error_str = "No keys exist. Please run 'chia keys generate' or open the UI."
self.all_root_sks: List[PrivateKey] = [sk for sk, _ in await self.get_all_private_keys()]
try:
self.all_root_sks: List[PrivateKey] = [sk for sk, _ in await self.get_all_private_keys()]
except KeychainProxyConnectionFailure:
return False
self._private_keys = [master_sk_to_farmer_sk(sk) for sk in self.all_root_sks] + [
master_sk_to_pool_sk(sk) for sk in self.all_root_sks
]

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import contextlib
import dataclasses

View File

@ -1651,7 +1651,7 @@ def blue_boxed_end_of_slot(sub_slot: EndOfSubSlotBundle):
def validate_sub_epoch_sampling(rng, sub_epoch_weight_list, weight_proof):
tip = weight_proof.recent_chain_data[-1]
weight_to_check = _get_weights_for_sampling(rng, tip.weight, weight_proof.recent_chain_data)
sampled_sub_epochs: dict[int, bool] = {}
sampled_sub_epochs: Dict[int, bool] = {}
for idx in range(1, len(sub_epoch_weight_list)):
if _sample_sub_epoch(sub_epoch_weight_list[idx - 1], sub_epoch_weight_list[idx], weight_to_check):
sampled_sub_epochs[idx - 1] = True

View File

@ -1,6 +1,7 @@
import ipaddress
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
from chia.rpc.rpc_server import Endpoint
from chia.seeder.crawler import Crawler
from chia.util.ws_message import WsRpcMessage, create_payload_dict
@ -10,7 +11,7 @@ class CrawlerRpcApi:
self.service = crawler
self.service_name = "chia_crawler"
def get_routes(self) -> Dict[str, Callable]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
"/get_peer_counts": self.get_peer_counts,
"/get_ips_after_timestamp": self.get_ips_after_timestamp,

View File

@ -1,9 +1,10 @@
import dataclasses
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
from pathlib import Path
from chia.data_layer.data_layer import DataLayer
from chia.data_layer.data_layer_util import Side, Subscription
from chia.rpc.rpc_server import Endpoint
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
@ -51,7 +52,7 @@ class DataLayerRpcApi:
self.service: DataLayer = data_layer
self.service_name = "chia_data_layer"
def get_routes(self) -> Dict[str, Callable[[Any], Any]]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
"/create_data_store": self.create_data_store,
"/get_owned_stores": self.get_owned_stores,

View File

@ -7,6 +7,7 @@ from typing_extensions import Protocol
from chia.farmer.farmer import Farmer
from chia.plot_sync.receiver import Receiver
from chia.protocols.harvester_protocol import Plot
from chia.rpc.rpc_server import Endpoint
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.ints import uint32
@ -81,7 +82,7 @@ class FarmerRpcApi:
self.service = farmer
self.service_name = "chia_farmer"
def get_routes(self) -> Dict[str, Callable[[Any], Any]]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
"/get_signage_point": self.get_signage_point,
"/get_signage_points": self.get_signage_points,

View File

@ -1,9 +1,10 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
from chia.consensus.block_record import BlockRecord
from chia.consensus.pos_quality import UI_ACTUAL_SPACE_CONSTANT_FACTOR
from chia.full_node.full_node import FullNode
from chia.full_node.mempool_check_conditions import get_puzzle_and_solution_for_coin
from chia.rpc.rpc_server import Endpoint
from chia.types.blockchain_format.program import Program, SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
@ -30,7 +31,7 @@ class FullNodeRpcApi:
self.service_name = "chia_full_node"
self.cached_blockchain_state: Optional[Dict] = None
def get_routes(self) -> Dict[str, Callable[[Any], Any]]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
# Blockchain
"/get_blockchain_state": self.get_blockchain_state,
@ -307,7 +308,7 @@ class FullNodeRpcApi:
return {"signage_point": sp, "time_received": time_received, "reverted": True}
async def get_block(self, request: Dict) -> Optional[Dict]:
async def get_block(self, request: Dict) -> Dict[str, object]:
if "header_hash" not in request:
raise ValueError("No header_hash in request")
header_hash = bytes32.from_hexstr(request["header_hash"])
@ -318,7 +319,7 @@ class FullNodeRpcApi:
return {"block": block}
async def get_blocks(self, request: Dict) -> Optional[Dict]:
async def get_blocks(self, request: Dict) -> Dict[str, object]:
if "start" not in request:
raise ValueError("No start in request")
if "end" not in request:
@ -368,7 +369,7 @@ class FullNodeRpcApi:
}
}
async def get_block_records(self, request: Dict) -> Optional[Dict]:
async def get_block_records(self, request: Dict) -> Dict[str, object]:
if "start" not in request:
raise ValueError("No start in request")
if "end" not in request:
@ -398,7 +399,7 @@ class FullNodeRpcApi:
records.append(record)
return {"block_records": records}
async def get_block_record_by_height(self, request: Dict) -> Optional[Dict]:
async def get_block_record_by_height(self, request: Dict) -> Dict[str, object]:
if "height" not in request:
raise ValueError("No height in request")
height = request["height"]
@ -431,7 +432,7 @@ class FullNodeRpcApi:
return {"block_record": record}
async def get_unfinished_block_headers(self, request: Dict) -> Optional[Dict]:
async def get_unfinished_block_headers(self, request: Dict) -> Dict[str, object]:
peak: Optional[BlockRecord] = self.service.blockchain.get_peak()
if peak is None:
@ -452,7 +453,7 @@ class FullNodeRpcApi:
response_headers.append(unfinished_header_block)
return {"headers": response_headers}
async def get_network_space(self, request: Dict) -> Optional[Dict]:
async def get_network_space(self, request: Dict) -> Dict[str, object]:
"""
Retrieves an estimate of total space validating the chain
between two block header hashes.
@ -492,7 +493,7 @@ class FullNodeRpcApi:
)
return {"space": uint128(int(network_space_bytes_estimate))}
async def get_coin_records_by_puzzle_hash(self, request: Dict) -> Optional[Dict]:
async def get_coin_records_by_puzzle_hash(self, request: Dict) -> Dict[str, object]:
"""
Retrieves the coins for a given puzzlehash, by default returns unspent coins.
"""
@ -511,7 +512,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_records_by_puzzle_hashes(self, request: Dict) -> Optional[Dict]:
async def get_coin_records_by_puzzle_hashes(self, request: Dict) -> Dict[str, object]:
"""
Retrieves the coins for a given puzzlehash, by default returns unspent coins.
"""
@ -533,7 +534,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_record_by_name(self, request: Dict) -> Optional[Dict]:
async def get_coin_record_by_name(self, request: Dict) -> Dict[str, object]:
"""
Retrieves a coin record by it's name.
"""
@ -547,7 +548,7 @@ class FullNodeRpcApi:
return {"coin_record": coin_record_dict_backwards_compat(coin_record.to_json_dict())}
async def get_coin_records_by_names(self, request: Dict) -> Optional[Dict]:
async def get_coin_records_by_names(self, request: Dict) -> Dict[str, object]:
"""
Retrieves the coins for given coin IDs, by default returns unspent coins.
"""
@ -569,7 +570,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_records_by_parent_ids(self, request: Dict) -> Optional[Dict]:
async def get_coin_records_by_parent_ids(self, request: Dict) -> Dict[str, object]:
"""
Retrieves the coins for given parent coin IDs, by default returns unspent coins.
"""
@ -591,7 +592,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def get_coin_records_by_hint(self, request: Dict) -> Optional[Dict]:
async def get_coin_records_by_hint(self, request: Dict) -> Dict[str, object]:
"""
Retrieves coins by hint, by default returns unspent coins.
"""
@ -620,7 +621,7 @@ class FullNodeRpcApi:
return {"coin_records": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in coin_records]}
async def push_tx(self, request: Dict) -> Optional[Dict]:
async def push_tx(self, request: Dict) -> Dict[str, object]:
if "spend_bundle" not in request:
raise ValueError("Spend bundle not in request")
@ -645,7 +646,7 @@ class FullNodeRpcApi:
"status": status.name,
}
async def get_puzzle_and_solution(self, request: Dict) -> Optional[Dict]:
async def get_puzzle_and_solution(self, request: Dict) -> Dict[str, object]:
coin_name: bytes32 = bytes32.from_hexstr(request["coin_id"])
height = request["height"]
coin_record = await self.service.coin_store.get_coin_record(coin_name)
@ -671,7 +672,7 @@ class FullNodeRpcApi:
solution_ser: SerializedProgram = SerializedProgram.from_program(Program.to(solution))
return {"coin_solution": CoinSpend(coin_record.coin, puzzle_ser, solution_ser)}
async def get_additions_and_removals(self, request: Dict) -> Optional[Dict]:
async def get_additions_and_removals(self, request: Dict) -> Dict[str, object]:
if "header_hash" not in request:
raise ValueError("No header_hash in request")
header_hash = bytes32.from_hexstr(request["header_hash"])
@ -691,17 +692,17 @@ class FullNodeRpcApi:
"removals": [coin_record_dict_backwards_compat(cr.to_json_dict()) for cr in removals],
}
async def get_all_mempool_tx_ids(self, request: Dict) -> Optional[Dict]:
async def get_all_mempool_tx_ids(self, request: Dict) -> Dict[str, object]:
ids = list(self.service.mempool_manager.mempool.spends.keys())
return {"tx_ids": ids}
async def get_all_mempool_items(self, request: Dict) -> Optional[Dict]:
async def get_all_mempool_items(self, request: Dict) -> Dict[str, object]:
spends = {}
for tx_id, item in self.service.mempool_manager.mempool.spends.items():
spends[tx_id.hex()] = item
return {"mempool_items": spends}
async def get_mempool_item_by_tx_id(self, request: Dict) -> Optional[Dict]:
async def get_mempool_item_by_tx_id(self, request: Dict) -> Dict[str, object]:
if "tx_id" not in request:
raise ValueError("No tx_id in request")
tx_id: bytes32 = bytes32.from_hexstr(request["tx_id"])

View File

@ -1,6 +1,7 @@
from typing import Any, Callable, Dict, List
from typing import Any, Dict, List
from chia.harvester.harvester import Harvester
from chia.rpc.rpc_server import Endpoint
from chia.util.ws_message import WsRpcMessage, create_payload_dict
@ -9,7 +10,7 @@ class HarvesterRpcApi:
self.service = harvester
self.service_name = "chia_harvester"
def get_routes(self) -> Dict[str, Callable[[Any], Any]]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
"/get_plots": self.get_plots,
"/refresh_plots": self.refresh_plots,

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import asyncio
import json
import logging
@ -8,7 +10,7 @@ from ssl import SSLContext
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple
from aiohttp import ClientConnectorError, ClientSession, ClientWebSocketResponse, WSMsgType, web
from typing_extensions import Protocol
from typing_extensions import Protocol, final
from chia.rpc.util import wrap_http_handler
from chia.server.outbound_message import NodeType
@ -24,11 +26,15 @@ log = logging.getLogger(__name__)
max_message_size = 50 * 1024 * 1024 # 50MB
Endpoint = Callable[[Dict[str, object]], Awaitable[Dict[str, object]]]
class RpcApiProtocol(Protocol):
def get_routes(self) -> Dict[str, Callable[[Any], Any]]:
def get_routes(self) -> Dict[str, Endpoint]:
pass
@final
@dataclass
class RpcServer:
"""
@ -36,7 +42,7 @@ class RpcServer:
"""
rpc_api: Any
stop_cb: Callable
stop_cb: Callable[[], None]
service_name: str
ssl_context: SSLContext
ssl_client_context: SSLContext
@ -45,7 +51,9 @@ class RpcServer:
client_session: Optional[ClientSession] = None
@classmethod
def create(cls, rpc_api: Any, service_name: str, stop_cb: Callable, root_path, net_config: Dict[str, Any]):
def create(
cls, rpc_api: Any, service_name: str, stop_cb: Callable[[], None], root_path: Path, net_config: Dict[str, Any]
) -> RpcServer:
crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
key_path = root_path / net_config["daemon_ssl"]["private_key"]
ca_cert_path = root_path / net_config["private_ssl_ca"]["crt"]
@ -54,7 +62,7 @@ class RpcServer:
ssl_client_context = ssl_context_for_client(ca_cert_path, ca_key_path, crt_path, key_path, log=log)
return cls(rpc_api, stop_cb, service_name, ssl_context, ssl_client_context)
async def stop(self):
async def stop(self) -> None:
self.shut_down = True
if self.websocket is not None:
await self.websocket.close()
@ -93,7 +101,7 @@ class RpcServer:
return None
asyncio.create_task(self._state_changed(change, change_data))
def get_routes(self) -> Dict[str, Callable]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
**self.rpc_api.get_routes(),
"/get_connections": self.get_connections,
@ -104,13 +112,13 @@ class RpcServer:
"/healthz": self.healthz,
}
async def _get_routes(self, request: Dict) -> Dict:
async def _get_routes(self, request: Dict[str, Any]) -> Dict[str, object]:
return {
"success": "true",
"routes": list(self.get_routes().keys()),
}
async def get_connections(self, request: Dict) -> Dict:
async def get_connections(self, request: Dict[str, Any]) -> Dict[str, object]:
request_node_type: Optional[NodeType] = None
if "node_type" in request:
request_node_type = NodeType(request["node_type"])
@ -166,7 +174,7 @@ class RpcServer:
]
return {"connections": con_info}
async def open_connection(self, request: Dict):
async def open_connection(self, request: Dict[str, Any]) -> Dict[str, object]:
host = request["host"]
port = request["port"]
target_node: PeerInfo = PeerInfo(host, uint16(int(port)))
@ -179,7 +187,7 @@ class RpcServer:
raise ValueError("Start client failed, or server is not set")
return {}
async def close_connection(self, request: Dict):
async def close_connection(self, request: Dict[str, Any]) -> Dict[str, object]:
node_id = hexstr_to_bytes(request["node_id"])
if self.rpc_api.service.server is None:
raise web.HTTPInternalServerError()
@ -190,7 +198,7 @@ class RpcServer:
await connection.close()
return {}
async def stop_node(self, request):
async def stop_node(self, request: Dict[str, Any]) -> Dict[str, object]:
"""
Shuts down the node.
"""
@ -198,36 +206,36 @@ class RpcServer:
self.stop_cb()
return {}
async def healthz(self, request: Dict) -> Dict:
async def healthz(self, request: Dict[str, Any]) -> Dict[str, object]:
return {
"success": "true",
}
async def ws_api(self, message):
async def ws_api(self, message: WsRpcMessage) -> Dict[str, object]:
"""
This function gets called when new message is received via websocket.
"""
command = message["command"]
if message["ack"]:
return None
return {}
data = None
data: Dict[str, object] = {}
if "data" in message:
data = message["data"]
if command == "ping":
return pong()
f = getattr(self, command, None)
if f is not None:
return await f(data)
f = getattr(self.rpc_api, command, None)
if f is not None:
return await f(data)
f_internal: Optional[Endpoint] = getattr(self, command, None)
if f_internal is not None:
return await f_internal(data)
f_rpc_api: Optional[Endpoint] = getattr(self.rpc_api, command, None)
if f_rpc_api is not None:
return await f_rpc_api(data)
raise ValueError(f"unknown_command {command}")
async def safe_handle(self, websocket, payload):
async def safe_handle(self, websocket: ClientWebSocketResponse, payload: str) -> None:
message = None
try:
message = json.loads(payload)
@ -250,7 +258,7 @@ class RpcServer:
res = {"success": False, "error": f"{error}"}
await websocket.send_str(format_response(message, res))
async def connection(self, ws):
async def connection(self, ws: ClientWebSocketResponse) -> None:
data = {"service": self.service_name}
payload = create_payload("register_service", data, self.service_name, "daemon")
await ws.send_str(payload)
@ -279,7 +287,7 @@ class RpcServer:
break
async def connect_to_daemon(self, self_hostname: str, daemon_port: uint16):
async def connect_to_daemon(self, self_hostname: str, daemon_port: uint16) -> None:
while not self.shut_down:
try:
self.client_session = ClientSession()
@ -311,11 +319,11 @@ async def start_rpc_server(
self_hostname: str,
daemon_port: uint16,
rpc_port: uint16,
stop_cb: Callable,
stop_cb: Callable[[], None],
root_path: Path,
net_config,
connect_to_daemon=True,
max_request_body_size=None,
net_config: Dict[str, object],
connect_to_daemon: bool = True,
max_request_body_size: Optional[int] = None,
name: str = "rpc_server",
) -> Tuple[Callable[[], Awaitable[None]], uint16]:
"""
@ -344,7 +352,7 @@ async def start_rpc_server(
if rpc_port == 0:
rpc_port = select_port(root_path, runner.addresses)
async def cleanup():
async def cleanup() -> None:
await rpc_server.stop()
await runner.cleanup()
if connect_to_daemon:

View File

@ -1,5 +1,6 @@
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional
from chia.rpc.rpc_server import Endpoint
from chia.timelord.timelord import Timelord
from chia.util.ws_message import WsRpcMessage, create_payload_dict
@ -9,7 +10,7 @@ class TimelordRpcApi:
self.service = timelord
self.service_name = "chia_timelord"
def get_routes(self) -> Dict[str, Callable]:
def get_routes(self) -> Dict[str, Endpoint]:
return {}
async def _state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> List[WsRpcMessage]:

View File

@ -3,7 +3,7 @@ import dataclasses
import json
import logging
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
from typing import Any, Dict, List, Optional, Set, Tuple
from blspy import G1Element, G2Element, PrivateKey
@ -13,6 +13,7 @@ from chia.pools.pool_wallet import PoolWallet
from chia.pools.pool_wallet_info import FARMING_TO_POOL, PoolState, PoolWalletInfo, create_pool_state
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.wallet_protocol import CoinState
from chia.rpc.rpc_server import Endpoint
from chia.server.outbound_message import NodeType, make_msg
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.types.announcement import Announcement
@ -68,7 +69,7 @@ class WalletRpcApi:
self.service_name = "chia_wallet"
self.balance_cache: Dict[int, Any] = {}
def get_routes(self) -> Dict[str, Callable[[Any], Any]]:
def get_routes(self) -> Dict[str, Endpoint]:
return {
# Key management
"/log_in": self.log_in,
@ -194,7 +195,6 @@ class WalletRpcApi:
await peers_close_task
async def _convert_tx_puzzle_hash(self, tx: TransactionRecord) -> TransactionRecord:
assert self.service.wallet_state_manager is not None
return dataclasses.replace(
tx,
to_puzzle_hash=(
@ -228,7 +228,6 @@ class WalletRpcApi:
async def get_public_keys(self, request: Dict):
try:
assert self.service.keychain_proxy is not None # An offering to the mypy gods
fingerprints = [
sk.get_g1().get_fingerprint() for (sk, seed) in await self.service.keychain_proxy.get_all_private_keys()
]
@ -241,7 +240,6 @@ class WalletRpcApi:
async def _get_private_key(self, fingerprint) -> Tuple[Optional[PrivateKey], Optional[bytes]]:
try:
assert self.service.keychain_proxy is not None # An offering to the mypy gods
all_keys = await self.service.keychain_proxy.get_all_private_keys()
for sk, seed in all_keys:
if sk.get_g1().get_fingerprint() == fingerprint:
@ -384,7 +382,6 @@ class WalletRpcApi:
async def delete_all_keys(self, request: Dict):
await self._stop_wallet()
try:
assert self.service.keychain_proxy is not None # An offering to the mypy gods
await self.service.keychain_proxy.delete_all_keys()
except Exception as e:
log.error(f"Failed to delete all keys: {e}")
@ -399,24 +396,20 @@ class WalletRpcApi:
##########################################################################################
async def get_sync_status(self, request: Dict):
assert self.service.wallet_state_manager is not None
syncing = self.service.wallet_state_manager.sync_mode
synced = await self.service.wallet_state_manager.synced()
return {"synced": synced, "syncing": syncing, "genesis_initialized": True}
async def get_height_info(self, request: Dict):
assert self.service.wallet_state_manager is not None
height = await self.service.wallet_state_manager.blockchain.get_finished_sync_up_to()
return {"height": height}
async def get_network_info(self, request: Dict):
assert self.service.wallet_state_manager is not None
network_name = self.service.config["selected_network"]
address_prefix = self.service.config["network_overrides"]["config"][network_name]["address_prefix"]
return {"network_name": network_name, "network_prefix": address_prefix}
async def push_tx(self, request: Dict):
assert self.service.server is not None
nodes = self.service.server.get_full_node_connections()
if len(nodes) == 0:
raise ValueError("Wallet is not currently connected to any full node peers")
@ -436,7 +429,6 @@ class WalletRpcApi:
##########################################################################################
async def get_wallets(self, request: Dict):
assert self.service.wallet_state_manager is not None
include_data: bool = request.get("include_data", True)
wallet_type: Optional[WalletType] = None
if "type" in request:
@ -451,7 +443,6 @@ class WalletRpcApi:
return {"wallets": wallets}
async def create_new_wallet(self, request: Dict):
assert self.service.wallet_state_manager is not None
wallet_state_manager = self.service.wallet_state_manager
if await self.service.wallet_state_manager.synced() is False:
@ -668,7 +659,6 @@ class WalletRpcApi:
##########################################################################################
async def get_wallet_balance(self, request: Dict) -> Dict:
assert self.service.wallet_state_manager is not None
wallet_id = uint32(int(request["wallet_id"]))
wallet = self.service.wallet_state_manager.wallets[wallet_id]
@ -721,7 +711,6 @@ class WalletRpcApi:
return {"wallet_balance": wallet_balance}
async def get_transaction(self, request: Dict) -> Dict:
assert self.service.wallet_state_manager is not None
transaction_id: bytes32 = bytes32(hexstr_to_bytes(request["transaction_id"]))
tr: Optional[TransactionRecord] = await self.service.wallet_state_manager.get_transaction(transaction_id)
if tr is None:
@ -733,8 +722,6 @@ class WalletRpcApi:
}
async def get_transactions(self, request: Dict) -> Dict:
assert self.service.wallet_state_manager is not None
wallet_id = int(request["wallet_id"])
start = request.get("start", 0)
@ -759,8 +746,6 @@ class WalletRpcApi:
}
async def get_transaction_count(self, request: Dict) -> Dict:
assert self.service.wallet_state_manager is not None
wallet_id = int(request["wallet_id"])
count = await self.service.wallet_state_manager.tx_store.get_transaction_count_for_wallet(wallet_id)
return {
@ -778,8 +763,6 @@ class WalletRpcApi:
"""
Returns a new address
"""
assert self.service.wallet_state_manager is not None
if request["new_address"] is True:
create_new = True
else:
@ -803,8 +786,6 @@ class WalletRpcApi:
}
async def send_transaction(self, request):
assert self.service.wallet_state_manager is not None
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced before sending transactions")
@ -843,8 +824,6 @@ class WalletRpcApi:
}
async def send_transaction_multi(self, request) -> Dict:
assert self.service.wallet_state_manager is not None
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced before sending transactions")
@ -873,13 +852,9 @@ class WalletRpcApi:
if self.service.wallet_state_manager.wallets[wallet_id].type() == WalletType.POOLING_WALLET.value:
self.service.wallet_state_manager.wallets[wallet_id].target_state = None
await self.service.wallet_state_manager.tx_store.db_wrapper.commit_transaction()
# Update the cache
await self.service.wallet_state_manager.tx_store.rebuild_tx_cache()
return {}
async def select_coins(self, request) -> Dict[str, List[Dict]]:
assert self.service.wallet_state_manager is not None
async def select_coins(self, request) -> Dict[str, object]:
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced before selecting coins")
@ -900,14 +875,12 @@ class WalletRpcApi:
return {"cat_list": list(DEFAULT_CATS.values())}
async def cat_set_name(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = int(request["wallet_id"])
wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id]
await wallet.set_name(str(request["name"]))
return {"wallet_id": wallet_id}
async def cat_get_name(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = int(request["wallet_id"])
wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id]
name: str = await wallet.get_name()
@ -919,13 +892,10 @@ class WalletRpcApi:
:param request: RPC request
:return: A list of unacknowledged CATs
"""
assert self.service.wallet_state_manager is not None
cats = await self.service.wallet_state_manager.interested_store.get_unacknowledged_tokens()
return {"stray_cats": cats}
async def cat_spend(self, request):
assert self.service.wallet_state_manager is not None
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced.")
wallet_id = int(request["wallet_id"])
@ -956,14 +926,12 @@ class WalletRpcApi:
}
async def cat_get_asset_id(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = int(request["wallet_id"])
wallet: CATWallet = self.service.wallet_state_manager.wallets[wallet_id]
asset_id: str = wallet.get_asset_id()
return {"asset_id": asset_id, "wallet_id": wallet_id}
async def cat_asset_id_to_name(self, request):
assert self.service.wallet_state_manager is not None
wallet = await self.service.wallet_state_manager.get_wallet_for_asset_id(request["asset_id"])
if wallet is None:
if request["asset_id"] in DEFAULT_CATS:
@ -974,8 +942,6 @@ class WalletRpcApi:
return {"wallet_id": wallet.id(), "name": (await wallet.get_name())}
async def create_offer_for_ids(self, request):
assert self.service.wallet_state_manager is not None
offer: Dict[str, int] = request["offer"]
fee: uint64 = uint64(request.get("fee", 0))
validate_only: bool = request.get("validate_only", False)
@ -1019,7 +985,6 @@ class WalletRpcApi:
raise ValueError(error)
async def get_offer_summary(self, request):
assert self.service.wallet_state_manager is not None
offer_hex: str = request["offer"]
offer = Offer.from_bech32(offer_hex)
offered, requested, infos = offer.summary()
@ -1027,14 +992,12 @@ class WalletRpcApi:
return {"summary": {"offered": offered, "requested": requested, "fees": offer.bundle.fees(), "infos": infos}}
async def check_offer_validity(self, request):
assert self.service.wallet_state_manager is not None
offer_hex: str = request["offer"]
offer = Offer.from_bech32(offer_hex)
return {"valid": (await self.service.wallet_state_manager.trade_manager.check_offer_validity(offer))}
async def take_offer(self, request):
assert self.service.wallet_state_manager is not None
offer_hex: str = request["offer"]
offer = Offer.from_bech32(offer_hex)
fee: uint64 = uint64(request.get("fee", 0))
@ -1050,8 +1013,6 @@ class WalletRpcApi:
return {"trade_record": trade_record.to_json_dict_convenience()}
async def get_offer(self, request: Dict):
assert self.service.wallet_state_manager is not None
trade_mgr = self.service.wallet_state_manager.trade_manager
trade_id = bytes32.from_hexstr(request["trade_id"])
@ -1065,8 +1026,6 @@ class WalletRpcApi:
return {"trade_record": trade_record.to_json_dict_convenience(), "offer": offer_value}
async def get_all_offers(self, request: Dict):
assert self.service.wallet_state_manager is not None
trade_mgr = self.service.wallet_state_manager.trade_manager
start: int = request.get("start", 0)
@ -1098,8 +1057,6 @@ class WalletRpcApi:
return {"trade_records": result, "offers": offer_values}
async def get_offers_count(self, request: Dict):
assert self.service.wallet_state_manager is not None
trade_mgr = self.service.wallet_state_manager.trade_manager
(total, my_offers_count, taken_offers_count) = await trade_mgr.trade_store.get_trades_count()
@ -1107,8 +1064,6 @@ class WalletRpcApi:
return {"total": total, "my_offers_count": my_offers_count, "taken_offers_count": taken_offers_count}
async def cancel_offer(self, request: Dict):
assert self.service.wallet_state_manager is not None
wsm = self.service.wallet_state_manager
secure = request["secure"]
trade_id = bytes32.from_hexstr(request["trade_id"])
@ -1126,7 +1081,6 @@ class WalletRpcApi:
##########################################################################################
async def did_set_wallet_name(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = uint32(request["wallet_id"])
wallet: DIDWallet = self.service.wallet_state_manager.wallets[wallet_id]
if wallet.type() == WalletType.DECENTRALIZED_ID:
@ -1136,7 +1090,6 @@ class WalletRpcApi:
return {"success": False, "error": f"Wallet id {wallet_id} is not a DID wallet"}
async def did_get_wallet_name(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = uint32(request["wallet_id"])
wallet: DIDWallet = self.service.wallet_state_manager.wallets[wallet_id]
name: str = await wallet.get_name()
@ -1319,7 +1272,6 @@ class WalletRpcApi:
return {"wallet_id": wallet_id, "success": True, "backup_data": did_wallet.create_backup()}
async def did_transfer_did(self, request):
assert self.service.wallet_state_manager is not None
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced.")
wallet_id = uint32(request["wallet_id"])
@ -1401,7 +1353,6 @@ class WalletRpcApi:
async def nft_get_nfts(self, request) -> Dict:
wallet_id = uint32(request["wallet_id"])
assert self.service.wallet_state_manager is not None
nft_wallet: NFTWallet = self.service.wallet_state_manager.wallets[wallet_id]
nfts = nft_wallet.get_current_nfts()
nft_info_list = []
@ -1411,7 +1362,6 @@ class WalletRpcApi:
async def nft_set_nft_did(self, request):
try:
assert self.service.wallet_state_manager is not None
wallet_id = uint32(request["wallet_id"])
nft_wallet: NFTWallet = self.service.wallet_state_manager.wallets[wallet_id]
did_id = request.get("did_id", "")
@ -1433,7 +1383,6 @@ class WalletRpcApi:
did_id: Optional[bytes32] = None
if "did_id" in request:
did_id = decode_puzzle_hash(request["did_id"])
assert self.service.wallet_state_manager is not None
for wallet in self.service.wallet_state_manager.wallets.values():
if isinstance(wallet, NFTWallet) and wallet.get_did() == did_id:
return {"wallet_id": wallet.wallet_id, "success": True}
@ -1441,7 +1390,6 @@ class WalletRpcApi:
async def nft_get_wallet_did(self, request) -> Dict:
wallet_id = uint32(request["wallet_id"])
assert self.service.wallet_state_manager is not None
nft_wallet: NFTWallet = self.service.wallet_state_manager.wallets[wallet_id]
if nft_wallet is not None:
if nft_wallet.type() != WalletType.NFT.value:
@ -1454,7 +1402,6 @@ class WalletRpcApi:
return {"success": False, "error": f"Wallet {wallet_id} not found"}
async def nft_get_wallets_with_dids(self, request) -> Dict:
assert self.service.wallet_state_manager is not None
all_wallets = self.service.wallet_state_manager.wallets.values()
did_wallets_by_did_id: Dict[bytes32, uint32] = {
wallet.did_info.origin_coin.name(): wallet.id()
@ -1494,7 +1441,6 @@ class WalletRpcApi:
return {"success": False, "error": f"Cannot change the status of the NFT.{e}"}
async def nft_transfer_nft(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = uint32(request["wallet_id"])
address = request["target_address"]
if isinstance(address, str):
@ -1511,7 +1457,7 @@ class WalletRpcApi:
nft_coin_info = nft_wallet.get_nft_coin_by_id(nft_coin_id)
fee = uint64(request.get("fee", 0))
txs = await nft_wallet.generate_signed_transaction(
[nft_coin_info.coin.amount],
[uint64(nft_coin_info.coin.amount)],
[puzzle_hash],
coins={nft_coin_info.coin},
fee=fee,
@ -1529,8 +1475,7 @@ class WalletRpcApi:
log.exception(f"Failed to transfer NFT: {e}")
return {"success": False, "error": str(e)}
async def nft_get_info(self, request: Dict):
assert self.service.wallet_state_manager is not None
async def nft_get_info(self, request: Dict) -> Dict[str, object]:
if "coin_id" not in request:
return {"success": False, "error": "Coin ID is required."}
coin_id = request["coin_id"]
@ -1623,7 +1568,6 @@ class WalletRpcApi:
return {"success": True, "nft_info": nft_info}
async def nft_add_uri(self, request) -> Dict:
assert self.service.wallet_state_manager is not None
wallet_id = uint32(request["wallet_id"])
# Note metadata updater can only add one uri for one field per spend.
# If you want to add multiple uris for one field, you need to spend multiple times.
@ -1649,8 +1593,6 @@ class WalletRpcApi:
##########################################################################################
async def rl_set_user_info(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = uint32(int(request["wallet_id"]))
rl_user = self.service.wallet_state_manager.wallets[wallet_id]
origin = request["origin"]
@ -1666,8 +1608,6 @@ class WalletRpcApi:
return {}
async def send_clawback_transaction(self, request):
assert self.service.wallet_state_manager is not None
wallet_id = int(request["wallet_id"])
wallet: RLWallet = self.service.wallet_state_manager.wallets[wallet_id]
@ -1723,7 +1663,6 @@ class WalletRpcApi:
}
async def create_signed_transaction(self, request, hold_lock=True) -> Dict:
assert self.service.wallet_state_manager is not None
if "additions" not in request or len(request["additions"]) < 1:
raise ValueError("Specify additions list")
@ -1820,8 +1759,6 @@ class WalletRpcApi:
# Pool Wallet
##########################################################################################
async def pw_join_pool(self, request) -> Dict:
if self.service.wallet_state_manager is None:
return {"success": False, "error": "not_initialized"}
fee = uint64(request.get("fee", 0))
wallet_id = uint32(request["wallet_id"])
wallet: PoolWallet = self.service.wallet_state_manager.wallets[wallet_id]
@ -1850,8 +1787,6 @@ class WalletRpcApi:
return {"total_fee": total_fee, "transaction": tx, "fee_transaction": fee_tx}
async def pw_self_pool(self, request) -> Dict:
if self.service.wallet_state_manager is None:
return {"success": False, "error": "not_initialized"}
# Leaving a pool requires two state transitions.
# First we transition to PoolSingletonState.LEAVING_POOL
# Then we transition to FARMING_TO_POOL or SELF_POOLING
@ -1870,8 +1805,6 @@ class WalletRpcApi:
async def pw_absorb_rewards(self, request) -> Dict:
"""Perform a sweep of the p2_singleton rewards controlled by the pool wallet singleton"""
if self.service.wallet_state_manager is None:
return {"success": False, "error": "not_initialized"}
if await self.service.wallet_state_manager.synced() is False:
raise ValueError("Wallet needs to be fully synced before collecting rewards")
fee = uint64(request.get("fee", 0))
@ -1888,8 +1821,6 @@ class WalletRpcApi:
async def pw_status(self, request) -> Dict:
"""Return the complete state of the Pool wallet with id `request["wallet_id"]`"""
if self.service.wallet_state_manager is None:
return {"success": False, "error": "not_initialized"}
wallet_id = uint32(request["wallet_id"])
wallet: PoolWallet = self.service.wallet_state_manager.wallets[wallet_id]

View File

@ -1,19 +1,20 @@
from typing import Any, Dict, Optional
from typing import Dict
from chia.rpc.full_node_rpc_api import FullNodeRpcApi
from chia.rpc.rpc_server import Endpoint
from chia.simulator.simulator_protocol import FarmNewBlockProtocol
from chia.util.bech32m import decode_puzzle_hash
class SimulatorFullNodeRpcApi(FullNodeRpcApi):
def get_routes(self) -> Dict[str, Any]:
def get_routes(self) -> Dict[str, Endpoint]:
routes = super().get_routes()
routes["/farm_tx_block"] = self.farm_tx_block
return routes
async def farm_tx_block(self, _request: Dict[str, str]) -> Optional[Dict[str, str]]:
request_address = _request["address"]
async def farm_tx_block(self, _request: Dict[str, object]) -> Dict[str, object]:
request_address = str(_request["address"])
ph = decode_puzzle_hash(request_address)
req = FarmNewBlockProtocol(ph)
await self.service.server.api.farm_new_transaction_block(req)
return None
return {}

View File

@ -54,7 +54,7 @@ def lock_config(root_path: Path, filename: Union[str, Path]) -> Iterator[None]:
# should probably be removed and this function made private.
config_path = config_path_for_filename(root_path, filename)
lock_path: Path = config_path.with_name(config_path.name + ".lock")
with FileLock(lock_path):
with FileLock(lock_path): # pylint: disable=E0110
yield

View File

@ -508,6 +508,10 @@ wallet:
wallet_peers_path: wallet/db/wallet_peers.sqlite
wallet_peers_file_path: wallet/db/wallet_peers.dat
# this is a debug and profiling facility that logs all SQLite commands to a
# separate log file (under logging/wallet_sql.log).
log_sqlite_cmds: False
logging: *logging
network_overrides: *network_overrides
selected_network: *selected_network

View File

@ -74,6 +74,7 @@ FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Tuple[Field, ...]] = {}
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {}
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {}
POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {}
def create_fields_cache(cls: Type[object]) -> Tuple[Field, ...]:
@ -239,6 +240,41 @@ def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType:
return lambda item: convert_primitive(f_type, item)
def post_init_process_item(f_type: Type[Any], item: Any) -> object:
if not isinstance(item, f_type):
try:
item = f_type(item)
except (TypeError, AttributeError, ValueError):
if hasattr(f_type, "from_bytes_unchecked"):
from_bytes_method: Callable[[bytes], Any] = f_type.from_bytes_unchecked
else:
from_bytes_method = f_type.from_bytes
try:
item = from_bytes_method(item)
except Exception:
item = from_bytes_method(bytes(item))
if not isinstance(item, f_type):
raise ValueError(f"Wrong type for {f_type}")
return item
def function_to_post_init_process_one_item(f_type: Type[object]) -> ConvertFunctionType:
if is_type_SpecificOptional(f_type):
process_inner_func = function_to_post_init_process_one_item(get_args(f_type)[0])
return lambda item: convert_optional(process_inner_func, item)
if is_type_Tuple(f_type):
args = get_args(f_type)
process_inner_tuple_funcs = []
for arg in args:
process_inner_tuple_funcs.append(function_to_post_init_process_one_item(arg))
return lambda items: convert_tuple(process_inner_tuple_funcs, items) # type: ignore[arg-type]
if is_type_List(f_type):
inner_type = get_args(f_type)[0]
process_inner_func = function_to_post_init_process_one_item(inner_type)
return lambda items: convert_list(process_inner_func, items) # type: ignore[arg-type]
return lambda item: post_init_process_item(f_type, item)
def recurse_jsonify(d: Any) -> Any:
"""
Makes bytes objects and unhashable types into strings with 0x, and makes large ints into
@ -506,6 +542,7 @@ def streamable(cls: Type[_T_Streamable]) -> Type[_T_Streamable]:
stream_functions = []
parse_functions = []
convert_functions = []
post_init_functions = []
fields = create_fields_cache(cls)
FIELDS_FOR_STREAMABLE_CLASS[cls] = fields
@ -514,10 +551,12 @@ def streamable(cls: Type[_T_Streamable]) -> Type[_T_Streamable]:
stream_functions.append(function_to_stream_one_item(field.type))
parse_functions.append(function_to_parse_one_item(field.type))
convert_functions.append(function_to_convert_one_item(field.type))
post_init_functions.append(function_to_post_init_process_one_item(field.type))
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions
POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = post_init_functions
return cls
@ -566,64 +605,16 @@ class Streamable:
Make sure to use the streamable decorator when inheriting from the Streamable class to prepare the streaming caches.
"""
def post_init_parse(self, item: Any, f_name: str, f_type: Type[Any]) -> Any:
if is_type_List(f_type):
collected_list: List[Any] = []
inner_type: Type[Any] = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
if not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a list.")
for el in item:
collected_list.append(self.post_init_parse(el, f_name, inner_type))
return collected_list
if is_type_SpecificOptional(f_type):
if item is None:
return None
else:
inner_type: Type = get_args(f_type)[0] # type: ignore
return self.post_init_parse(item, f_name, inner_type)
if is_type_Tuple(f_type):
collected_list = []
if not is_type_Tuple(type(item)) and not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a tuple.")
if len(item) != len(get_args(f_type)):
raise ValueError(f"Wrong number of elements in tuple {f_name}.")
for i in range(len(item)):
inner_type = get_args(f_type)[i]
tuple_item = item[i]
collected_list.append(self.post_init_parse(tuple_item, f_name, inner_type))
return tuple(collected_list)
if not isinstance(item, f_type):
try:
item = f_type(item)
except (TypeError, AttributeError, ValueError):
if hasattr(f_type, "from_bytes_unchecked"):
from_bytes_method: Callable[[bytes], Any] = f_type.from_bytes_unchecked
else:
from_bytes_method = f_type.from_bytes
try:
item = from_bytes_method(item)
except Exception:
item = from_bytes_method(bytes(item))
if not isinstance(item, f_type):
raise ValueError(f"Wrong type for {f_name}")
return item
def __post_init__(self) -> None:
try:
fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
except Exception:
fields = ()
fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
process_funcs = POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS[type(self)]
data = self.__dict__
for field in fields:
for field, process_func in zip(fields, process_funcs):
if field.name not in data:
raise ValueError(f"Field {field.name} not present")
try:
if not isinstance(data[field.name], field.type):
object.__setattr__(self, field.name, self.post_init_parse(data[field.name], field.name, field.type))
except TypeError:
# Throws a TypeError because we cannot call isinstance for subscripted generics like Optional[int]
object.__setattr__(self, field.name, self.post_init_parse(data[field.name], field.name, field.type))
object.__setattr__(self, field.name, process_func(data[field.name]))
@classmethod
def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable:

View File

@ -93,11 +93,7 @@ class CATWallet:
if name is None:
name = "CAT WALLET"
new_wallet_info = await wallet_state_manager.user_store.create_wallet(name, WalletType.CAT, info_as_string)
if new_wallet_info is None:
raise ValueError("Internal Error")
self.wallet_info = new_wallet_info
self.wallet_info = await wallet_state_manager.user_store.create_wallet(name, WalletType.CAT, info_as_string)
try:
chia_tx, spend_bundle = await ALL_LIMITATIONS_PROGRAMS[
@ -194,12 +190,9 @@ class CATWallet:
limitations_program_hash = bytes32(hexstr_to_bytes(limitations_program_hash_hex))
self.cat_info = CATInfo(limitations_program_hash, None)
info_as_string = bytes(self.cat_info).hex()
new_wallet_info = await wallet_state_manager.user_store.create_wallet(
self.wallet_info = await wallet_state_manager.user_store.create_wallet(
name, WalletType.CAT, info_as_string, in_transaction=in_transaction
)
if new_wallet_info is None:
raise Exception("wallet_info is None")
self.wallet_info = new_wallet_info
self.lineage_store = await CATLineageStore.create(
self.wallet_state_manager.db_wrapper, self.get_asset_id(), in_transaction=in_transaction

View File

@ -97,8 +97,6 @@ class DIDWallet:
self.wallet_info = await wallet_state_manager.user_store.create_wallet(
name, WalletType.DECENTRALIZED_ID.value, info_as_string
)
if self.wallet_info is None:
raise ValueError("Internal Error")
self.wallet_id = self.wallet_info.id
std_wallet_id = self.standard_wallet.wallet_id
bal = await wallet_state_manager.get_confirmed_balance_for_wallet(std_wallet_id)

View File

@ -78,13 +78,13 @@ def get_nft_info_from_puzzle(nft_coin_info: NFTCoinInfo) -> NFTInfo:
uncurried_nft: UncurriedNFT = UncurriedNFT.uncurry(nft_coin_info.full_puzzle)
data_uris: List[str] = []
for uri in uncurried_nft.data_uris.as_python():
for uri in uncurried_nft.data_uris.as_python(): # pylint: disable=E1133
data_uris.append(str(uri, "utf-8"))
meta_uris: List[str] = []
for uri in uncurried_nft.meta_uris.as_python():
for uri in uncurried_nft.meta_uris.as_python(): # pylint: disable=E1133
meta_uris.append(str(uri, "utf-8"))
license_uris: List[str] = []
for uri in uncurried_nft.license_uris.as_python():
for uri in uncurried_nft.license_uris.as_python(): # pylint: disable=E1133
license_uris.append(str(uri, "utf-8"))
nft_info = NFTInfo(

View File

@ -71,11 +71,9 @@ class RLWallet:
rl_info = RLInfo("admin", bytes(pubkey), None, None, None, None, None, None, False)
info_as_string = json.dumps(rl_info.to_json_dict())
wallet_info: Optional[WalletInfo] = await wallet_state_manager.user_store.create_wallet(
wallet_info: WalletInfo = await wallet_state_manager.user_store.create_wallet(
"RL Admin", WalletType.RATE_LIMITED, info_as_string
)
if wallet_info is None:
raise Exception("wallet_info is None")
await wallet_state_manager.puzzle_store.add_derivation_paths(
[
@ -112,10 +110,11 @@ class RLWallet:
rl_info = RLInfo("user", None, bytes(pubkey), None, None, None, None, None, False)
info_as_string = json.dumps(rl_info.to_json_dict())
await wallet_state_manager.user_store.create_wallet("RL User", WalletType.RATE_LIMITED, info_as_string)
wallet_info = await wallet_state_manager.user_store.get_last_wallet()
if wallet_info is None:
raise Exception("wallet_info is None")
wallet_info = await wallet_state_manager.user_store.create_wallet(
"RL User",
WalletType.RATE_LIMITED,
info_as_string,
)
self = await RLWallet.create(wallet_state_manager, wallet_info)

View File

@ -23,6 +23,8 @@ from chia.protocols.wallet_protocol import (
RespondToCoinUpdates,
RespondHeaderBlocks,
RequestHeaderBlocks,
RejectHeaderBlocks,
RejectBlockHeaders,
)
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.coin import hash_coin_ids, Coin
@ -325,7 +327,7 @@ async def request_header_blocks(
response = await peer.request_block_headers(RequestBlockHeaders(start_height, end_height, False))
else:
response = await peer.request_header_blocks(RequestHeaderBlocks(start_height, end_height))
if response is None:
if response is None or isinstance(response, RejectBlockHeaders) or isinstance(response, RejectHeaderBlocks):
return None
return response.header_blocks

View File

@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Set
from typing import List, Optional, Set
import aiosqlite
import sqlite3
@ -17,10 +17,6 @@ class WalletCoinStore:
"""
db_connection: aiosqlite.Connection
# coin_record_cache keeps ALL coin records in memory. [record_name: record]
coin_record_cache: Dict[bytes32, WalletCoinRecord]
# unspent_coin_wallet_cache keeps ALL unspent coin records for wallet in memory [wallet_id: [record_name: record]]
unspent_coin_wallet_cache: Dict[int, Dict[bytes32, WalletCoinRecord]]
db_wrapper: DBWrapper
@classmethod
@ -61,9 +57,6 @@ class WalletCoinStore:
await self.db_connection.execute("CREATE INDEX IF NOT EXISTS wallet_id on coin_record(wallet_id)")
await self.db_connection.commit()
self.coin_record_cache = {}
self.unspent_coin_wallet_cache = {}
await self.rebuild_wallet_cache()
return self
async def _clear_database(self):
@ -71,49 +64,23 @@ class WalletCoinStore:
await cursor.close()
await self.db_connection.commit()
async def rebuild_wallet_cache(self):
# First update all coins that were reorged, then re-add coin_records
all_coins = await self.get_all_coins()
self.unspent_coin_wallet_cache = {}
self.coin_record_cache = {}
for coin_record in all_coins:
name = coin_record.name()
self.coin_record_cache[name] = coin_record
if coin_record.spent is False:
if coin_record.wallet_id not in self.unspent_coin_wallet_cache:
self.unspent_coin_wallet_cache[coin_record.wallet_id] = {}
self.unspent_coin_wallet_cache[coin_record.wallet_id][name] = coin_record
async def get_multiple_coin_records(self, coin_names: List[bytes32]) -> List[WalletCoinRecord]:
"""Return WalletCoinRecord(s) that have a coin name in the specified list"""
if set(coin_names).issubset(set(self.coin_record_cache.keys())):
return list(filter(lambda cr: cr.coin.name() in coin_names, self.coin_record_cache.values()))
else:
as_hexes = [cn.hex() for cn in coin_names]
cursor = await self.db_connection.execute(
f'SELECT * from coin_record WHERE coin_name in ({"?," * (len(as_hexes) - 1)}?)', tuple(as_hexes)
)
rows = await cursor.fetchall()
await cursor.close()
if len(coin_names) == 0:
return []
return [self.coin_record_from_row(row) for row in rows]
as_hexes = [cn.hex() for cn in coin_names]
rows = await self.db_connection.execute_fetchall(
f'SELECT * from coin_record WHERE coin_name in ({"?," * (len(as_hexes) - 1)}?)', tuple(as_hexes)
)
return [self.coin_record_from_row(row) for row in rows]
# Store CoinRecord in DB and ram cache
async def add_coin_record(self, record: WalletCoinRecord, name: Optional[bytes32] = None) -> None:
# update wallet cache
if name is None:
name = record.name()
self.coin_record_cache[name] = record
if record.wallet_id in self.unspent_coin_wallet_cache:
if record.spent and name in self.unspent_coin_wallet_cache[record.wallet_id]:
self.unspent_coin_wallet_cache[record.wallet_id].pop(name)
if not record.spent:
self.unspent_coin_wallet_cache[record.wallet_id][name] = record
else:
if not record.spent:
self.unspent_coin_wallet_cache[record.wallet_id] = {}
self.unspent_coin_wallet_cache[record.wallet_id][name] = record
assert record.spent == (record.spent_block_height != 0)
cursor = await self.db_connection.execute(
"INSERT OR REPLACE INTO coin_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
@ -133,13 +100,6 @@ class WalletCoinStore:
# Sometimes we realize that a coin is actually not interesting to us so we need to delete it
async def delete_coin_record(self, coin_name: bytes32) -> None:
if coin_name in self.coin_record_cache:
coin_record = self.coin_record_cache.pop(coin_name)
if coin_record.wallet_id in self.unspent_coin_wallet_cache:
coin_cache = self.unspent_coin_wallet_cache[coin_record.wallet_id]
if coin_name in coin_cache:
coin_cache.pop(coin_record.coin.name())
c = await self.db_connection.execute("DELETE FROM coin_record WHERE coin_name=?", (coin_name.hex(),))
await c.close()
@ -170,130 +130,65 @@ class WalletCoinStore:
async def get_coin_record(self, coin_name: bytes32) -> Optional[WalletCoinRecord]:
"""Returns CoinRecord with specified coin id."""
if coin_name in self.coin_record_cache:
return self.coin_record_cache[coin_name]
cursor = await self.db_connection.execute("SELECT * from coin_record WHERE coin_name=?", (coin_name.hex(),))
row = await cursor.fetchone()
await cursor.close()
rows = list(
await self.db_connection.execute_fetchall("SELECT * from coin_record WHERE coin_name=?", (coin_name.hex(),))
)
if row is None:
if len(rows) == 0:
return None
return self.coin_record_from_row(row)
return self.coin_record_from_row(rows[0])
async def get_first_coin_height(self) -> Optional[uint32]:
"""Returns height of first confirmed coin"""
cursor = await self.db_connection.execute("SELECT MIN(confirmed_height) FROM coin_record;")
row = await cursor.fetchone()
await cursor.close()
rows = list(await self.db_connection.execute_fetchall("SELECT MIN(confirmed_height) FROM coin_record"))
if row is not None and row[0] is not None:
return uint32(row[0])
if len(rows) != 0 and rows[0][0] is not None:
return uint32(rows[0][0])
return None
async def get_unspent_coins_at_height(self, height: Optional[uint32] = None) -> Set[WalletCoinRecord]:
"""
Returns set of CoinRecords that have not been spent yet. If a height is specified,
We can also return coins that were unspent at this height (but maybe spent later).
Finally, the coins must be confirmed at the height or less.
"""
if height is None:
all_unspent = set()
for name, coin_record in self.coin_record_cache.items():
if coin_record.spent is False:
all_unspent.add(coin_record)
return all_unspent
else:
all_unspent = set()
for name, coin_record in self.coin_record_cache.items():
if (
coin_record.spent is False
or coin_record.spent_block_height > height >= coin_record.confirmed_block_height
):
all_unspent.add(coin_record)
return all_unspent
async def get_unspent_coins_for_wallet(self, wallet_id: int) -> Set[WalletCoinRecord]:
"""Returns set of CoinRecords that have not been spent yet for a wallet."""
if wallet_id in self.unspent_coin_wallet_cache:
wallet_coins: Dict[bytes32, WalletCoinRecord] = self.unspent_coin_wallet_cache[wallet_id]
return set(wallet_coins.values())
else:
return set()
async def get_all_coins(self) -> Set[WalletCoinRecord]:
"""Returns set of all CoinRecords."""
cursor = await self.db_connection.execute("SELECT * from coin_record")
rows = await cursor.fetchall()
await cursor.close()
rows = await self.db_connection.execute_fetchall(
"SELECT * FROM coin_record WHERE wallet_id=? AND spent_height=0", (wallet_id,)
)
return set(self.coin_record_from_row(row) for row in rows)
async def get_coins_to_check(self, check_height) -> Set[WalletCoinRecord]:
"""Returns set of all CoinRecords."""
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from coin_record where spent_height=0 or spent_height>? or confirmed_height>?",
(
check_height,
check_height,
),
)
rows = await cursor.fetchall()
await cursor.close()
return set(self.coin_record_from_row(row) for row in rows)
# Checks DB and DiffStores for CoinRecords with puzzle_hash and returns them
async def get_coin_records_by_puzzle_hash(self, puzzle_hash: bytes32) -> List[WalletCoinRecord]:
"""Returns a list of all coin records with the given puzzle hash"""
cursor = await self.db_connection.execute("SELECT * from coin_record WHERE puzzle_hash=?", (puzzle_hash.hex(),))
rows = await cursor.fetchall()
await cursor.close()
rows = await self.db_connection.execute_fetchall(
"SELECT * from coin_record WHERE puzzle_hash=?", (puzzle_hash.hex(),)
)
return [self.coin_record_from_row(row) for row in rows]
# Checks DB and DiffStores for CoinRecords with parent_coin_info and returns them
async def get_coin_records_by_parent_id(self, parent_coin_info: bytes32) -> List[WalletCoinRecord]:
"""Returns a list of all coin records with the given parent id"""
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from coin_record WHERE coin_parent=?", (parent_coin_info.hex(),)
)
rows = await cursor.fetchall()
await cursor.close()
return [self.coin_record_from_row(row) for row in rows]
async def rollback_to_block(self, height: int):
"""
Rolls back the blockchain to block_index. All blocks confirmed after this point
are removed from the LCA. All coins confirmed after this point are removed.
Rolls back the blockchain to block_index. All coins confirmed after this point are removed.
All coins spent after this point are set to unspent. Can be -1 (rollback all)
"""
# Delete from storage
delete_queue: List[WalletCoinRecord] = []
for coin_name, coin_record in self.coin_record_cache.items():
if coin_record.spent_block_height > height:
new_record = WalletCoinRecord(
coin_record.coin,
coin_record.confirmed_block_height,
uint32(0),
False,
coin_record.coinbase,
coin_record.wallet_type,
coin_record.wallet_id,
)
self.coin_record_cache[coin_record.coin.name()] = new_record
if coin_record.wallet_id in self.unspent_coin_wallet_cache:
self.unspent_coin_wallet_cache[coin_record.wallet_id][coin_record.coin.name()] = new_record
if coin_record.confirmed_block_height > height:
delete_queue.append(coin_record)
for coin_record in delete_queue:
self.coin_record_cache.pop(coin_record.coin.name())
if coin_record.wallet_id in self.unspent_coin_wallet_cache:
coin_cache = self.unspent_coin_wallet_cache[coin_record.wallet_id]
if coin_record.coin.name() in coin_cache:
coin_cache.pop(coin_record.coin.name())
c1 = await self.db_connection.execute("DELETE FROM coin_record WHERE confirmed_height>?", (height,))
await c1.close()

View File

@ -87,14 +87,14 @@ class WalletNode:
# Sync data
proof_hashes: List = dataclasses.field(default_factory=list)
state_changed_callback: Optional[Callable] = None
wallet_state_manager: Optional[WalletStateManager] = None
server: Optional[ChiaServer] = None
_wallet_state_manager: Optional[WalletStateManager] = None
_server: Optional[ChiaServer] = None
wsm_close_task: Optional[asyncio.Task] = None
sync_task: Optional[asyncio.Task] = None
logged_in_fingerprint: Optional[int] = None
peer_task: Optional[asyncio.Task] = None
logged_in: bool = False
keychain_proxy: Optional[KeychainProxy] = None
_keychain_proxy: Optional[KeychainProxy] = None
height_to_time: Dict[uint32, uint64] = dataclasses.field(default_factory=dict)
# Peers that we have long synced to
synced_peers: Set[bytes32] = dataclasses.field(default_factory=set)
@ -112,7 +112,7 @@ class WalletNode:
last_wallet_tx_resend_time: int = 0
# Duration in seconds
wallet_tx_resend_timeout_secs: int = 1800
new_peak_queue: NewPeakQueue = dataclasses.field(default_factory=lambda: NewPeakQueue(asyncio.PriorityQueue()))
_new_peak_queue: Optional[NewPeakQueue] = None
full_node_peer: Optional[PeerInfo] = None
_shut_down: bool = False
@ -120,15 +120,51 @@ class WalletNode:
_primary_peer_sync_task: Optional[asyncio.Task] = None
_secondary_peer_sync_task: Optional[asyncio.Task] = None
@property
def keychain_proxy(self) -> KeychainProxy:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._keychain_proxy is None:
raise RuntimeError("keychain proxy not assigned")
return self._keychain_proxy
@property
def wallet_state_manager(self) -> WalletStateManager:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._wallet_state_manager is None:
raise RuntimeError("wallet state manager not assigned")
return self._wallet_state_manager
@property
def server(self) -> ChiaServer:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._server is None:
raise RuntimeError("server not assigned")
return self._server
@property
def new_peak_queue(self) -> NewPeakQueue:
# This is a stop gap until the class usage is refactored such the values of
# integral attributes are known at creation of the instance.
if self._new_peak_queue is None:
raise RuntimeError("new peak queue not assigned")
return self._new_peak_queue
async def ensure_keychain_proxy(self) -> KeychainProxy:
if self.keychain_proxy is None:
if self._keychain_proxy is None:
if self.local_keychain:
self.keychain_proxy = wrap_local_keychain(self.local_keychain, log=self.log)
self._keychain_proxy = wrap_local_keychain(self.local_keychain, log=self.log)
else:
self.keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
if not self.keychain_proxy:
self._keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
if not self._keychain_proxy:
raise KeychainProxyConnectionFailure("Failed to connect to keychain service")
return self.keychain_proxy
return self._keychain_proxy
def get_cache_for_peer(self, peer) -> PeerRequestCache:
if peer.peer_node_id not in self.untrusted_caches:
@ -160,6 +196,11 @@ class WalletNode:
self,
fingerprint: Optional[int] = None,
) -> bool:
# Makes sure the coin_state_updates get higher priority than new_peak messages.
# Delayed instantiation until here to avoid errors.
# got Future <Future pending> attached to a different loop
self._new_peak_queue = NewPeakQueue(inner_queue=asyncio.PriorityQueue())
self.synced_peers = set()
private_key = await self.get_key_for_fingerprint(fingerprint)
if private_key is None:
@ -184,8 +225,7 @@ class WalletNode:
self.log.info(f"Copying wallet db from {standalone_path} to {path}")
path.write_bytes(standalone_path.read_bytes())
assert self.server is not None
self.wallet_state_manager = await WalletStateManager.create(
self._wallet_state_manager = await WalletStateManager.create(
private_key,
self.config,
path,
@ -195,7 +235,7 @@ class WalletNode:
self,
)
assert self.wallet_state_manager is not None
assert self._wallet_state_manager is not None
self.config["starting_height"] = 0
@ -241,16 +281,16 @@ class WalletNode:
async def _await_closed(self, shutting_down: bool = True):
self.log.info("self._await_closed")
if self.server is not None:
if self._server is not None:
await self.server.close_all_connections()
if self.wallet_peers is not None:
await self.wallet_peers.ensure_is_closed()
if self.wallet_state_manager is not None:
if self._wallet_state_manager is not None:
await self.wallet_state_manager._await_closed()
self.wallet_state_manager = None
if shutting_down and self.keychain_proxy is not None:
proxy = self.keychain_proxy
self.keychain_proxy = None
self._wallet_state_manager = None
if shutting_down and self._keychain_proxy is not None:
proxy = self._keychain_proxy
self._keychain_proxy = None
await proxy.close()
await asyncio.sleep(0.5) # https://docs.aiohttp.org/en/stable/client_advanced.html#graceful-shutdown
self.logged_in = False
@ -259,17 +299,17 @@ class WalletNode:
def _set_state_changed_callback(self, callback: Callable):
self.state_changed_callback = callback
if self.wallet_state_manager is not None:
if self._wallet_state_manager is not None:
self.wallet_state_manager.set_callback(self.state_changed_callback)
self.wallet_state_manager.set_pending_callback(self._pending_tx_handler)
def _pending_tx_handler(self):
if self.wallet_state_manager is None:
if self._wallet_state_manager is None:
return None
asyncio.create_task(self._resend_queue())
async def _action_messages(self) -> List[Message]:
if self.wallet_state_manager is None:
if self._wallet_state_manager is None:
return []
actions: List[WalletAction] = await self.wallet_state_manager.action_store.get_all_pending_actions()
result: List[Message] = []
@ -288,11 +328,11 @@ class WalletNode:
return result
async def _resend_queue(self):
if self._shut_down or self.server is None or self.wallet_state_manager is None:
if self._shut_down or self._server is None or self._wallet_state_manager is None:
return None
for msg, sent_peers in await self._messages_to_resend():
if self._shut_down or self.server is None or self.wallet_state_manager is None:
if self._shut_down or self._server is None or self._wallet_state_manager is None:
return None
full_nodes = self.server.get_full_node_connections()
for peer in full_nodes:
@ -302,12 +342,12 @@ class WalletNode:
await peer.send_message(msg)
for msg in await self._action_messages():
if self._shut_down or self.server is None or self.wallet_state_manager is None:
if self._shut_down or self._server is None or self._wallet_state_manager is None:
return None
await self.server.send_to_all([msg], NodeType.FULL_NODE)
async def _messages_to_resend(self) -> List[Tuple[Message, Set[bytes32]]]:
if self.wallet_state_manager is None or self._shut_down:
if self._wallet_state_manager is None or self._shut_down:
return []
messages: List[Tuple[Message, Set[bytes32]]] = []
@ -390,7 +430,7 @@ class WalletNode:
await peer.close(9999)
def set_server(self, server: ChiaServer):
self.server = server
self._server = server
self.initialize_wallet_peers()
def initialize_wallet_peers(self):
@ -433,7 +473,7 @@ class WalletNode:
self.node_peaks.pop(peer.peer_node_id)
async def on_connect(self, peer: WSChiaConnection):
if self.wallet_state_manager is None:
if self._wallet_state_manager is None:
return None
if Version(peer.protocol_version) < Version("0.0.33"):
@ -459,7 +499,6 @@ class WalletNode:
await self.wallet_peers.on_connect(peer)
async def perform_atomic_rollback(self, fork_height: int, cache: Optional[PeerRequestCache] = None):
assert self.wallet_state_manager is not None
self.log.info(f"perform_atomic_rollback to {fork_height}")
async with self.wallet_state_manager.db_wrapper.lock:
try:
@ -475,8 +514,6 @@ class WalletNode:
tb = traceback.format_exc()
self.log.error(f"Exception while perform_atomic_rollback: {e} {tb}")
await self.wallet_state_manager.db_wrapper.rollback_transaction()
await self.wallet_state_manager.coin_store.rebuild_wallet_cache()
await self.wallet_state_manager.tx_store.rebuild_tx_cache()
await self.wallet_state_manager.pool_store.rebuild_cache()
raise
else:
@ -514,7 +551,6 @@ class WalletNode:
trusted: bool = self.is_trusted(full_node)
self.log.info(f"Starting sync trusted: {trusted} to peer {full_node.peer_host}")
assert self.wallet_state_manager is not None
start_time = time.time()
if rollback:
@ -600,7 +636,7 @@ class WalletNode:
# Adds the state to the wallet state manager. If the peer is trusted, we do not validate. If the peer is
# untrusted we do, but we might not add the state, since we need to receive the new_peak message as well.
if self.wallet_state_manager is None:
if self._wallet_state_manager is None:
return False
trusted = self.is_trusted(peer)
# Validate states in parallel, apply serial
@ -632,7 +668,7 @@ class WalletNode:
items = sorted(items_input, key=last_change_height_cs)
async def receive_and_validate(inner_states: List[CoinState], inner_idx_start: int, cs_heights: List[uint32]):
assert self.wallet_state_manager is not None
assert self._wallet_state_manager is not None
try:
assert self.validation_semaphore is not None
async with self.validation_semaphore:
@ -652,7 +688,7 @@ class WalletNode:
f"new coin state received ({inner_idx_start}-"
f"{inner_idx_start + len(inner_states) - 1}/ {len(items)})"
)
if self.wallet_state_manager is None:
if self._wallet_state_manager is None:
return
try:
await self.wallet_state_manager.db_wrapper.begin_transaction()
@ -679,8 +715,6 @@ class WalletNode:
tb = traceback.format_exc()
self.log.error(f"Exception while adding state: {e} {tb}")
await self.wallet_state_manager.db_wrapper.rollback_transaction()
await self.wallet_state_manager.coin_store.rebuild_wallet_cache()
await self.wallet_state_manager.tx_store.rebuild_tx_cache()
await self.wallet_state_manager.pool_store.rebuild_cache()
else:
await self.wallet_state_manager.blockchain.clean_block_records()
@ -696,7 +730,7 @@ class WalletNode:
# Untrusted has a smaller batch size since validation has to happen which takes a while
chunk_size: int = 900 if trusted else 20
for states in chunks(items, chunk_size):
if self.server is None:
if self._server is None:
self.log.error("No server")
await asyncio.gather(*all_tasks)
return False
@ -716,8 +750,6 @@ class WalletNode:
await self.wallet_state_manager.db_wrapper.commit_transaction()
except Exception as e:
await self.wallet_state_manager.db_wrapper.rollback_transaction()
await self.wallet_state_manager.coin_store.rebuild_wallet_cache()
await self.wallet_state_manager.tx_store.rebuild_tx_cache()
await self.wallet_state_manager.pool_store.rebuild_cache()
tb = traceback.format_exc()
self.log.error(f"Error adding states.. {e} {tb}")
@ -736,14 +768,12 @@ class WalletNode:
all_tasks.append(asyncio.create_task(receive_and_validate(states, idx, concurrent_tasks_cs_heights)))
idx += len(states)
still_connected = self.server is not None and peer.peer_node_id in self.server.all_connections
still_connected = self._server is not None and peer.peer_node_id in self.server.all_connections
await asyncio.gather(*all_tasks)
await self.update_ui()
return still_connected and self.server is not None and peer.peer_node_id in self.server.all_connections
return still_connected and self._server is not None and peer.peer_node_id in self.server.all_connections
async def get_coins_with_puzzle_hash(self, puzzle_hash) -> List[CoinState]:
assert self.wallet_state_manager is not None
assert self.server is not None
# TODO Use trusted peer, otherwise try untrusted
all_nodes = self.server.connection_by_type[NodeType.FULL_NODE]
if len(all_nodes.keys()) == 0:
@ -771,7 +801,6 @@ class WalletNode:
return latest_timestamp
def is_trusted(self, peer) -> bool:
assert self.server is not None
return self.server.is_trusted_peer(peer, self.config["trusted_peers"])
def add_state_to_race_cache(self, header_hash: bytes32, height: uint32, coin_state: CoinState) -> None:
@ -793,8 +822,6 @@ class WalletNode:
# that is of interest to this wallet. It is not guaranteed to come for every height. This message is guaranteed
# to come before the corresponding new_peak for each height. We handle this differently for trusted and
# untrusted peers. For trusted, we always process the state, and we process reorgs as well.
assert self.wallet_state_manager is not None
assert self.server is not None
for coin in request.items:
self.log.info(f"request coin: {coin.coin.name()}{coin}")
@ -808,7 +835,7 @@ class WalletNode:
)
def get_full_node_peer(self) -> Optional[WSChiaConnection]:
if self.server is None:
if self._server is None:
return None
nodes = self.server.get_full_node_connections()
@ -818,7 +845,7 @@ class WalletNode:
return None
async def disconnect_and_stop_wpeers(self) -> None:
if self.server is None:
if self._server is None:
return
# Close connection of non-trusted peers
@ -832,7 +859,7 @@ class WalletNode:
self.wallet_peers = None
async def check_for_synced_trusted_peer(self, header_block: HeaderBlock, request_time: uint64) -> bool:
if self.server is None:
if self._server is None:
return False
for peer in self.server.get_full_node_connections():
if self.is_trusted(peer) and await self.is_peer_synced(peer, header_block, request_time):
@ -864,10 +891,9 @@ class WalletNode:
return last_tx_block.foliage_transaction_block.timestamp
async def new_peak_wallet(self, new_peak: wallet_protocol.NewPeakWallet, peer: WSChiaConnection):
if self.wallet_state_manager is None:
if self._wallet_state_manager is None:
# When logging out of wallet
return
assert self.server is not None
request_time = uint64(int(time.time()))
trusted: bool = self.is_trusted(peer)
peak_hb: Optional[HeaderBlock] = await self.wallet_state_manager.blockchain.get_peak_block()
@ -1055,7 +1081,6 @@ class WalletNode:
await self.wallet_state_manager.new_peak(new_peak)
async def wallet_short_sync_backtrack(self, header_block: HeaderBlock, peer: WSChiaConnection) -> int:
assert self.wallet_state_manager is not None
peak: Optional[HeaderBlock] = await self.wallet_state_manager.blockchain.get_peak_block()
top = header_block
@ -1103,7 +1128,6 @@ class WalletNode:
async def fetch_and_validate_the_weight_proof(
self, peer: WSChiaConnection, peak: HeaderBlock
) -> Tuple[bool, Optional[WeightProof], List[SubEpochSummary], List[BlockRecord]]:
assert self.wallet_state_manager is not None
assert self.wallet_state_manager.weight_proof_handler is not None
weight_request = RequestProofOfWeight(peak.height, peak.header_hash)
@ -1142,7 +1166,6 @@ class WalletNode:
return valid, weight_proof, summaries, block_records
async def get_puzzle_hashes_to_subscribe(self) -> List[bytes32]:
assert self.wallet_state_manager is not None
all_puzzle_hashes = list(await self.wallet_state_manager.puzzle_store.get_all_puzzle_hashes())
# Get all phs from interested store
interested_puzzle_hashes = [
@ -1152,7 +1175,6 @@ class WalletNode:
return all_puzzle_hashes
async def get_coin_ids_to_subscribe(self, min_height: int) -> List[bytes32]:
assert self.wallet_state_manager is not None
all_coins: Set[WalletCoinRecord] = await self.wallet_state_manager.coin_store.get_coins_to_check(min_height)
all_coin_names: Set[bytes32] = {coin_record.name() for coin_record in all_coins}
removed_dict = await self.wallet_state_manager.trade_manager.get_coins_of_interest()
@ -1171,8 +1193,6 @@ class WalletNode:
Returns all state that is valid and included in the blockchain proved by the weight proof. If return_old_states
is False, only new states that are not in the coin_store are returned.
"""
assert self.wallet_state_manager is not None
# Only use the cache if we are talking about states before the fork point. If we are evaluating something
# in a reorg, we cannot use the cache, since we don't know if it's actually in the new chain after the reorg.
if await can_use_peer_request_cache(coin_state, peer_request_cache, fork_height):
@ -1308,8 +1328,6 @@ class WalletNode:
async def validate_block_inclusion(
self, block: HeaderBlock, peer: WSChiaConnection, peer_request_cache: PeerRequestCache
) -> bool:
assert self.wallet_state_manager is not None
assert self.server is not None
if self.wallet_state_manager.blockchain.contains_height(block.height):
stored_hash = self.wallet_state_manager.blockchain.height_to_hash(block.height)
stored_record = self.wallet_state_manager.blockchain.try_block_record(stored_hash)
@ -1469,7 +1487,6 @@ class WalletNode:
async def get_coin_state(
self, coin_names: List[bytes32], fork_height: Optional[uint32] = None, peer: Optional[WSChiaConnection] = None
) -> List[CoinState]:
assert self.server is not None
all_nodes = self.server.connection_by_type[NodeType.FULL_NODE]
if len(all_nodes.keys()) == 0:
raise ValueError("Not connected to the full node")

View File

@ -79,8 +79,13 @@ class WalletNodeAPI:
assert peer.peer_node_id is not None
name = peer.peer_node_id.hex()
status = MempoolInclusionStatus(ack.status)
if self.wallet_node.wallet_state_manager is None:
return None
try:
wallet_state_manager = self.wallet_node.wallet_state_manager
except RuntimeError as e:
if "not assigned" in str(e):
return None
raise
if status == MempoolInclusionStatus.SUCCESS:
self.wallet_node.log.info(f"SpendBundle has been received and accepted to mempool by the FullNode. {ack}")
elif status == MempoolInclusionStatus.PENDING:
@ -92,9 +97,9 @@ class WalletNodeAPI:
return
self.wallet_node.log.warning(f"SpendBundle has been rejected by the FullNode. {ack}")
if ack.error is not None:
await self.wallet_node.wallet_state_manager.remove_from_queue(ack.txid, name, status, Err[ack.error])
await wallet_state_manager.remove_from_queue(ack.txid, name, status, Err[ack.error])
else:
await self.wallet_node.wallet_state_manager.remove_from_queue(ack.txid, name, status, None)
await wallet_state_manager.remove_from_queue(ack.txid, name, status, None)
@peer_required
@api_request
@ -120,9 +125,12 @@ class WalletNodeAPI:
@api_request
async def respond_puzzle_solution(self, request: wallet_protocol.RespondPuzzleSolution):
if self.wallet_node.wallet_state_manager is None:
return None
await self.wallet_node.wallet_state_manager.puzzle_solution_received(request)
try:
await self.wallet_node.wallet_state_manager.puzzle_solution_received(request)
except RuntimeError as e:
if "not assigned" in str(e):
return None
raise
@api_request
async def reject_puzzle_solution(self, request: wallet_protocol.RejectPuzzleSolution):
@ -140,6 +148,10 @@ class WalletNodeAPI:
async def reject_header_blocks(self, request: wallet_protocol.RejectHeaderBlocks):
self.log.warning(f"Reject header blocks: {request}")
@api_request
async def reject_block_headers(self, request: wallet_protocol.RejectBlockHeaders):
pass
@execute_task
@peer_required
@api_request

View File

@ -4,6 +4,7 @@ import logging
import multiprocessing
import multiprocessing.context
import time
from datetime import datetime
from collections import defaultdict
from pathlib import Path
from secrets import token_bytes
@ -35,6 +36,7 @@ from chia.util.config import process_config_start_method
from chia.util.db_synchronous import db_synchronous_on
from chia.util.db_wrapper import DBWrapper
from chia.util.errors import Err
from chia.util.path import path_from_root
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.lru_cache import LRUCache
from chia.wallet.cat_wallet.cat_constants import DEFAULT_CATS
@ -156,6 +158,17 @@ class WalletStateManager:
"pragma synchronous={}".format(db_synchronous_on(self.config.get("db_sync", "auto"), db_path))
)
if self.config.get("log_sqlite_cmds", False):
sql_log_path = path_from_root(self.root_path, "log/wallet_sql.log")
self.log.info(f"logging SQL commands to {sql_log_path}")
def sql_trace_callback(req: str):
timestamp = datetime.now().strftime("%H:%M:%S.%f")
with open(sql_log_path, "a") as log:
log.write(timestamp + " " + req + "\n")
await self.db_connection.set_trace_callback(sql_trace_callback)
self.db_wrapper = DBWrapper(self.db_connection)
self.coin_store = await WalletCoinStore.create(self.db_wrapper)
self.tx_store = await WalletTransactionStore.create(self.db_wrapper)

View File

@ -30,9 +30,7 @@ class WalletTransactionStore:
db_connection: aiosqlite.Connection
db_wrapper: DBWrapper
tx_record_cache: Dict[bytes32, TransactionRecord]
tx_submitted: Dict[bytes32, Tuple[int, int]] # tx_id: [time submitted: count]
unconfirmed_for_wallet: Dict[int, Dict[bytes32, TransactionRecord]]
last_wallet_tx_resend_time: int # Epoch time in seconds
@classmethod
@ -87,26 +85,10 @@ class WalletTransactionStore:
)
await self.db_connection.commit()
self.tx_record_cache = {}
self.tx_submitted = {}
self.unconfirmed_for_wallet = {}
self.last_wallet_tx_resend_time = int(time.time())
await self.rebuild_tx_cache()
return self
async def rebuild_tx_cache(self):
# init cache here
all_records = await self.get_all_transactions()
self.tx_record_cache = {}
self.unconfirmed_for_wallet = {}
for record in all_records:
self.tx_record_cache[record.name] = record
if record.wallet_id not in self.unconfirmed_for_wallet:
self.unconfirmed_for_wallet[record.wallet_id] = {}
if not record.confirmed:
self.unconfirmed_for_wallet[record.wallet_id][record.name] = record
async def _clear_database(self):
cursor = await self.db_connection.execute("DELETE FROM transaction_record")
await cursor.close()
@ -116,19 +98,10 @@ class WalletTransactionStore:
"""
Store TransactionRecord in DB and Cache.
"""
self.tx_record_cache[record.name] = record
if record.wallet_id not in self.unconfirmed_for_wallet:
self.unconfirmed_for_wallet[record.wallet_id] = {}
unconfirmed_dict = self.unconfirmed_for_wallet[record.wallet_id]
if record.confirmed and record.name in unconfirmed_dict:
unconfirmed_dict.pop(record.name)
if not record.confirmed:
unconfirmed_dict[record.name] = record
if not in_transaction:
await self.db_wrapper.lock.acquire()
try:
cursor = await self.db_connection.execute(
await self.db_connection.execute_insert(
"INSERT OR REPLACE INTO transaction_record VALUES(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(
bytes(record),
@ -145,25 +118,13 @@ class WalletTransactionStore:
record.type,
),
)
await cursor.close()
if not in_transaction:
await self.db_connection.commit()
except BaseException:
if not in_transaction:
await self.rebuild_tx_cache()
raise
finally:
if not in_transaction:
self.db_wrapper.lock.release()
async def delete_transaction_record(self, tx_id: bytes32) -> None:
if tx_id in self.tx_record_cache:
tx_record = self.tx_record_cache.pop(tx_id)
if tx_record.wallet_id in self.unconfirmed_for_wallet:
tx_cache = self.unconfirmed_for_wallet[tx_record.wallet_id]
if tx_id in tx_cache:
tx_cache.pop(tx_id)
c = await self.db_connection.execute("DELETE FROM transaction_record WHERE bundle_id=?", (tx_id,))
await c.close()
@ -277,16 +238,12 @@ class WalletTransactionStore:
"""
Checks DB and cache for TransactionRecord with id: id and returns it.
"""
if tx_id in self.tx_record_cache:
return self.tx_record_cache[tx_id]
# NOTE: bundle_id is being stored as bytes, not hex
cursor = await self.db_connection.execute("SELECT * from transaction_record WHERE bundle_id=?", (tx_id,))
row = await cursor.fetchone()
await cursor.close()
if row is not None:
record = TransactionRecord.from_bytes(row[0])
return record
rows = list(
await self.db_connection.execute_fetchall("SELECT * from transaction_record WHERE bundle_id=?", (tx_id,))
)
if len(rows) > 0:
return TransactionRecord.from_bytes(rows[0][0])
return None
async def get_not_sent(self, *, include_accepted_txs=False) -> List[TransactionRecord]:
@ -294,12 +251,10 @@ class WalletTransactionStore:
Returns the list of transactions that have not been received by full node yet.
"""
current_time = int(time.time())
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from transaction_record WHERE confirmed=?",
(0,),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
@ -331,11 +286,9 @@ class WalletTransactionStore:
"""
fee_int = TransactionType.FEE_REWARD.value
pool_int = TransactionType.COINBASE_REWARD.value
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from transaction_record WHERE confirmed=? and (type=? or type=?)", (1, fee_int, pool_int)
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
@ -349,9 +302,7 @@ class WalletTransactionStore:
Returns the list of all transaction that have not yet been confirmed.
"""
cursor = await self.db_connection.execute("SELECT * from transaction_record WHERE confirmed=?", (0,))
rows = await cursor.fetchall()
await cursor.close()
rows = await self.db_connection.execute_fetchall("SELECT * from transaction_record WHERE confirmed=?", (0,))
records = []
for row in rows:
@ -364,10 +315,10 @@ class WalletTransactionStore:
"""
Returns the list of transaction that have not yet been confirmed.
"""
if wallet_id in self.unconfirmed_for_wallet:
return list(self.unconfirmed_for_wallet[wallet_id].values())
else:
return []
rows = await self.db_connection.execute_fetchall(
"SELECT transaction_record from transaction_record WHERE confirmed=0 AND wallet_id=?", (wallet_id,)
)
return [TransactionRecord.from_bytes(row[0]) for row in rows]
async def get_transactions_between(
self, wallet_id: int, start, end, sort_key=None, reverse=False, to_puzzle_hash: Optional[bytes32] = None
@ -392,113 +343,64 @@ class WalletTransactionStore:
else:
query_str = SortKey[sort_key].ascending()
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
f"SELECT * from transaction_record where wallet_id=?{puzz_hash_where}"
f" {query_str}, rowid"
f" LIMIT {start}, {limit}",
(wallet_id,),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
return [TransactionRecord.from_bytes(row[0]) for row in rows]
async def get_transaction_count_for_wallet(self, wallet_id) -> int:
cursor = await self.db_connection.execute(
"SELECT COUNT(*) FROM transaction_record where wallet_id=?", (wallet_id,)
rows = list(
await self.db_connection.execute_fetchall(
"SELECT COUNT(*) FROM transaction_record where wallet_id=?", (wallet_id,)
)
)
count_result = await cursor.fetchone()
if count_result is not None:
count = count_result[0]
else:
count = 0
await cursor.close()
return count
return 0 if len(rows) == 0 else rows[0][0]
async def get_all_transactions_for_wallet(self, wallet_id: int, type: int = None) -> List[TransactionRecord]:
"""
Returns all stored transactions.
"""
if type is None:
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from transaction_record where wallet_id=?", (wallet_id,)
)
else:
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from transaction_record where wallet_id=? and type=?",
(
wallet_id,
type,
),
)
rows = await cursor.fetchall()
await cursor.close()
records = []
cache_set = set()
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
cache_set.add(record.name)
return records
return [TransactionRecord.from_bytes(row[0]) for row in rows]
async def get_all_transactions(self) -> List[TransactionRecord]:
"""
Returns all stored transactions.
"""
cursor = await self.db_connection.execute("SELECT * from transaction_record")
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
rows = await self.db_connection.execute_fetchall("SELECT * from transaction_record")
return [TransactionRecord.from_bytes(row[0]) for row in rows]
async def get_transaction_above(self, height: int) -> List[TransactionRecord]:
# Can be -1 (get all tx)
cursor = await self.db_connection.execute(
rows = await self.db_connection.execute_fetchall(
"SELECT * from transaction_record WHERE confirmed_at_height>?", (height,)
)
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
return [TransactionRecord.from_bytes(row[0]) for row in rows]
async def get_transactions_by_trade_id(self, trade_id: bytes32) -> List[TransactionRecord]:
cursor = await self.db_connection.execute("SELECT * from transaction_record WHERE trade_id=?", (trade_id,))
rows = await cursor.fetchall()
await cursor.close()
records = []
for row in rows:
record = TransactionRecord.from_bytes(row[0])
records.append(record)
return records
rows = await self.db_connection.execute_fetchall(
"SELECT * from transaction_record WHERE trade_id=?", (trade_id,)
)
return [TransactionRecord.from_bytes(row[0]) for row in rows]
async def rollback_to_block(self, height: int):
# Delete from storage
to_delete = []
for tx in self.tx_record_cache.values():
if tx.confirmed_at_height > height:
to_delete.append(tx)
for tx in to_delete:
self.tx_record_cache.pop(tx.name)
self.tx_submitted = {}
c1 = await self.db_connection.execute("DELETE FROM transaction_record WHERE confirmed_at_height>?", (height,))
await c1.close()

View File

@ -55,7 +55,7 @@ class WalletUserStore:
async def create_wallet(
self, name: str, wallet_type: int, data: str, id: Optional[int] = None, in_transaction=False
) -> Optional[WalletInfo]:
) -> WalletInfo:
if not in_transaction:
await self.db_wrapper.lock.acquire()
@ -65,12 +65,15 @@ class WalletUserStore:
(id, name, wallet_type, data),
)
await cursor.close()
wallet = await self.get_last_wallet()
if wallet is None:
raise ValueError("Failed to get the just-created wallet")
finally:
if not in_transaction:
await self.db_connection.commit()
self.db_wrapper.lock.release()
return await self.get_last_wallet()
return wallet
async def delete_wallet(self, id: int, in_transaction: bool):
if not in_transaction:

File diff suppressed because one or more lines are too long

View File

@ -182,12 +182,14 @@ ignored-modules=blspy,
chiabip158,
chiapos,
chiavdf,
chia_rs,
cryptography,
aiohttp,
keyring,
keyrings.cryptfile,
bitstring,
clvm_tools,
clvm_tools_rs,
setproctitle,
clvm,
colorlog,

View File

@ -28,7 +28,7 @@ dependencies = [
"sortedcontainers==2.4.0", # For maintaining sorted mempools
# TODO: when moving to click 8 remove the pinning of black noted below
"click==7.1.2", # For the CLI
"dnspythonchia==2.2.0", # Query DNS seeds
"dnspython==2.2.0", # Query DNS seeds
"watchdog==2.1.9", # Filesystem event watching - watches keyring.yaml
"dnslib==0.9.17", # dns lib
"typing-extensions==4.0.1", # typing backports like Protocol and TypedDict
@ -44,6 +44,8 @@ dev_dependencies = [
"build",
"coverage",
"pre-commit",
"py3createtorrent",
"pylint",
"pytest",
"pytest-asyncio>=0.18.1", # require attribute 'fixture'
"pytest-monitor; sys_platform == 'linux'",

View File

@ -69,7 +69,7 @@ async def one_wallet_node_and_rpc(bt: BlockTools) -> AsyncIterator[nodes_with_po
hostname,
daemon_port,
wallet_node_0.server._port,
lambda x: None,
lambda: None,
bt.root_path,
config,
connect_to_daemon=False,
@ -85,7 +85,6 @@ async def test_create_insert_get(one_wallet_node_and_rpc: nodes_with_port, bt: B
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -154,7 +153,6 @@ async def test_upsert(one_wallet_node_and_rpc: nodes_with_port, bt: BlockTools,
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -200,7 +198,6 @@ async def test_create_double_insert(one_wallet_node_and_rpc: nodes_with_port, bt
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -262,7 +259,6 @@ async def test_keys_values_ancestors(one_wallet_node_and_rpc: nodes_with_port, b
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -347,7 +343,6 @@ async def test_get_roots(one_wallet_node_and_rpc: nodes_with_port, bt: BlockTool
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -422,7 +417,6 @@ async def test_get_root_history(one_wallet_node_and_rpc: nodes_with_port, bt: Bl
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -500,7 +494,6 @@ async def test_get_kv_diff(one_wallet_node_and_rpc: nodes_with_port, bt: BlockTo
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -593,7 +586,6 @@ async def test_batch_update_matches_single_operations(
num_blocks = 15
assert wallet_node.server
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -704,7 +696,6 @@ async def test_get_owned_stores(one_wallet_node_and_rpc: nodes_with_port, bt: Bl
num_blocks = 4
assert wallet_node.server is not None
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))
@ -741,7 +732,6 @@ async def test_subscriptions(one_wallet_node_and_rpc: nodes_with_port, bt: Block
num_blocks = 4
assert wallet_node.server is not None
await wallet_node.server.start_client(PeerInfo("localhost", uint16(full_node_api.server._port)), None)
assert wallet_node.wallet_state_manager is not None
ph = await wallet_node.wallet_state_manager.main_wallet.get_new_puzzlehash()
for i in range(0, num_blocks):
await full_node_api.farm_new_transaction_block(FarmNewBlockProtocol(ph))

View File

@ -22,7 +22,6 @@ def wallet_height_at_least(wallet_node, h):
async def wallet_balance_at_least(wallet_node: WalletNode, balance):
assert wallet_node.wallet_state_manager is not None
b = await wallet_node.wallet_state_manager.get_confirmed_balance_for_wallet(1)
if b >= balance:
return True

View File

@ -9,7 +9,8 @@ from chia.full_node.signage_point import SignagePoint
from chia.protocols import full_node_protocol
from chia.rpc.full_node_rpc_api import FullNodeRpcApi
from chia.rpc.full_node_rpc_client import FullNodeRpcClient
from chia.rpc.rpc_server import NodeType, start_rpc_server
from chia.rpc.rpc_server import start_rpc_server
from chia.server.outbound_message import NodeType
from chia.simulator.simulator_protocol import FarmNewBlockProtocol, ReorgProtocol
from chia.types.full_block import FullBlock
from chia.types.spend_bundle import SpendBundle

View File

@ -400,13 +400,13 @@ def test_post_init_valid(test_class: Type[Any], args: Tuple[Any, ...]) -> None:
(PostInitTestClassBasic, (1, "test", b"\00\01", b"\12" * 31, G1Element()), ValueError),
(PostInitTestClassBasic, (1, "test", b"\00\01", b"\12" * 32, b"\12" * 10), ValueError),
(PostInitTestClassBad, (1, 2), TypeError),
(PostInitTestClassList, ({"1": 1}, [[uint8(200), uint8(25)], [uint8(25)]]), ValueError),
(PostInitTestClassList, (("1", 1), [[uint8(200), uint8(25)], [uint8(25)]]), ValueError),
(PostInitTestClassList, ([1, 2, 3], [uint8(200), uint8(25)]), ValueError),
(PostInitTestClassList, ({"1": 1}, [[uint8(200), uint8(25)], [uint8(25)]]), TypeError),
(PostInitTestClassList, (("1", 1), [[uint8(200), uint8(25)], [uint8(25)]]), TypeError),
(PostInitTestClassList, ([1, 2, 3], [uint8(200), uint8(25)]), TypeError),
(PostInitTestClassTuple, ((1,), ((200, "test_2"), b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test", 1), ((200, "test_2"), b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test"), ({"a": 2}, b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test"), (G1Element(), b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test"), (G1Element(), b"\xba" * 32)), TypeError),
(PostInitTestClassOptional, ([], None, None, None), ValueError),
],
)

View File

@ -161,7 +161,7 @@ class TestCompression(TestCase):
def test_spend_byndle_coin_spend(self):
for i in range(0, 10):
sb: SpendBundle = make_spend_bundle(i)
cs1 = SExp.to(spend_bundle_to_coin_spend_entry_list(sb)).as_bin()
cs1 = SExp.to(spend_bundle_to_coin_spend_entry_list(sb)).as_bin() # pylint: disable=E1101
cs2 = spend_bundle_to_serialized_coin_spend_entry_list(sb)
assert cs1 == cs2

View File

@ -81,7 +81,6 @@ class TemporaryPoolPlot:
async def wallet_is_synced(wallet_node: WalletNode, full_node_api) -> bool:
assert wallet_node.wallet_state_manager is not None
wallet_height = await wallet_node.wallet_state_manager.blockchain.get_finished_sync_up_to()
full_node_height = full_node_api.full_node.blockchain.get_peak_height()
return wallet_height == full_node_height
@ -111,7 +110,7 @@ async def one_wallet_node_and_rpc(bt, self_hostname) -> AsyncGenerator[Tuple[Wal
self_hostname,
daemon_port,
uint16(0),
lambda x: None,
lambda: None,
bt.root_path,
config,
connect_to_daemon=False,

View File

@ -1,5 +1,6 @@
from pathlib import Path
from chia.util.db_wrapper import DBWrapper2
from chia.util.db_wrapper import DBWrapper
import tempfile
import aiosqlite
@ -33,3 +34,18 @@ class DBConnection:
async def __aexit__(self, exc_t, exc_v, exc_tb) -> None:
await self._db_wrapper.close()
self.db_path.unlink()
# This is just here until all DBWrappers have been upgraded to DBWrapper2
class DBConnection1:
async def __aenter__(self) -> DBWrapper:
self.db_path = Path(tempfile.NamedTemporaryFile().name)
if self.db_path.exists():
self.db_path.unlink()
self._db_connection = await aiosqlite.connect(self.db_path)
self._db_wrapper = DBWrapper(self._db_connection)
return self._db_wrapper
async def __aexit__(self, exc_t, exc_v, exc_tb) -> None:
await self._db_connection.close()
self.db_path.unlink()

View File

@ -124,4 +124,4 @@ async def test_graftroot(setup_sim: Tuple[SpendSim, SimClient]) -> None:
with pytest.raises(ValueError, match="clvm raise"):
graftroot_puzzle.run(graftroot_spend.solution.to_program())
finally:
await sim.close() # type: ignore[no-untyped-call]
await sim.close()

View File

@ -71,7 +71,6 @@ class TestDLWallet:
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
assert wallet_node_0.wallet_state_manager is not None
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
if trusted:
@ -120,7 +119,6 @@ class TestDLWallet:
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
assert wallet_node_0.wallet_state_manager is not None
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
if trusted:
@ -174,8 +172,6 @@ class TestDLWallet:
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_node_1, server_1 = wallets[1]
assert wallet_node_0.wallet_state_manager is not None
assert wallet_node_1.wallet_state_manager is not None
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
@ -250,7 +246,6 @@ class TestDLWallet:
full_node_api = full_nodes[0]
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
assert wallet_node_0.wallet_state_manager is not None
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
if trusted:
@ -333,8 +328,6 @@ class TestDLWallet:
full_node_server = full_node_api.server
wallet_node_0, server_0 = wallets[0]
wallet_node_1, server_1 = wallets[1]
assert wallet_node_0.wallet_state_manager is not None
assert wallet_node_1.wallet_state_manager is not None
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet

View File

@ -131,7 +131,7 @@ async def test_state_layer(setup_sim: Tuple[SpendSim, SimClient], metadata_updat
await sim.farm_block()
state_layer_puzzle = create_nft_layer_puzzle_with_curry_params(metadata, METADATA_UPDATER_PUZZLE_HASH, ACS)
finally:
await sim.close() # type: ignore
await sim.close()
@pytest.mark.asyncio()
@ -238,7 +238,7 @@ async def test_ownership_layer(setup_sim: Tuple[SpendSim, SimClient]) -> None:
ACS,
).get_tree_hash()
finally:
await sim.close() # type: ignore
await sim.close()
@pytest.mark.asyncio()
@ -362,4 +362,4 @@ async def test_default_transfer_program(setup_sim: Tuple[SpendSim, SimClient]) -
assert result == (MempoolInclusionStatus.SUCCESS, None)
await sim.farm_block()
finally:
await sim.close() # type: ignore
await sim.close()

View File

@ -50,7 +50,6 @@ class TestWalletRpc:
full_node_server = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
wallet_node_2, server_3 = wallets[1]
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()

View File

@ -607,7 +607,6 @@ async def test_cat_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment):
await farm_transaction(full_node_api, wallet_node, spend_bundle)
# Test unacknowledged CAT
assert wallet_node.wallet_state_manager is not None
await wallet_node.wallet_state_manager.interested_store.add_unacknowledged_token(
asset_id, "Unknown", uint32(10000), bytes32(b"\00" * 32)
)
@ -814,8 +813,6 @@ async def test_did_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment):
for _ in range(3):
await farm_transaction_block(full_node_api, wallet_1_node)
assert wallet_2_node.wallet_state_manager is not None
did_wallets = list(
filter(
lambda w: (w.type == WalletType.DECENTRALIZED_ID),
@ -857,8 +854,6 @@ async def test_nft_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment):
for _ in range(3):
await farm_transaction_block(full_node_api, wallet_1_node)
assert wallet_1_node.wallet_state_manager is not None
nft_wallet: NFTWallet = wallet_1_node.wallet_state_manager.wallets[nft_wallet_id]
# Test with the hex version of nft_id
nft_id = nft_wallet.get_current_nfts()[0].coin.name().hex()
@ -876,8 +871,6 @@ async def test_nft_endpoints(wallet_rpc_environment: WalletRpcTestEnvironment):
for _ in range(3):
await farm_transaction_block(full_node_api, wallet_1_node)
assert wallet_2_node.wallet_state_manager is not None
nft_wallet_id_1 = (
await wallet_2_node.wallet_state_manager.get_all_wallet_info_entries(wallet_type=WalletType.NFT)
)[0].id

View File

@ -43,7 +43,6 @@ class TestWalletSimulator:
server_1: ChiaServer = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
if trusted:
@ -65,7 +64,6 @@ class TestWalletSimulator:
)
async def check_tx_are_pool_farm_rewards() -> bool:
assert wallet_node.wallet_state_manager is not None
wsm: WalletStateManager = wallet_node.wallet_state_manager
all_txs = await wsm.get_all_transactions(1)
expected_count = (num_blocks + 1) * 2
@ -105,9 +103,7 @@ class TestWalletSimulator:
full_node_api = full_nodes[0]
server_1 = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet_node_2, server_3 = wallets[1]
assert wallet_node_2.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
if trusted:
@ -169,7 +165,6 @@ class TestWalletSimulator:
full_node_api = full_nodes[0]
fn_server = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
if trusted:
@ -215,7 +210,6 @@ class TestWalletSimulator:
full_nodes, wallets = three_sim_two_wallets
wallet_0, wallet_server_0 = wallets[0]
assert wallet_0.wallet_state_manager is not None
full_node_api_0 = full_nodes[0]
full_node_api_1 = full_nodes[1]
@ -294,9 +288,7 @@ class TestWalletSimulator:
server_0 = full_node_0.server
wallet_node_0, wallet_0_server = wallets[0]
assert wallet_node_0.wallet_state_manager is not None
wallet_node_1, wallet_1_server = wallets[1]
assert wallet_node_1.wallet_state_manager is not None
wallet_0 = wallet_node_0.wallet_state_manager.main_wallet
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
@ -420,9 +412,7 @@ class TestWalletSimulator:
full_node_1 = full_nodes[0]
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet_node_2, server_3 = wallets[1]
assert wallet_node_2.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
@ -497,7 +487,6 @@ class TestWalletSimulator:
full_node_1 = full_nodes[0]
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet_node_2, server_3 = wallets[1]
wallet = wallet_node.wallet_state_manager.main_wallet
@ -602,9 +591,7 @@ class TestWalletSimulator:
full_node_1 = full_nodes[0]
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet_node_2, server_3 = wallets[1]
assert wallet_node_2.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()
@ -698,9 +685,7 @@ class TestWalletSimulator:
fn_server = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet_node_2, server_3 = wallets[1]
assert wallet_node_2.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
wallet_2 = wallet_node_2.wallet_state_manager.main_wallet
@ -805,7 +790,6 @@ class TestWalletSimulator:
full_node_api = full_nodes[0]
server_1: ChiaServer = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
if trusted:
wallet_node.config["trusted_peers"] = {server_1.node_id.hex(): server_1.node_id.hex()}
else:
@ -856,9 +840,7 @@ class TestWalletSimulator:
server_1 = full_node_api.full_node.server
wallet_node, server_2 = wallets[0]
assert wallet_node.wallet_state_manager is not None
wallet_node_2, server_3 = wallets[1]
assert wallet_node_2.wallet_state_manager is not None
wallet = wallet_node.wallet_state_manager.main_wallet
ph = await wallet.get_new_puzzlehash()

View File

@ -0,0 +1,387 @@
from secrets import token_bytes
import pytest
from chia.types.blockchain_format.coin import Coin
from chia.util.ints import uint32, uint64
from chia.wallet.util.wallet_types import WalletType
from chia.wallet.wallet_coin_record import WalletCoinRecord
from chia.wallet.wallet_coin_store import WalletCoinStore
from tests.util.db_connection import DBConnection1
coin_1 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_2 = Coin(coin_1.parent_coin_info, token_bytes(32), uint64(12311))
coin_3 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_4 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_5 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
coin_6 = Coin(token_bytes(32), coin_4.puzzle_hash, uint64(12312))
coin_7 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
record_replaced = WalletCoinRecord(coin_1, uint32(8), uint32(0), False, True, WalletType.STANDARD_WALLET, 0)
record_1 = WalletCoinRecord(coin_1, uint32(4), uint32(0), False, True, WalletType.STANDARD_WALLET, 0)
record_2 = WalletCoinRecord(coin_2, uint32(5), uint32(0), False, True, WalletType.STANDARD_WALLET, 0)
record_3 = WalletCoinRecord(
coin_3,
uint32(5),
uint32(10),
True,
False,
WalletType.STANDARD_WALLET,
0,
)
record_4 = WalletCoinRecord(
coin_4,
uint32(5),
uint32(15),
True,
False,
WalletType.STANDARD_WALLET,
0,
)
record_5 = WalletCoinRecord(
coin_5,
uint32(5),
uint32(0),
False,
False,
WalletType.STANDARD_WALLET,
1,
)
record_6 = WalletCoinRecord(
coin_6,
uint32(5),
uint32(15),
True,
False,
WalletType.STANDARD_WALLET,
2,
)
record_7 = WalletCoinRecord(
coin_7,
uint32(5),
uint32(0),
False,
False,
WalletType.POOLING_WALLET,
2,
)
@pytest.mark.asyncio
async def test_add_replace_get() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
assert await store.get_coin_record(coin_1.name()) is None
await store.add_coin_record(record_replaced)
await store.add_coin_record(record_1)
await store.add_coin_record(record_2)
await store.add_coin_record(record_3)
await store.add_coin_record(record_4)
assert await store.get_coin_record(coin_1.name()) == record_1
@pytest.mark.asyncio
async def test_persistance() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(record_1)
store = await WalletCoinStore.create(db_wrapper)
assert await store.get_coin_record(coin_1.name()) == record_1
@pytest.mark.asyncio
async def test_set_spent() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(record_1)
assert not (await store.get_coin_record(coin_1.name())).spent
await store.set_spent(coin_1.name(), uint32(12))
assert (await store.get_coin_record(coin_1.name())).spent
assert (await store.get_coin_record(coin_1.name())).spent_block_height == 12
@pytest.mark.asyncio
async def test_get_records_by_puzzle_hash() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(record_4)
await store.add_coin_record(record_5)
await store.add_coin_record(record_5)
await store.add_coin_record(record_6)
assert len(await store.get_coin_records_by_puzzle_hash(record_6.coin.puzzle_hash)) == 2 # 4 and 6
assert len(await store.get_coin_records_by_puzzle_hash(token_bytes(32))) == 0
assert await store.get_coin_record(coin_6.name()) == record_6
assert await store.get_coin_record(token_bytes(32)) is None
@pytest.mark.asyncio
async def test_get_unspent_coins_for_wallet() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
assert await store.get_unspent_coins_for_wallet(1) == set()
await store.add_coin_record(record_4) # this is spent and wallet 0
await store.add_coin_record(record_5) # wallet 1
await store.add_coin_record(record_6) # this is spent and wallet 2
await store.add_coin_record(record_7) # wallet 2
assert await store.get_unspent_coins_for_wallet(1) == set([record_5])
assert await store.get_unspent_coins_for_wallet(2) == set([record_7])
assert await store.get_unspent_coins_for_wallet(3) == set()
await store.set_spent(coin_4.name(), uint32(12))
assert await store.get_unspent_coins_for_wallet(1) == set([record_5])
assert await store.get_unspent_coins_for_wallet(2) == set([record_7])
assert await store.get_unspent_coins_for_wallet(3) == set()
await store.set_spent(coin_7.name(), uint32(12))
assert await store.get_unspent_coins_for_wallet(1) == set([record_5])
assert await store.get_unspent_coins_for_wallet(2) == set()
assert await store.get_unspent_coins_for_wallet(3) == set()
await store.set_spent(coin_5.name(), uint32(12))
assert await store.get_unspent_coins_for_wallet(1) == set()
assert await store.get_unspent_coins_for_wallet(2) == set()
assert await store.get_unspent_coins_for_wallet(3) == set()
@pytest.mark.asyncio
async def test_get_records_by_parent_id() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(record_1)
await store.add_coin_record(record_2)
await store.add_coin_record(record_3)
await store.add_coin_record(record_4)
await store.add_coin_record(record_5)
await store.add_coin_record(record_6)
await store.add_coin_record(record_7)
assert set(await store.get_coin_records_by_parent_id(coin_1.parent_coin_info)) == set([record_1, record_2])
assert set(await store.get_coin_records_by_parent_id(coin_2.parent_coin_info)) == set([record_1, record_2])
assert await store.get_coin_records_by_parent_id(coin_3.parent_coin_info) == [record_3]
assert await store.get_coin_records_by_parent_id(coin_4.parent_coin_info) == [record_4]
assert await store.get_coin_records_by_parent_id(coin_5.parent_coin_info) == [record_5]
assert await store.get_coin_records_by_parent_id(coin_6.parent_coin_info) == [record_6]
assert await store.get_coin_records_by_parent_id(coin_7.parent_coin_info) == [record_7]
@pytest.mark.asyncio
async def test_get_multiple_coin_records() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(record_1)
await store.add_coin_record(record_2)
await store.add_coin_record(record_3)
await store.add_coin_record(record_4)
await store.add_coin_record(record_5)
await store.add_coin_record(record_6)
await store.add_coin_record(record_7)
assert set(await store.get_multiple_coin_records([coin_1.name(), coin_2.name(), coin_3.name()])) == set(
[record_1, record_2, record_3]
)
assert set(await store.get_multiple_coin_records([coin_5.name(), coin_6.name(), coin_7.name()])) == set(
[record_5, record_6, record_7]
)
assert (
set(
await store.get_multiple_coin_records(
[
coin_1.name(),
coin_2.name(),
coin_3.name(),
coin_4.name(),
coin_5.name(),
coin_6.name(),
coin_7.name(),
]
)
)
== set([record_1, record_2, record_3, record_4, record_5, record_6, record_7])
)
@pytest.mark.asyncio
async def test_delete_coin_record() -> None:
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(record_1)
await store.add_coin_record(record_2)
await store.add_coin_record(record_3)
await store.add_coin_record(record_4)
await store.add_coin_record(record_5)
await store.add_coin_record(record_6)
await store.add_coin_record(record_7)
assert (
set(
await store.get_multiple_coin_records(
[
coin_1.name(),
coin_2.name(),
coin_3.name(),
coin_4.name(),
coin_5.name(),
coin_6.name(),
coin_7.name(),
]
)
)
== set([record_1, record_2, record_3, record_4, record_5, record_6, record_7])
)
assert await store.get_coin_record(coin_1.name()) == record_1
await store.delete_coin_record(coin_1.name())
assert await store.get_coin_record(coin_1.name()) is None
assert set(
await store.get_multiple_coin_records(
[coin_2.name(), coin_3.name(), coin_4.name(), coin_5.name(), coin_6.name(), coin_7.name()]
)
) == set([record_2, record_3, record_4, record_5, record_6, record_7])
def record(c: Coin, *, confirmed: int, spent: int) -> WalletCoinRecord:
return WalletCoinRecord(c, uint32(confirmed), uint32(spent), spent != 0, False, WalletType.STANDARD_WALLET, 0)
@pytest.mark.asyncio
async def test_get_coins_to_check() -> None:
r1 = record(coin_1, confirmed=1, spent=0)
r2 = record(coin_2, confirmed=2, spent=4)
r3 = record(coin_3, confirmed=3, spent=5)
r4 = record(coin_4, confirmed=4, spent=6)
r5 = record(coin_5, confirmed=5, spent=7)
# these spent heights violate the invariant
r6 = record(coin_6, confirmed=6, spent=1)
r7 = record(coin_7, confirmed=7, spent=2)
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(r1)
await store.add_coin_record(r2)
await store.add_coin_record(r3)
await store.add_coin_record(r4)
await store.add_coin_record(r5)
await store.add_coin_record(r6)
await store.add_coin_record(r7)
for i in range(10):
coins = await store.get_coins_to_check(i)
# r1 is unspent and should always be included, regardless of height
assert r1 in coins
# r2 was spent at height 4
assert (r2 in coins) == (i < 4)
# r3 was spent at height 5
assert (r3 in coins) == (i < 5)
# r4 was spent at height 6
assert (r4 in coins) == (i < 6)
# r5 was spent at height 7
assert (r5 in coins) == (i < 7)
# r6 was confirmed at height 6
assert (r6 in coins) == (i < 6)
# r7 was confirmed at height 7
assert (r7 in coins) == (i < 7)
@pytest.mark.asyncio
async def test_get_first_coin_height() -> None:
r1 = record(coin_1, confirmed=1, spent=0)
r2 = record(coin_2, confirmed=2, spent=4)
r3 = record(coin_3, confirmed=3, spent=5)
r4 = record(coin_4, confirmed=4, spent=6)
r5 = record(coin_5, confirmed=5, spent=7)
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
assert await store.get_first_coin_height() is None
await store.add_coin_record(r5)
assert await store.get_first_coin_height() == 5
await store.add_coin_record(r4)
assert await store.get_first_coin_height() == 4
await store.add_coin_record(r3)
assert await store.get_first_coin_height() == 3
await store.add_coin_record(r2)
assert await store.get_first_coin_height() == 2
await store.add_coin_record(r1)
assert await store.get_first_coin_height() == 1
@pytest.mark.asyncio
async def test_rollback_to_block() -> None:
r1 = record(coin_1, confirmed=1, spent=0)
r2 = record(coin_2, confirmed=2, spent=4)
r3 = record(coin_3, confirmed=3, spent=5)
r4 = record(coin_4, confirmed=4, spent=6)
r5 = record(coin_5, confirmed=5, spent=7)
async with DBConnection1() as db_wrapper:
store = await WalletCoinStore.create(db_wrapper)
await store.add_coin_record(r1)
await store.add_coin_record(r2)
await store.add_coin_record(r3)
await store.add_coin_record(r4)
await store.add_coin_record(r5)
assert set(
await store.get_multiple_coin_records(
[
coin_1.name(),
coin_2.name(),
coin_3.name(),
coin_4.name(),
coin_5.name(),
]
)
) == set(
[
r1,
r2,
r3,
r4,
r5,
]
)
assert await store.get_coin_record(coin_5.name()) == r5
await store.rollback_to_block(6)
new_r5 = await store.get_coin_record(coin_5.name())
assert not new_r5.spent
assert new_r5.spent_block_height == 0
assert new_r5 != r5
assert await store.get_coin_record(coin_4.name()) == r4
await store.rollback_to_block(4)
assert await store.get_coin_record(coin_5.name()) is None
new_r4 = await store.get_coin_record(coin_4.name())
assert not new_r4.spent
assert new_r4.spent_block_height == 0
assert new_r4 != r4

View File

@ -43,7 +43,6 @@ async def test_wallet_tx_retry(
wallet_node_1: WalletNode = wallets[0][0]
wallet_node_1.config["tx_resend_timeout_secs"] = 5
wallet_server_1 = wallets[0][1]
assert wallet_node_1.wallet_state_manager is not None
wallet_1 = wallet_node_1.wallet_state_manager.main_wallet
reward_ph = await wallet_1.get_new_puzzlehash()
@ -77,7 +76,6 @@ async def test_wallet_tx_retry(
await time_out_assert(wait_secs, wallet_is_synced, True, wallet_node_1, full_node_1)
async def check_transaction_in_mempool_or_confirmed(transaction: TransactionRecord) -> bool:
assert wallet_node_1.wallet_state_manager is not None
txn = await wallet_node_1.wallet_state_manager.get_transaction(transaction.name)
assert txn is not None
sb = txn.spend_bundle

View File

@ -1,230 +0,0 @@
# TODO: write tests for other stores
# import asyncio
# from pathlib import Path
# from secrets import token_bytes
# import aiosqlite
# import pytest
# from chia.util.ints import uint32, uint64, uint128
# from chia.wallet.wallet_coin_record import WalletCoinRecord
# from chia.wallet.util.wallet_types import WalletType
# from chia.types.coin import Coin
#
#
# @pytest.fixture(scope="module")
# def event_loop():
# loop = asyncio.get_event_loop()
# yield loop
#
#
# class TestWalletStore:
# @pytest.mark.asyncio
# async def test_store(self):
# db_filename = Path("blockchain_wallet_store_test.db")
#
# if db_filename.exists():
# db_filename.unlink()
#
# db_connection = await aiosqlite.connect(db_filename)
# store = await WalletStore.create(db_connection)
# try:
# coin_1 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# coin_2 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# coin_3 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# coin_4 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# record_replaced = WalletCoinRecord(coin_1, uint32(8), uint32(0),
# False, True, WalletType.STANDARD_WALLET, 0)
# record_1 = WalletCoinRecord(coin_1, uint32(4), uint32(0), False,
# True, WalletType.STANDARD_WALLET, 0)
# record_2 = WalletCoinRecord(coin_2, uint32(5), uint32(0),
# False, True, WalletType.STANDARD_WALLET, 0)
# record_3 = WalletCoinRecord(
# coin_3,
# uint32(5),
# uint32(10),
# True,
# False,
# WalletType.STANDARD_WALLET,
# 0,
# )
# record_4 = WalletCoinRecord(
# coin_4,
# uint32(5),
# uint32(15),
# True,
# False,
# WalletType.STANDARD_WALLET,
# 0,
# )
#
# # Test add (replace) and get
# assert await store.get_coin_record(coin_1.name()) is None
# await store.add_coin_record(record_replaced)
# await store.add_coin_record(record_1)
# await store.add_coin_record(record_2)
# await store.add_coin_record(record_3)
# await store.add_coin_record(record_4)
# assert await store.get_coin_record(coin_1.name()) == record_1
#
# # Test persistance
# await db_connection.close()
# db_connection = await aiosqlite.connect(db_filename)
# store = await WalletStore.create(db_connection)
# assert await store.get_coin_record(coin_1.name()) == record_1
#
# # Test set spent
# await store.set_spent(coin_1.name(), uint32(12))
# assert (await store.get_coin_record(coin_1.name())).spent
# assert (await store.get_coin_record(coin_1.name())).spent_block_index == 12
#
# # No coins at height 3
# assert len(await store.get_unspent_coins_at_height(3)) == 0
# assert len(await store.get_unspent_coins_at_height(4)) == 1
# assert len(await store.get_unspent_coins_at_height(5)) == 4
# assert len(await store.get_unspent_coins_at_height(11)) == 3
# assert len(await store.get_unspent_coins_at_height(12)) == 2
# assert len(await store.get_unspent_coins_at_height(15)) == 1
# assert len(await store.get_unspent_coins_at_height(16)) == 1
# assert len(await store.get_unspent_coins_at_height()) == 1
#
# assert len(await store.get_unspent_coins_for_wallet(0)) == 1
# assert len(await store.get_unspent_coins_for_wallet(1)) == 0
#
# coin_5 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# record_5 = WalletCoinRecord(
# coin_5,
# uint32(5),
# uint32(15),
# False,
# False,
# WalletType.STANDARD_WALLET,
# 1,
# )
# await store.add_coin_record(record_5)
# assert len(await store.get_unspent_coins_for_wallet(1)) == 1
#
# assert len(await store.get_spendable_for_index(100, 1)) == 1
# assert len(await store.get_spendable_for_index(100, 0)) == 1
# assert len(await store.get_spendable_for_index(0, 0)) == 0
#
# coin_6 = Coin(token_bytes(32), coin_4.puzzle_hash, uint64(12312))
# await store.add_coin_record(record_5)
# record_6 = WalletCoinRecord(
# coin_6,
# uint32(5),
# uint32(15),
# True,
# False,
# WalletType.STANDARD_WALLET,
# 2,
# )
# await store.add_coin_record(record_6)
# assert len(await store.get_coin_records_by_puzzle_hash(record_6.coin.puzzle_hash)) == 2 # 4 and 6
# assert len(await store.get_coin_records_by_puzzle_hash(token_bytes(32))) == 0
#
# assert await store.get_coin_record_by_coin_id(coin_6.name()) == record_6
# assert await store.get_coin_record_by_coin_id(token_bytes(32)) is None
#
# # BLOCKS
# assert len(await store.get_lca_path()) == 0
#
# # NOT lca block
# br_1 = BlockRecord(
# token_bytes(32),
# token_bytes(32),
# uint32(0),
# uint128(100),
# None,
# None,
# None,
# None,
# uint64(0),
# )
# assert await store.get_block_record(br_1.header_hash) is None
# await store.add_block_record(br_1, False)
# assert len(await store.get_lca_path()) == 0
# assert await store.get_block_record(br_1.header_hash) == br_1
#
# # LCA genesis
# await store.add_block_record(br_1, True)
# assert await store.get_block_record(br_1.header_hash) == br_1
# assert len(await store.get_lca_path()) == 1
# assert (await store.get_lca_path())[br_1.header_hash] == br_1
#
# br_2 = BlockRecord(
# token_bytes(32),
# token_bytes(32),
# uint32(1),
# uint128(100),
# None,
# None,
# None,
# None,
# uint64(0),
# )
# await store.add_block_record(br_2, False)
# assert len(await store.get_lca_path()) == 1
# await store.add_block_to_path(br_2.header_hash)
# assert len(await store.get_lca_path()) == 2
# assert (await store.get_lca_path())[br_2.header_hash] == br_2
#
# br_3 = BlockRecord(
# token_bytes(32),
# token_bytes(32),
# uint32(2),
# uint128(100),
# None,
# None,
# None,
# None,
# uint64(0),
# )
# await store.add_block_record(br_3, True)
# assert len(await store.get_lca_path()) == 3
# await store.remove_block_records_from_path(1)
# assert len(await store.get_lca_path()) == 2
#
# await store.rollback_lca_to_block(0)
# assert len(await store.get_unspent_coins_at_height()) == 0
#
# coin_7 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# coin_8 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# coin_9 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# coin_10 = Coin(token_bytes(32), token_bytes(32), uint64(12312))
# record_7 = WalletCoinRecord(coin_7, uint32(0), uint32(1), True, False, WalletType.STANDARD_WALLET, 1)
# record_8 = WalletCoinRecord(coin_8, uint32(1), uint32(2), True, False, WalletType.STANDARD_WALLET, 1)
# record_9 = WalletCoinRecord(coin_9, uint32(2), uint32(3), True, False, WalletType.STANDARD_WALLET, 1)
# record_10 = WalletCoinRecord(
# coin_10,
# uint32(3),
# uint32(4),
# True,
# False,
# WalletType.STANDARD_WALLET,
# 1,
# )
#
# await store.add_coin_record(record_7)
# await store.add_coin_record(record_8)
# await store.add_coin_record(record_9)
# await store.add_coin_record(record_10)
# assert len(await store.get_unspent_coins_at_height(0)) == 1
# assert len(await store.get_unspent_coins_at_height(1)) == 1
# assert len(await store.get_unspent_coins_at_height(2)) == 1
# assert len(await store.get_unspent_coins_at_height(3)) == 1
# assert len(await store.get_unspent_coins_at_height(4)) == 0
#
# await store.add_block_record(br_2, True)
# await store.add_block_record(br_3, True)
#
# await store.rollback_lca_to_block(1)
#
# assert len(await store.get_unspent_coins_at_height(0)) == 1
# assert len(await store.get_unspent_coins_at_height(1)) == 1
# assert len(await store.get_unspent_coins_at_height(2)) == 1
# assert len(await store.get_unspent_coins_at_height(3)) == 1
# assert len(await store.get_unspent_coins_at_height(4)) == 1
#
# except AssertionError:
# await db_connection.close()
# raise
# await db_connection.close()