Merge branch 'main' into ak.lint-files

This commit is contained in:
Kyle Altendorf 2023-04-10 08:59:14 -04:00
commit 90d0641a9f
No known key found for this signature in database
325 changed files with 10274 additions and 5437 deletions

View File

@ -11,6 +11,8 @@ updates:
schedule:
interval: "weekly"
day: "tuesday"
labels:
- "Changed"
target-branch: "main"
pull-request-branch-name:
# Separate sections of the branch name with a hyphen
@ -24,6 +26,8 @@ updates:
schedule:
interval: "weekly"
day: "tuesday"
labels:
- "Changed"
target-branch: "main"
pull-request-branch-name:
# Separate sections of the branch name with a hyphen

View File

@ -2,6 +2,8 @@ name: ⚡️ Benchmarks
on:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -9,6 +11,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'

View File

@ -3,6 +3,8 @@ name: 📦🚀 Build Installer - Linux DEB ARM64
on:
workflow_dispatch:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -10,6 +12,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
@ -48,18 +52,9 @@ jobs:
run: bash build_scripts/clean-runner.sh || true
- name: Set Env
if: github.event_name == 'release' && github.event.action == 'published'
run: |
PRE_RELEASE=$(jq -r '.release.prerelease' "$GITHUB_EVENT_PATH")
RELEASE_TAG=$(jq -r '.release.tag_name' "$GITHUB_EVENT_PATH")
echo "RELEASE=true" >>$GITHUB_ENV
echo "PRE_RELEASE=$PRE_RELEASE" >>$GITHUB_ENV
echo "RELEASE_TAG=$RELEASE_TAG" >>$GITHUB_ENV
if [ $PRE_RELEASE = false ]; then
echo "FULL_RELEASE=true" >>$GITHUB_ENV
else
echo "FULL_RELEASE=false" >>$GITHUB_ENV
fi
uses: Chia-Network/actions/setjobenv@main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Create our own venv outside of the git directory JUST for getting the ACTUAL version so that install can't break it
- name: Get version number
@ -87,47 +82,25 @@ jobs:
AWS_SECRET: "${{ secrets.INSTALLER_UPLOAD_KEY }}"
GLUE_ACCESS_TOKEN: "${{ secrets.GLUE_ACCESS_TOKEN }}"
# Get the most recent release from chia-plotter-madmax
- uses: actions/github-script@v6
id: 'latest-madmax'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'chia-plotter-madmax',
});
return release.data.tag_name;
- name: Get latest madmax plotter
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_MADMAX=$(gh api repos/Chia-Network/chia-plotter-madmax/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/madmax"
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot-${{ steps.latest-madmax.outputs.result }}-arm64
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot_k34" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot_k34-${{ steps.latest-madmax.outputs.result }}-arm64
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot-*-arm64' -O $GITHUB_WORKSPACE/madmax/chia_plot
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot_k34-*-arm64' -O $GITHUB_WORKSPACE/madmax/chia_plot_k34
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot"
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot_k34"
# Get the most recent release from bladebit
- uses: actions/github-script@v6
if: '!github.event.release.prerelease'
id: 'latest-bladebit'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'bladebit',
});
return release.data.tag_name;
- name: Get latest bladebit plotter
if: '!github.event.release.prerelease'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_BLADEBIT=$(gh api repos/Chia-Network/bladebit/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz https://github.com/Chia-Network/bladebit/releases/download/${{ steps.latest-bladebit.outputs.result }}/bladebit-${{ steps.latest-bladebit.outputs.result }}-ubuntu-arm64.tar.gz
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_BLADEBIT -p '*-ubuntu-arm64.tar.gz' -O - | tar -xz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- name: Get latest prerelease bladebit plotter
@ -135,10 +108,9 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("ubuntu-arm64.tar.gz")).browser_download_url')
LATEST_PRERELEASE=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_PRERELEASE -p '*ubuntu-arm64.tar.gz' -O - | tar -xz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- uses: ./.github/actions/install

View File

@ -3,6 +3,8 @@ name: 📦🚀 Build Installer - Linux DEB AMD64
on:
workflow_dispatch:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -10,6 +12,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
@ -48,18 +52,9 @@ jobs:
run: bash build_scripts/clean-runner.sh || true
- name: Set Env
if: github.event_name == 'release' && github.event.action == 'published'
run: |
PRE_RELEASE=$(jq -r '.release.prerelease' "$GITHUB_EVENT_PATH")
RELEASE_TAG=$(jq -r '.release.tag_name' "$GITHUB_EVENT_PATH")
echo "RELEASE=true" >>$GITHUB_ENV
echo "PRE_RELEASE=$PRE_RELEASE" >>$GITHUB_ENV
echo "RELEASE_TAG=$RELEASE_TAG" >>$GITHUB_ENV
if [ $PRE_RELEASE = false ]; then
echo "FULL_RELEASE=true" >>$GITHUB_ENV
else
echo "FULL_RELEASE=false" >>$GITHUB_ENV
fi
uses: Chia-Network/actions/setjobenv@main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
# Create our own venv outside of the git directory JUST for getting the ACTUAL version so that install can't break it
- name: Get version number
@ -87,47 +82,25 @@ jobs:
AWS_SECRET: "${{ secrets.INSTALLER_UPLOAD_KEY }}"
GLUE_ACCESS_TOKEN: "${{ secrets.GLUE_ACCESS_TOKEN }}"
# Get the most recent release from chia-plotter-madmax
- uses: actions/github-script@v6
id: 'latest-madmax'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'chia-plotter-madmax',
});
return release.data.tag_name;
- name: Get latest madmax plotter
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_MADMAX=$(gh api repos/Chia-Network/chia-plotter-madmax/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/madmax"
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot-${{ steps.latest-madmax.outputs.result }}-x86-64
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot_k34" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot_k34-${{ steps.latest-madmax.outputs.result }}-x86-64
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot-*-x86-64' -O $GITHUB_WORKSPACE/madmax/chia_plot
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot_k34-*-x86-64' -O $GITHUB_WORKSPACE/madmax/chia_plot_k34
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot"
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot_k34"
# Get the most recent release from bladebit
- uses: actions/github-script@v6
if: '!github.event.release.prerelease'
id: 'latest-bladebit'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'bladebit',
});
return release.data.tag_name;
- name: Get latest bladebit plotter
if: '!github.event.release.prerelease'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_BLADEBIT=$(gh api repos/Chia-Network/bladebit/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz https://github.com/Chia-Network/bladebit/releases/download/${{ steps.latest-bladebit.outputs.result }}/bladebit-${{ steps.latest-bladebit.outputs.result }}-ubuntu-x86-64.tar.gz
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_BLADEBIT -p '*-ubuntu-x86-64.tar.gz' -O - | tar -xz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- name: Get latest prerelease bladebit plotter
@ -135,10 +108,9 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("ubuntu-x86-64.tar.gz")).browser_download_url')
LATEST_PRERELEASE=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_PRERELEASE -p '*ubuntu-x86-64.tar.gz' -O - | tar -xz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- uses: ./.github/actions/install

View File

@ -3,6 +3,8 @@ name: 📦🚀 Build Installer - Linux RPM AMD64
on:
workflow_dispatch:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -10,6 +12,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
@ -44,18 +48,9 @@ jobs:
run: bash build_scripts/clean-runner.sh || true
- name: Set Env
if: github.event_name == 'release' && github.event.action == 'published'
run: |
PRE_RELEASE=$(jq -r '.release.prerelease' "$GITHUB_EVENT_PATH")
RELEASE_TAG=$(jq -r '.release.tag_name' "$GITHUB_EVENT_PATH")
echo "RELEASE=true" >>$GITHUB_ENV
echo "PRE_RELEASE=$PRE_RELEASE" >>$GITHUB_ENV
echo "RELEASE_TAG=$RELEASE_TAG" >>$GITHUB_ENV
if [ $PRE_RELEASE = false ]; then
echo "FULL_RELEASE=true" >>$GITHUB_ENV
else
echo "FULL_RELEASE=false" >>$GITHUB_ENV
fi
uses: Chia-Network/actions/setjobenv@main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- uses: Chia-Network/actions/enforce-semver@main
if: env.FULL_RELEASE == 'true'
@ -86,47 +81,25 @@ jobs:
AWS_SECRET: "${{ secrets.INSTALLER_UPLOAD_KEY }}"
GLUE_ACCESS_TOKEN: "${{ secrets.GLUE_ACCESS_TOKEN }}"
# Get the most recent release from chia-plotter-madmax
- uses: actions/github-script@v6
id: 'latest-madmax'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'chia-plotter-madmax',
});
return release.data.tag_name;
- name: Get latest madmax plotter
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_MADMAX=$(gh api repos/Chia-Network/chia-plotter-madmax/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/madmax"
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot-${{ steps.latest-madmax.outputs.result }}-x86-64
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot_k34" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot_k34-${{ steps.latest-madmax.outputs.result }}-x86-64
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot-*-x86-64' -O $GITHUB_WORKSPACE/madmax/chia_plot
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot_k34-*-x86-64' -O $GITHUB_WORKSPACE/madmax/chia_plot_k34
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot"
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot_k34"
# Get the most recent release from bladebit
- uses: actions/github-script@v6
if: '!github.event.release.prerelease'
id: 'latest-bladebit'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'bladebit',
});
return release.data.tag_name;
- name: Get latest bladebit plotter
if: '!github.event.release.prerelease'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_BLADEBIT=$(gh api repos/Chia-Network/bladebit/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz https://github.com/Chia-Network/bladebit/releases/download/${{ steps.latest-bladebit.outputs.result }}/bladebit-${{ steps.latest-bladebit.outputs.result }}-centos-x86-64.tar.gz
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_BLADEBIT -p '*-centos-x86-64.tar.gz' -O - | tar -xz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- name: Get latest prerelease bladebit plotter
@ -134,10 +107,9 @@ jobs:
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("centos-x86-64.tar.gz")).browser_download_url')
LATEST_PRERELEASE=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_PRERELEASE -p '*centos-x86-64.tar.gz' -O - | tar -xz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- uses: ./.github/actions/install

View File

@ -3,6 +3,8 @@ name: 📦🚀 Build Installers - MacOS
on:
workflow_dispatch:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -10,6 +12,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
@ -57,18 +61,9 @@ jobs:
run: bash build_scripts/clean-runner.sh || true
- name: Set Env
if: github.event_name == 'release' && github.event.action == 'published'
run: |
PRE_RELEASE=$(jq -r '.release.prerelease' "$GITHUB_EVENT_PATH")
RELEASE_TAG=$(jq -r '.release.tag_name' "$GITHUB_EVENT_PATH")
echo "RELEASE=true" >>$GITHUB_ENV
echo "PRE_RELEASE=$PRE_RELEASE" >>$GITHUB_ENV
echo "RELEASE_TAG=$RELEASE_TAG" >>$GITHUB_ENV
if [ $PRE_RELEASE = false ]; then
echo "FULL_RELEASE=true" >>$GITHUB_ENV
else
echo "FULL_RELEASE=false" >>$GITHUB_ENV
fi
uses: Chia-Network/actions/setjobenv@main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Test for secrets access
id: check_secrets
@ -119,47 +114,39 @@ jobs:
p12-file-base64: ${{ secrets.APPLE_DEV_ID_APP }}
p12-password: ${{ secrets.APPLE_DEV_ID_APP_PASS }}
# Get the most recent release from chia-plotter-madmax
- uses: actions/github-script@v6
id: 'latest-madmax'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'chia-plotter-madmax',
});
return release.data.tag_name;
- name: Get latest madmax plotter
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_MADMAX=$(gh api repos/Chia-Network/chia-plotter-madmax/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/madmax"
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot-${{ steps.latest-madmax.outputs.result }}-macos-${{ matrix.os.name }}
wget -O "$GITHUB_WORKSPACE/madmax/chia_plot_k34" https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot_k34-${{ steps.latest-madmax.outputs.result }}-macos-${{ matrix.os.name }}
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot-'$LATEST_MADMAX'-macos-${{ matrix.os.name }}'
mv chia_plot-$LATEST_MADMAX-macos-${{ matrix.os.name }} $GITHUB_WORKSPACE/madmax/chia_plot
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot_k34-'$LATEST_MADMAX'-macos-${{ matrix.os.name }}'
mv chia_plot_k34-$LATEST_MADMAX-macos-${{ matrix.os.name }} $GITHUB_WORKSPACE/madmax/chia_plot_k34
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot"
chmod +x "$GITHUB_WORKSPACE/madmax/chia_plot_k34"
- name: Get latest bladebit plotter
if: '!github.event.release.prerelease'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
LATEST_BLADEBIT=$(gh api repos/Chia-Network/bladebit/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
gh release download -R Chia-Network/bladebit $LATEST_BLADEBIT -p '*${{ matrix.os.bladebit-suffix }}'
tar -xz -C $GITHUB_WORKSPACE/bladebit -f *${{ matrix.os.bladebit-suffix }}
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- name: Get latest prerelease bladebit plotter
if: env.PRE_RELEASE == 'true'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("${{ matrix.os.bladebit-suffix }}")).browser_download_url')
LATEST_PRERELEASE=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first | .tag_name')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $PRERELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- name: Get latest full release bladebit plotter
if: '!github.event.release.prerelease'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
FULLRELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease | not)) | first.assets[] | select(.browser_download_url | endswith("${{ matrix.os.bladebit-suffix }}")).browser_download_url')
mkdir "$GITHUB_WORKSPACE/bladebit"
wget -O /tmp/bladebit.tar.gz $FULLRELEASE_URL
tar -xvzf /tmp/bladebit.tar.gz -C $GITHUB_WORKSPACE/bladebit
gh release download -R Chia-Network/bladebit $LATEST_PRERELEASE -p '*${{ matrix.os.bladebit-suffix }}'
tar -xz -C $GITHUB_WORKSPACE/bladebit -f *${{ matrix.os.bladebit-suffix }}
chmod +x "$GITHUB_WORKSPACE/bladebit/bladebit"
- uses: ./.github/actions/install

View File

@ -3,6 +3,8 @@ name: 📦🚀 Build Installer - Windows 10
on:
workflow_dispatch:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -10,6 +12,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
@ -36,19 +40,9 @@ jobs:
submodules: recursive
- name: Set Env
if: github.event_name == 'release' && github.event.action == 'published'
shell: bash
run: |
PRE_RELEASE=$(jq -r '.release.prerelease' "$GITHUB_EVENT_PATH")
RELEASE_TAG=$(jq -r '.release.tag_name' "$GITHUB_EVENT_PATH")
echo "RELEASE=true" >>$GITHUB_ENV
echo "PRE_RELEASE=$PRE_RELEASE" >>$GITHUB_ENV
echo "RELEASE_TAG=$RELEASE_TAG" >>$GITHUB_ENV
if [ $PRE_RELEASE = false ]; then
echo "FULL_RELEASE=true" >>$GITHUB_ENV
else
echo "FULL_RELEASE=false" >>$GITHUB_ENV
fi
uses: Chia-Network/actions/setjobenv@main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Set git urls to https instead of ssh
run: |
@ -133,46 +127,27 @@ jobs:
echo "CHIA_INSTALLER_VERSION=$CHIA_INSTALLER_VERSION" >>$GITHUB_OUTPUT
deactivate
# Get the most recent release from chia-plotter-madmax
- uses: actions/github-script@v6
id: 'latest-madmax'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'chia-plotter-madmax',
});
return release.data.tag_name;
- name: Get latest madmax plotter
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
shell: bash
run: |
mkdir "$env:GITHUB_WORKSPACE\madmax"
Invoke-WebRequest https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot-${{ steps.latest-madmax.outputs.result }}.exe -OutFile "$env:GITHUB_WORKSPACE\madmax\chia_plot.exe"
Invoke-WebRequest https://github.com/Chia-Network/chia-plotter-madmax/releases/download/${{ steps.latest-madmax.outputs.result }}/chia_plot_k34-${{ steps.latest-madmax.outputs.result }}.exe -OutFile "$env:GITHUB_WORKSPACE\madmax\chia_plot_k34.exe"
# Get the most recent release from bladebit
- uses: actions/github-script@v6
if: '!github.event.release.prerelease'
id: 'latest-bladebit'
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
result-encoding: string
script: |
const release = await github.rest.repos.getLatestRelease({
owner: 'Chia-Network',
repo: 'bladebit',
});
return release.data.tag_name;
LATEST_MADMAX=$(gh api repos/Chia-Network/chia-plotter-madmax/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir $GITHUB_WORKSPACE\\madmax
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot-*.exe' -O $GITHUB_WORKSPACE\\madmax\\chia_plot.exe
gh release download -R Chia-Network/chia-plotter-madmax $LATEST_MADMAX -p 'chia_plot_k34-*.exe' -O $GITHUB_WORKSPACE\\madmax\\chia_plot_k34.exe
- name: Get latest bladebit plotter
if: '!github.event.release.prerelease'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
shell: bash
run: |
mkdir "$env:GITHUB_WORKSPACE\bladebit"
Invoke-WebRequest https://github.com/Chia-Network/bladebit/releases/download/${{ steps.latest-bladebit.outputs.result }}/bladebit-${{ steps.latest-bladebit.outputs.result }}-windows-x86-64.zip -OutFile "$env:GITHUB_WORKSPACE\bladebit\bladebit.zip"
Expand-Archive -Path "$env:GITHUB_WORKSPACE\bladebit\bladebit.zip" -DestinationPath "$env:GITHUB_WORKSPACE\bladebit\"
rm "$env:GITHUB_WORKSPACE\bladebit\bladebit.zip"
LATEST_BLADEBIT=$(gh api repos/Chia-Network/bladebit/releases/latest --jq 'select(.prerelease == false) | .tag_name')
mkdir $GITHUB_WORKSPACE\\bladebit
gh release download -R Chia-Network/bladebit $LATEST_BLADEBIT -p '*windows-x86-64.zip' -O $GITHUB_WORKSPACE\\bladebit\\bladebit.zip
unzip $GITHUB_WORKSPACE\\bladebit\\bladebit.zip -d $GITHUB_WORKSPACE\\bladebit\\
rm $GITHUB_WORKSPACE\\bladebit\\bladebit.zip
- name: Get latest prerelease bladebit plotter
if: env.PRE_RELEASE == 'true'
@ -180,11 +155,9 @@ jobs:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
shell: bash
run: |
PRERELEASE_URL=$(gh api /repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first.assets[] | select(.browser_download_url | endswith("windows-x86-64.zip")).browser_download_url')
LATEST_PRERELEASE=$(gh api repos/Chia-Network/bladebit/releases --jq 'map(select(.prerelease)) | first | .tag_name')
mkdir $GITHUB_WORKSPACE\\bladebit
ls
echo $PRERELEASE_URL
curl -L "$PRERELEASE_URL" --output $GITHUB_WORKSPACE\\bladebit\\bladebit.zip
gh release download -R Chia-Network/bladebit $LATEST_PRERELEASE -p '*windows-x86-64.zip' -O $GITHUB_WORKSPACE\\bladebit\\bladebit.zip
unzip $GITHUB_WORKSPACE\\bladebit\\bladebit.zip -d $GITHUB_WORKSPACE\\bladebit\\
rm $GITHUB_WORKSPACE\\bladebit\\bladebit.zip

View File

@ -1,6 +1,8 @@
name: 🚨 Snyk Python Scan
on:
push:
paths-ignore:
- '**.md'
branches:
- long_lived/**
- main
@ -8,6 +10,8 @@ on:
tags:
- '**'
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
workflow_dispatch:

View File

@ -2,6 +2,8 @@ name: 🏗️ Test Install Scripts
on:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -9,6 +11,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'

View File

@ -63,6 +63,7 @@ jobs:
matrix: '3.7'
exclude_from:
limited: True
main: True
- name: '3.8'
file_name: '3.8'
action: '3.8'
@ -71,6 +72,7 @@ jobs:
matrix: '3.8'
exclude_from:
limited: True
main: True
- name: '3.9'
file_name: '3.9'
action: '3.9'
@ -85,6 +87,7 @@ jobs:
matrix: '3.10'
exclude_from:
limited: True
main: True
- name: '3.11'
file_name: '3.11'
action: '3.11'
@ -93,6 +96,7 @@ jobs:
matrix: '3.11'
exclude_from:
limited: True
main: True
exclude:
- os:
matrix: macos
@ -129,6 +133,11 @@ jobs:
with:
fetch-depth: 0
- name: Set Env
uses: Chia-Network/actions/setjobenv@main
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Setup Python environment
uses: Chia-Network/actions/setup-python@main
with:
@ -184,15 +193,20 @@ jobs:
- name: Checkout test blocks and plots (macOS, Ubuntu)
if: steps.test-blocks-plots.outputs.cache-hit != 'true' && (matrix.os.matrix == 'ubuntu' || matrix.os.matrix == 'macos')
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
wget -qO- https://github.com/Chia-Network/test-cache/archive/refs/tags/${{ env.BLOCKS_AND_PLOTS_VERSION }}.tar.gz | tar xzf -
gh release download -R Chia-Network/test-cache 0.29.0 --archive=tar.gz -O - | tar xzf -
mkdir ${{ github.workspace }}/.chia
mv ${{ github.workspace }}/test-cache-${{ env.BLOCKS_AND_PLOTS_VERSION }}/* ${{ github.workspace }}/.chia
- name: Checkout test blocks and plots (Windows)
if: steps.test-blocks-plots.outputs.cache-hit != 'true' && matrix.os.matrix == 'windows'
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
Invoke-WebRequest -OutFile blocks_and_plots.zip https://github.com/Chia-Network/test-cache/archive/refs/tags/${{ env.BLOCKS_AND_PLOTS_VERSION }}.zip; Expand-Archive blocks_and_plots.zip -DestinationPath .
gh release download -R Chia-Network/test-cache ${{ env.BLOCKS_AND_PLOTS_VERSION }} --archive=zip -O blocks_and_plots.zip
Expand-Archive blocks_and_plots.zip -DestinationPath .
mkdir ${{ github.workspace }}/.chia
mv ${{ github.workspace }}/test-cache-${{ env.BLOCKS_AND_PLOTS_VERSION }}/* ${{ github.workspace }}/.chia

View File

@ -2,6 +2,8 @@ name: 🧪 test
on:
push:
paths-ignore:
- '**.md'
branches:
- long_lived/**
- main
@ -9,6 +11,8 @@ on:
tags:
- '**'
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'
workflow_dispatch: null
@ -37,7 +41,7 @@ jobs:
python tests/build-job-matrix.py --per directory --verbose > matrix.json
cat matrix.json
echo configuration=$(cat matrix.json) >> $GITHUB_OUTPUT
echo matrix_mode=${{ ( github.repository_owner == 'Chia-Network' && github.repository != 'Chia-Network/chia-blockchain' ) && 'limited' || 'all' }} >> $GITHUB_OUTPUT
echo matrix_mode=${{ ( github.event_name == 'workflow_dispatch' ) && 'all' || ( github.repository_owner == 'Chia-Network' && github.repository != 'Chia-Network/chia-blockchain' ) && 'limited' || ( github.repository_owner == 'Chia-Network' && github.repository == 'Chia-Network/chia-blockchain' && github.ref == 'refs/heads/main' ) && 'main' || ( github.repository_owner == 'Chia-Network' && github.repository == 'Chia-Network/chia-blockchain' && startsWith(github.ref, 'refs/heads/release/') ) && 'all' || 'main' }} >> $GITHUB_OUTPUT
outputs:
configuration: ${{ steps.configure.outputs.configuration }}

View File

@ -2,10 +2,14 @@ name: 📦🚀 Trigger Dev Docker Build
on:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- 'release/**'
pull_request:
paths-ignore:
- '**.md'
concurrency:
# SHA is added to the end if on `main` to let all main workflows run

View File

@ -2,6 +2,8 @@ name: 📦🚀 Trigger Main Docker Build
on:
push:
paths-ignore:
- '**.md'
branches:
- main

View File

@ -2,6 +2,8 @@ name: 🚨🚀 Lint and upload source distribution
on:
push:
paths-ignore:
- '**.md'
branches:
- 'long_lived/**'
- main
@ -9,6 +11,8 @@ on:
release:
types: [published]
pull_request:
paths-ignore:
- '**.md'
branches:
- '**'

View File

@ -15,7 +15,7 @@ repos:
pass_filenames: false
additional_dependencies: [click~=7.1]
- repo: https://github.com/psf/black
rev: 22.10.0
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks

View File

@ -6,6 +6,107 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project does not yet adhere to [Semantic Versioning](https://semver.org/spec/v2.0.0.html)
for setuptools_scm/PEP 440 reasons.
## 1.7.1 Chia blockchain 2023-03-22
### Added
- `get_transaction_memo` wallet RPC
- `set_wallet_resync_on_startup` wallet RPC to reset wallet sync data on wallet restart
- `nft_count_nfts` wallet RPC - counts NFTs per wallet or for all wallets
- Community DNS introducers to initial (default) config.yaml
- additional metrics for `state_changed` events (used by chia-exporter)
- Python 3.11 support
- `chia wallet check` CLI command
- `reuse_public_key_for_change` config.yaml option to allow address reuse for change
- `nft_id` added to the JSON output of all NFT RPCs
- `curry` Chialisp library replaces `curry-and-treehash`
### Changed
- `chia show -f` changed to output proper JSON
- `Rate limiting` log messages are themselves rate limited
- Notified GUI when wallets are removed
- Optimized counting of NFTs during removal by leveraging SQLite
- Offer CLI command help now shows `--fee` units as XCH
- Optimized offer code by limiting `additions` recomputation
- `chia_rs` updated to 0.2.4
- Improve the help text for the `chia peer` and `chia peer -a` commands
- Remove redundant checks for minting coin and reserve fee
- `nft_get_nfts` performance improvements by shifting paging to SQLite
- `did_find_lost_did` improved
- Extend the sign APIs to support hex string
- Changed mempool backend to use an in-memory SQLite DB
### Fixed
- Quieted wallet log output for `Record: ... not in mempool` (fixes #14452)
- Quieted log output for `AttributeError: 'NoneType' object has no attribute '_get_extra_info`
- Reduced log output for `Using previous generator for height`
- Fixed error message when the `coin_id` argument for `nft_get_info` cannot be decoded
- Reduced cases where wallet claims to be synced while still syncing
- Resolved unnecessary error logging caused by unhinted coins (see #14757)
- Avoid `Unclosed client session` errors and associated tracebacks when using Ctrl-c during CLI commands
- Avoid `AttributeError` when shutting down plotting
- Added `anyio` as a dependency
- Resolved issues when offers involve small amounts and royalties round down to zero (fixes #14744)
- Limit retries to 3 for submitting offer transactions to the mempool and improved handling of offer status (fixes #14714)
- Resolved issues with offers stuck as pending when multiple people accept the same offer (fixes #14621)
- Improved the accuracy of the wallet sync status indication
### Deprecated
- `curry-and-treehash` Chialisp library replaced by new `curry` library
## 1.7.0 Chia blockchain 2023-02-15
### Added
- New `chia wallet coins` CLI and RPCs for listing, splitting, and combining coins
- New on-chain notification for offers, specifically designed for NFT offers
- New full node dynamic fee estimator (`chia show -f` and `get_fee_estimate` full node RPC)
- Implementation of soft fork at block 3630000 - see the 1.7.0 blog post for more details
- Add gzip support to DataLayer download client (Thanks, @Chida82!)
- Add proxy support to DataLayer download client (Thanks again, @Chida82!)
- Add `get_timestamp_for_height` Wallet RPC for converting heights to timestamps
- Add `tools/legacy_keyring.py` to allow migration from the removed old key storage format. Available only from source installations.
- Add Arch Linux to install-gui.sh script (Thanks, @DaOneLuna!)
- Add a `daemon_heartbeat` setting to config.yaml
- add `trusted_max_subscribe_items` and `wallet:trusted_peers` to config.yaml
- NFT bulk transfer and DID assignment wallet RPCs
- Add the expected offer ID to some RPCs that take offer blobs
### Changed
- bump `chia_rs` dependency to `0.2.0`
- Update version of `clvm_tools_rs` to `0.1.30`
- Use better check that we are on mainnet when deciding to use default Chia DNS server
- Remove conflicting TXs before adding SpendBundle to Mempool in `add_spend_bundle`
- Try each Chia DNS Server in list before trying introducers
- optimize mempool's potential cache
- Display complete exception info in log file for validation, consensus, and protocol errors
- Enable setting time between blocks in full node sim
- Limit rate of log messages when farmer is disconnected from pool
- Add SigningMode and update `verify_signature` RPC to work with `sign_message_by_*` RPCs
### Fixed
- Offer security updates: Offers that are generated with this version cannot be accepted with older versions of Chia - see blog post for details
- server: Fix invalid attribute accesses in `WSChiaConnection`
- header validation time logging severity reduced from warning to info when time is less than two seconds
- replacing transactions in the mempool is normal behavior, not a warning
- don't throw unnecessary exception on peer connect
- Return existing CAT wallet instead of raising
- Resolve peers in harvester and timelord startup (fixes #14158)
- bump default bladebit version to `2.0.1` in `install-plotter.sh`
- disallow empty SpendBundles in the mempool
- avoid an exception in some rare cases when requesting the pool login link
- provide a clear error when the `wallet_id` value is missing in a call to the `nft_set_bulk_nft_did` wallet rpc (Thanks, @steppsr!)
- allow cancellation of offers when there is no spendable balance
- track all transactions of an NFT bulk mint instead of just the first
- Make the `--id` flag on cancel_offer required
- corrected a target address vs. metadata mismatch when bulk minting and airdropping NFTs
- Fixed wallet DB issues resulting when there are unexpected failures during syncing
### Deprecated
- Python 3.7 support is deprecated and will be removed in a future version
## 1.6.2 Chia blockchain 2023-01-03
### Added

View File

@ -1,6 +1,6 @@
# chia-blockchain
![Alt text](https://www.chia.net/wp-content/uploads/2022/09/chia-logo.svg)
[![Chia Network logo](https://www.chia.net/wp-content/uploads/2022/09/chia-logo.svg "Chia logo")](https://www.chia.net/)
| Current Release/main | Development Branch/dev |
| :---: | :---: |
@ -19,8 +19,7 @@ Chia is a modern cryptocurrency built from scratch, designed to be efficient, de
* Support for light clients with fast, objective syncing
* A growing community of farmers and developers around the world
Please check out the [wiki](https://github.com/Chia-Network/chia-blockchain/wiki)
and [FAQ](https://github.com/Chia-Network/chia-blockchain/wiki/FAQ) for
Please check out the [Chia website](https://www.chia.net/), the [wiki](https://github.com/Chia-Network/chia-blockchain/wiki), and [FAQ](https://github.com/Chia-Network/chia-blockchain/wiki/FAQ) for
information on this project.
Python 3.7+ is required. Make sure your default python version is >=3.7

View File

@ -48,8 +48,7 @@ def random_refs() -> List[uint32]:
REPETITIONS = 100
async def main(db_path: Path):
async def main(db_path: Path) -> None:
random.seed(0x213FB154)
async with aiosqlite.connect(db_path) as connection:
@ -92,7 +91,7 @@ async def main(db_path: Path):
@click.command()
@click.argument("db-path", type=click.Path())
def entry_point(db_path: Path):
def entry_point(db_path: Path) -> None:
asyncio.run(main(Path(db_path)))

View File

@ -39,8 +39,7 @@ NUM_ITERS = 20000
random.seed(123456789)
async def run_add_block_benchmark(version: int):
async def run_add_block_benchmark(version: int) -> None:
verbose: bool = "--verbose" in sys.argv
db_wrapper: DBWrapper2 = await setup_db("block-store-benchmark.db", version)
@ -73,7 +72,6 @@ async def run_add_block_benchmark(version: int):
print("profiling add_full_block", end="")
for height in range(block_height, block_height + NUM_ITERS):
is_transaction = transaction_block_counter == 0
fees = uint64(random.randint(0, 150000))
farmer_coin, pool_coin = rewards(uint32(height))

View File

@ -37,8 +37,7 @@ def make_coins(num: int) -> Tuple[List[Coin], List[bytes32]]:
return additions, hashes
async def run_new_block_benchmark(version: int):
async def run_new_block_benchmark(version: int) -> None:
verbose: bool = "--verbose" in sys.argv
db_wrapper: DBWrapper2 = await setup_db("coin-store-benchmark.db", version)
@ -56,7 +55,6 @@ async def run_new_block_benchmark(version: int):
print("Building database ", end="")
for height in range(block_height, block_height + NUM_ITERS):
# add some new coins
additions, hashes = make_coins(2000)
@ -94,7 +92,6 @@ async def run_new_block_benchmark(version: int):
if verbose:
print("Profiling mostly additions ", end="")
for height in range(block_height, block_height + NUM_ITERS):
# add some new coins
additions, hashes = make_coins(2000)
total_add += 2000
@ -193,7 +190,6 @@ async def run_new_block_benchmark(version: int):
total_remove = 0
total_time = 0
for height in range(block_height, block_height + NUM_ITERS):
# add some new coins
additions, hashes = make_coins(2000)
total_add += 2000

View File

@ -0,0 +1,153 @@
from __future__ import annotations
import asyncio
from dataclasses import dataclass
from time import monotonic
from typing import Dict, Optional
from blspy import G2Element
from clvm.casts import int_to_bytes
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.mempool_manager import MempoolManager
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.coin_spend import CoinSpend
from chia.types.condition_opcodes import ConditionOpcode
from chia.types.spend_bundle import SpendBundle
from chia.types.spend_bundle_conditions import Spend, SpendBundleConditions
from chia.util.ints import uint32, uint64
# this is one week worth of blocks
NUM_ITERS = 32256
def make_hash(height: int) -> bytes32:
return bytes32(height.to_bytes(32, byteorder="big"))
@dataclass(frozen=True)
class BenchBlockRecord:
"""
This is a subset of BlockRecord that the mempool manager uses for peak.
"""
header_hash: bytes32
height: uint32
timestamp: Optional[uint64]
prev_transaction_block_height: uint32
prev_transaction_block_hash: Optional[bytes32]
@property
def is_transaction_block(self) -> bool:
return self.timestamp is not None
IDENTITY_PUZZLE = Program.to(1)
IDENTITY_PUZZLE_HASH = IDENTITY_PUZZLE.get_tree_hash()
def make_spend_bundle(coin: Coin, height: int) -> SpendBundle:
# the fees we pay will go up over time (by subtracting height * 10)
conditions = [
[
ConditionOpcode.CREATE_COIN,
make_hash(height + coin.amount - 1),
int_to_bytes(coin.amount // 2 - height * 10),
],
[
ConditionOpcode.CREATE_COIN,
make_hash(height + coin.amount + 1),
int_to_bytes(coin.amount // 2 - height * 10),
],
]
spend = CoinSpend(coin, IDENTITY_PUZZLE, Program.to(conditions))
return SpendBundle([spend], G2Element())
def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockRecord:
this_hash = make_hash(block_height)
prev_hash = make_hash(block_height - 1)
return BenchBlockRecord(
header_hash=this_hash,
height=block_height,
timestamp=timestamp,
prev_transaction_block_height=uint32(block_height - 1),
prev_transaction_block_hash=prev_hash,
)
async def run_mempool_benchmark() -> None:
coin_records: Dict[bytes32, CoinRecord] = {}
async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
return coin_records.get(coin_id)
timestamp = uint64(1631794488)
mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=True)
print("\nrunning add_spend_bundle() + new_peak()")
start = monotonic()
most_recent_coin_id = make_hash(100)
for height in range(1, NUM_ITERS):
timestamp = uint64(timestamp + 19)
rec = fake_block_record(uint32(height), timestamp)
# the new block spends on coind, the most recently added one
# most_recent_coin_id
npc_result = NPCResult(
None,
SpendBundleConditions(
[Spend(most_recent_coin_id, bytes32(b" " * 32), None, 0, None, None, None, None, [], [], 0)],
0,
0,
0,
None,
None,
[],
0,
0,
0,
),
uint64(1000000000),
)
await mempool.new_peak(rec, npc_result)
# add 10 transactions to the mempool
for i in range(10):
coin = Coin(make_hash(height * 10 + i), IDENTITY_PUZZLE_HASH, height * 100000 + i * 100)
sb = make_spend_bundle(coin, height)
# make this coin available via get_coin_record, which is called
# by mempool_manager
coin_records = {
coin.name(): CoinRecord(coin, uint32(height // 2), uint32(0), False, uint64(timestamp // 2))
}
spend_bundle_id = sb.name()
npc = await mempool.pre_validate_spendbundle(sb, None, spend_bundle_id)
assert npc is not None
await mempool.add_spend_bundle(sb, npc, spend_bundle_id, uint32(height))
if height % 100 == 0:
print(
"height: ", height, " size: ", mempool.mempool.size(), " cost: ", mempool.mempool.total_mempool_cost()
)
# this coin will be spent in the next block
most_recent_coin_id = coin.name()
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per block: {(stop - start) / height * 1000:0.2f}ms")
if __name__ == "__main__":
import logging
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.WARNING)
asyncio.run(run_mempool_benchmark())

View File

@ -3,27 +3,25 @@ from __future__ import annotations
import asyncio
import cProfile
from contextlib import contextmanager
from dataclasses import dataclass
from subprocess import check_call
from time import monotonic
from typing import Iterator, List
from typing import Dict, Iterator, List, Optional, Tuple
from utils import setup_db
from chia.consensus.block_record import BlockRecord
from chia.consensus.coinbase import create_farmer_coin, create_pool_coin
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.coin_store import CoinStore
from chia.full_node.mempool_manager import MempoolManager
from chia.simulator.wallet_tools import WalletTool
from chia.types.blockchain_format.classgroup import ClassgroupElement
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32, bytes100
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.spend_bundle import SpendBundle
from chia.util.db_wrapper import DBWrapper2
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.types.spend_bundle_conditions import Spend, SpendBundleConditions
from chia.util.ints import uint32, uint64
NUM_ITERS = 100
NUM_ITERS = 200
NUM_PEERS = 5
@ -42,105 +40,109 @@ def enable_profiler(profile: bool, name: str) -> Iterator[None]:
check_call(["gprof2dot", "-f", "pstats", "-o", output_file + ".dot", output_file + ".profile"])
with open(output_file + ".png", "w+") as f:
check_call(["dot", "-T", "png", output_file + ".dot"], stdout=f)
print("output written to: %s.png" % output_file)
print(" output written to: %s.png" % output_file)
def fake_block_record(block_height: uint32, timestamp: uint64) -> BlockRecord:
return BlockRecord(
bytes32(b"a" * 32), # header_hash
bytes32(b"b" * 32), # prev_hash
block_height, # height
uint128(0), # weight
uint128(0), # total_iters
uint8(0), # signage_point_index
ClassgroupElement(bytes100(b"1" * 100)), # challenge_vdf_output
None, # infused_challenge_vdf_output
bytes32(b"f" * 32), # reward_infusion_new_challenge
bytes32(b"c" * 32), # challenge_block_info_hash
uint64(0), # sub_slot_iters
bytes32(b"d" * 32), # pool_puzzle_hash
bytes32(b"e" * 32), # farmer_puzzle_hash
uint64(0), # required_iters
uint8(0), # deficit
False, # overflow
uint32(block_height - 1), # prev_transaction_block_height
timestamp, # timestamp
None, # prev_transaction_block_hash
uint64(0), # fees
None, # reward_claims_incorporated
None, # finished_challenge_slot_hashes
None, # finished_infused_challenge_slot_hashes
None, # finished_reward_slot_hashes
None, # sub_epoch_summary_included
def make_hash(height: int) -> bytes32:
return bytes32(height.to_bytes(32, byteorder="big"))
@dataclass(frozen=True)
class BenchBlockRecord:
"""
This is a subset of BlockRecord that the mempool manager uses for peak.
"""
header_hash: bytes32
height: uint32
timestamp: Optional[uint64]
prev_transaction_block_height: uint32
prev_transaction_block_hash: Optional[bytes32]
@property
def is_transaction_block(self) -> bool:
return self.timestamp is not None
def fake_block_record(block_height: uint32, timestamp: uint64) -> BenchBlockRecord:
this_hash = make_hash(block_height)
prev_hash = make_hash(block_height - 1)
return BenchBlockRecord(
header_hash=this_hash,
height=block_height,
timestamp=timestamp,
prev_transaction_block_height=uint32(block_height - 1),
prev_transaction_block_hash=prev_hash,
)
async def run_mempool_benchmark(single_threaded: bool) -> None:
async def run_mempool_benchmark() -> None:
all_coins: Dict[bytes32, CoinRecord] = {}
suffix = "st" if single_threaded else "mt"
db_wrapper: DBWrapper2 = await setup_db(f"mempool-benchmark-coins-{suffix}.db", 2)
async def get_coin_record(coin_id: bytes32) -> Optional[CoinRecord]:
return all_coins.get(coin_id)
try:
coin_store = await CoinStore.create(db_wrapper)
mempool = MempoolManager(coin_store.get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded)
wt = WalletTool(DEFAULT_CONSTANTS)
wt = WalletTool(DEFAULT_CONSTANTS)
spend_bundles: List[List[SpendBundle]] = []
spend_bundles: List[List[SpendBundle]] = []
# these spend the same coins as spend_bundles but with a higher fee
replacement_spend_bundles: List[List[SpendBundle]] = []
timestamp = uint64(1631794488)
timestamp = uint64(1631794488)
height = uint32(1)
height = uint32(1)
print("Building SpendBundles")
for peer in range(NUM_PEERS):
print(f" peer {peer}")
print(" reward coins")
unspent: List[Coin] = []
for idx in range(NUM_ITERS):
height = uint32(height + 1)
# farm rewards
farmer_coin = create_farmer_coin(
height, wt.get_new_puzzlehash(), uint64(250000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
pool_coin = create_pool_coin(
height, wt.get_new_puzzlehash(), uint64(1750000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
unspent.extend([farmer_coin, pool_coin])
await coin_store.new_block(
height,
timestamp,
set([pool_coin, farmer_coin]),
[],
[],
)
bundles: List[SpendBundle] = []
print(" spend bundles")
for coin in unspent:
tx: SpendBundle = wt.generate_signed_transaction(
uint64(coin.amount // 2), wt.get_new_puzzlehash(), coin
)
bundles.append(tx)
spend_bundles.append(bundles)
print("Building SpendBundles")
for peer in range(NUM_PEERS):
print(f" peer {peer}")
print(" reward coins")
unspent: List[Coin] = []
for idx in range(NUM_ITERS):
height = uint32(height + 1)
# 19 seconds per block
timestamp = uint64(timestamp + 19)
if single_threaded:
print("Single-threaded")
else:
print("Multi-threaded")
print("Profiling add_spendbundle()")
# farm rewards
farmer_coin = create_farmer_coin(
height, wt.get_new_puzzlehash(), uint64(250000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
pool_coin = create_pool_coin(
height, wt.get_new_puzzlehash(), uint64(1750000000), DEFAULT_CONSTANTS.GENESIS_CHALLENGE
)
all_coins[farmer_coin.name()] = CoinRecord(farmer_coin, height, uint32(0), True, timestamp)
all_coins[pool_coin.name()] = CoinRecord(pool_coin, height, uint32(0), True, timestamp)
unspent.extend([farmer_coin, pool_coin])
# the mempool only looks at:
# timestamp
# height
# is_transaction_block
# header_hash
print("initialize MempoolManager")
print(" spend bundles")
bundles: List[SpendBundle] = []
for coin in unspent:
tx: SpendBundle = wt.generate_signed_transaction(
uint64(coin.amount // 2), wt.get_new_puzzlehash(), coin, fee=peer + idx
)
bundles.append(tx)
spend_bundles.append(bundles)
bundles = []
print(" replacement spend bundles")
for coin in unspent:
tx = wt.generate_signed_transaction(
uint64(coin.amount // 2), wt.get_new_puzzlehash(), coin, fee=peer + idx + 10000000
)
bundles.append(tx)
replacement_spend_bundles.append(bundles)
start_height = height
for single_threaded in [False, True]:
if single_threaded:
print("\n== Single-threaded")
else:
print("\n== Multi-threaded")
mempool = MempoolManager(get_coin_record, DEFAULT_CONSTANTS, single_threaded=single_threaded)
height = start_height
rec = fake_block_record(height, timestamp)
await mempool.new_peak(rec, None)
@ -153,6 +155,9 @@ async def run_mempool_benchmark(single_threaded: bool) -> None:
assert status == MempoolInclusionStatus.SUCCESS
assert error is None
suffix = "st" if single_threaded else "mt"
print("\nProfiling add_spend_bundle()")
total_bundles = 0
tasks = []
with enable_profiler(True, f"add-{suffix}"):
@ -162,20 +167,94 @@ async def run_mempool_benchmark(single_threaded: bool) -> None:
tasks.append(asyncio.create_task(add_spend_bundles(spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f"add_spendbundle time: {stop - start:0.4f}s")
print(f"{(stop - start) / total_bundles * 1000:0.2f}ms per add_spendbundle() call")
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / total_bundles * 1000:0.2f}ms")
print("\nProfiling add_spend_bundle() with replace-by-fee")
total_bundles = 0
tasks = []
with enable_profiler(True, f"replace-{suffix}"):
start = monotonic()
for peer in range(NUM_PEERS):
total_bundles += len(replacement_spend_bundles[peer])
tasks.append(asyncio.create_task(add_spend_bundles(replacement_spend_bundles[peer])))
await asyncio.gather(*tasks)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / total_bundles * 1000:0.2f}ms")
print("\nProfiling create_bundle_from_mempool()")
with enable_profiler(True, f"create-{suffix}"):
start = monotonic()
for _ in range(2000):
mempool.create_bundle_from_mempool(bytes32(b"a" * 32))
for _ in range(500):
mempool.create_bundle_from_mempool(rec.header_hash)
stop = monotonic()
print(f"create_bundle_from_mempool time: {stop - start:0.4f}s")
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / 500 * 1000:0.2f}ms")
# TODO: add benchmark for new_peak()
print("\nProfiling new_peak() (optimized)")
blocks: List[Tuple[BenchBlockRecord, NPCResult]] = []
for coin_id in all_coins.keys():
height = uint32(height + 1)
timestamp = uint64(timestamp + 19)
rec = fake_block_record(height, timestamp)
npc_result = NPCResult(
None,
SpendBundleConditions(
[Spend(coin_id, bytes32(b" " * 32), None, None, None, None, None, None, [], [], 0)],
0,
0,
0,
None,
None,
[],
0,
0,
0,
),
uint64(1000000000),
)
blocks.append((rec, npc_result))
finally:
await db_wrapper.close()
with enable_profiler(True, f"new-peak-{suffix}"):
start = monotonic()
for rec, npc_result in blocks:
await mempool.new_peak(rec, npc_result)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / len(blocks) * 1000:0.2f}ms")
print("\nProfiling new_peak() (reorg)")
blocks = []
for coin_id in all_coins.keys():
height = uint32(height + 2)
timestamp = uint64(timestamp + 28)
rec = fake_block_record(height, timestamp)
npc_result = NPCResult(
None,
SpendBundleConditions(
[Spend(coin_id, bytes32(b" " * 32), None, None, None, None, None, None, [], [], 0)],
0,
0,
0,
None,
None,
[],
0,
0,
0,
),
uint64(1000000000),
)
blocks.append((rec, npc_result))
with enable_profiler(True, f"new-peak-reorg-{suffix}"):
start = monotonic()
for rec, npc_result in blocks:
await mempool.new_peak(rec, npc_result)
stop = monotonic()
print(f" time: {stop - start:0.4f}s")
print(f" per call: {(stop - start) / len(blocks) * 1000:0.2f}ms")
if __name__ == "__main__":
@ -184,5 +263,4 @@ if __name__ == "__main__":
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
logger.setLevel(logging.WARNING)
asyncio.run(run_mempool_benchmark(True))
asyncio.run(run_mempool_benchmark(False))
asyncio.run(run_mempool_benchmark())

View File

@ -1,12 +1,13 @@
from __future__ import annotations
import enum
import os
import random
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from typing import Tuple, Union
from typing import Any, Generic, Optional, Tuple, Type, TypeVar, Union
import aiosqlite
import click
@ -34,13 +35,16 @@ with open(Path(os.path.realpath(__file__)).parent / "clvm_generator.bin", "rb")
clvm_generator = f.read()
_T_Enum = TypeVar("_T_Enum", bound=enum.Enum)
# Workaround to allow `Enum` with click.Choice: https://github.com/pallets/click/issues/605#issuecomment-901099036
class EnumType(click.Choice):
def __init__(self, enum, case_sensitive=False):
class EnumType(click.Choice, Generic[_T_Enum]):
def __init__(self, enum: Type[_T_Enum], case_sensitive: bool = False) -> None:
self.__enum = enum
super().__init__(choices=[item.value for item in enum], case_sensitive=case_sensitive)
def convert(self, value, param, ctx):
def convert(self, value: Any, param: Optional[click.Parameter], ctx: Optional[click.Context]) -> _T_Enum:
converted_str = super().convert(value, param, ctx)
return self.__enum(converted_str)
@ -51,7 +55,7 @@ def rewards(height: uint32) -> Tuple[Coin, Coin]:
return farmer_coin, pool_coin
def rand_bytes(num) -> bytes:
def rand_bytes(num: int) -> bytes:
ret = bytearray(num)
for i in range(num):
ret[i] = random.getrandbits(8)
@ -175,7 +179,7 @@ def rand_full_block() -> FullBlock:
return full_block
async def setup_db(name: Union[str, os.PathLike], db_version: int) -> DBWrapper2:
async def setup_db(name: Union[str, os.PathLike[str]], db_version: int) -> DBWrapper2:
db_filename = Path(name)
try:
os.unlink(db_filename)
@ -183,7 +187,7 @@ async def setup_db(name: Union[str, os.PathLike], db_version: int) -> DBWrapper2
pass
connection = await aiosqlite.connect(db_filename)
def sql_trace_callback(req: str):
def sql_trace_callback(req: str) -> None:
sql_log_path = "sql.log"
timestamp = datetime.now().strftime("%H:%M:%S.%f")
log = open(sql_log_path, "a")

View File

@ -4,8 +4,7 @@ from setuptools_scm import get_version
# example: 1.0b5.dev225
def main():
def main() -> None:
scm_full_version = get_version(root="..", relative_to=__file__)
# scm_full_version = "1.0.5.dev22"

@ -1 +1 @@
Subproject commit c6f5b7e93d22b3fb1e9d891dbe84d15cfc877761
Subproject commit 69b15a42328199eecbe714302070bfd05b098031

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import json
import random
from contextlib import asynccontextmanager
from dataclasses import dataclass
@ -13,9 +14,11 @@ from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.bundle_tools import simple_solution_generator
from chia.full_node.coin_store import CoinStore
from chia.full_node.mempool_check_conditions import get_puzzle_and_solution_for_coin
from chia.full_node.mempool import Mempool
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, get_puzzle_and_solution_for_coin
from chia.full_node.mempool_manager import MempoolManager
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import INFINITE_COST
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.coin_spend import CoinSpend
@ -54,6 +57,31 @@ async def sim_and_client(
await sim.close()
class CostLogger:
def __init__(self) -> None:
self.cost_dict: Dict[str, int] = {}
self.cost_dict_no_puzs: Dict[str, int] = {}
def add_cost(self, descriptor: str, spend_bundle: SpendBundle) -> SpendBundle:
program: BlockGenerator = simple_solution_generator(spend_bundle)
npc_result: NPCResult = get_name_puzzle_conditions(
program, INFINITE_COST, mempool_mode=True, height=DEFAULT_CONSTANTS.SOFT_FORK2_HEIGHT
)
self.cost_dict[descriptor] = npc_result.cost
cost_to_subtract: int = 0
for cs in spend_bundle.coin_spends:
cost_to_subtract += len(bytes(cs.puzzle_reveal)) * DEFAULT_CONSTANTS.COST_PER_BYTE
self.cost_dict_no_puzs[descriptor] = npc_result.cost - cost_to_subtract
return spend_bundle
def log_cost_statistics(self) -> str:
merged_dict = {
"standard cost": self.cost_dict,
"no puzzle reveals": self.cost_dict_no_puzs,
}
return json.dumps(merged_dict, indent=4)
@streamable
@dataclass(frozen=True)
class SimFullBlock(Streamable):
@ -101,7 +129,6 @@ _T_SpendSim = TypeVar("_T_SpendSim", bound="SpendSim")
class SpendSim:
db_wrapper: DBWrapper2
coin_store: CoinStore
mempool_manager: MempoolManager
@ -139,8 +166,7 @@ class SpendSim:
self.block_height = store_data.block_height
self.block_records = store_data.block_records
self.blocks = store_data.blocks
# Create a protocol to make BlockRecord and SimBlockRecord interchangeable.
self.mempool_manager.peak = self.block_records[-1] # type: ignore[assignment]
self.mempool_manager.peak = self.block_records[-1]
else:
self.timestamp = uint64(1)
self.block_height = uint32(0)
@ -160,8 +186,7 @@ class SpendSim:
await self.db_wrapper.close()
async def new_peak(self) -> None:
# Create a protocol to make BlockRecord and SimBlockRecord interchangeable.
await self.mempool_manager.new_peak(self.block_records[-1], None) # type: ignore[arg-type]
await self.mempool_manager.new_peak(self.block_records[-1], None)
def new_coin_record(self, coin: Coin, coinbase: bool = False) -> CoinRecord:
return CoinRecord(
@ -194,13 +219,12 @@ class SpendSim:
async def farm_block(
self,
puzzle_hash: bytes32 = bytes32(b"0" * 32),
item_inclusion_filter: Optional[Callable[[MempoolManager, MempoolItem], bool]] = None,
item_inclusion_filter: Optional[Callable[[bytes32], bool]] = None,
) -> Tuple[List[Coin], List[Coin]]:
# Fees get calculated
fees = uint64(0)
if self.mempool_manager.mempool.spends:
for _, item in self.mempool_manager.mempool.spends.items():
fees = uint64(fees + item.spend_bundle.fees())
for item in self.mempool_manager.mempool.all_spends():
fees = uint64(fees + item.fee)
# Rewards get created
next_block_height: uint32 = uint32(self.block_height + 1) if len(self.block_records) > 0 else self.block_height
@ -224,7 +248,7 @@ class SpendSim:
generator_bundle: Optional[SpendBundle] = None
return_additions: List[Coin] = []
return_removals: List[Coin] = []
if (len(self.block_records) > 0) and (self.mempool_manager.mempool.spends):
if (len(self.block_records) > 0) and (self.mempool_manager.mempool.size() > 0):
peak = self.mempool_manager.peak
if peak is not None:
result = self.mempool_manager.create_bundle_from_mempool(peak.header_hash, item_inclusion_filter)
@ -273,7 +297,8 @@ class SpendSim:
self.block_records = new_br_list
self.blocks = new_block_list
await self.coin_store.rollback_to_block(block_height)
self.mempool_manager.mempool.spends = {}
old_pool = self.mempool_manager.mempool
self.mempool_manager.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.block_height = block_height
if new_br_list:
self.timestamp = new_br_list[-1].timestamp
@ -402,12 +427,12 @@ class SimClient:
return CoinSpend(coin_record.coin, puzzle, solution)
async def get_all_mempool_tx_ids(self) -> List[bytes32]:
return list(self.service.mempool_manager.mempool.spends.keys())
return self.service.mempool_manager.mempool.all_spend_ids()
async def get_all_mempool_items(self) -> Dict[bytes32, MempoolItem]:
spends = {}
for tx_id, item in self.service.mempool_manager.mempool.spends.items():
spends[tx_id] = item
for item in self.service.mempool_manager.mempool.all_spends():
spends[item.name] = item
return spends
async def get_mempool_item_by_tx_id(self, tx_id: bytes32) -> Optional[Dict[str, Any]]:

View File

@ -0,0 +1,415 @@
from __future__ import annotations
import asyncio
import sys
from collections import defaultdict
from pathlib import Path
from sqlite3 import Row
from typing import Any, Dict, Iterable, List, Optional, Set
from chia.util.collection import find_duplicates
from chia.util.db_synchronous import db_synchronous_on
from chia.util.db_wrapper import DBWrapper2, execute_fetchone
from chia.util.pprint import print_compact_ranges
from chia.wallet.util.wallet_types import WalletType
# TODO: Check for missing paired wallets (eg. No DID wallet for an NFT)
# TODO: Check for missing DID Wallets
help_text = """
\b
The purpose of this command is find potential issues in Chia wallet databases.
The core chia client currently uses sqlite to store the wallet databases, one database per key.
\b
Guide to warning diagnostics:
----------------------------
"Missing Wallet IDs": A wallet was created and later deleted. By itself, this is okay because
the wallet does not reuse wallet IDs. However, this information may be useful
in conjunction with other information.
\b
Guide to error diagnostics:
--------------------------
Diagnostics in the error section indicate an error in the database structure.
In general, this does not indicate an error in on-chain data, nor does it mean that you have lost coins.
\b
An example is "Missing DerivationPath indexes" - a derivation path is a sub-key of your master key. Missing
derivation paths could cause your wallet to not "know" about transactions that happened on the blockchain.
\b
"""
def _validate_args_addresses_used(wallet_id: int, last_index: int, last_hardened: int, dp: DerivationPath) -> None:
if last_hardened:
if last_hardened != dp.hardened:
raise ValueError(f"Invalid argument: Mix of hardened and unhardened columns wallet_id={wallet_id}")
if last_index:
if last_index != dp.derivation_index:
raise ValueError(f"Invalid argument: noncontiguous derivation_index at {last_index} wallet_id={wallet_id}")
def check_addresses_used_contiguous(derivation_paths: List[DerivationPath]) -> List[str]:
"""
The used column for addresses in the derivation_paths table should be a
zero or greater run of 1's, followed by a zero or greater run of 0's.
There should be no used derivations after seeing a used derivation.
"""
errors: List[str] = []
for wallet_id, dps in dp_by_wallet_id(derivation_paths).items():
saw_unused = False
bad_used_values: Set[int] = set()
ordering_errors: List[str] = []
# last_index = None
# last_hardened = None
for dp in dps:
# _validate_args_addresses_used(wallet_id, last_index, last_hardened, dp)
if saw_unused and dp.used == 1 and ordering_errors == []:
ordering_errors.append(
f"Wallet {dp.wallet_id}: "
f"Used address after unused address at derivation index {dp.derivation_index}"
)
if dp.used == 1:
pass
elif dp.used == 0:
saw_unused = True
else:
bad_used_values.add(dp.used)
# last_hardened = dp.hardened
# last_index = dp.derivation_index
if len(bad_used_values) > 0:
errors.append(f"Wallet {wallet_id}: Bad values in 'used' column: {bad_used_values}")
if ordering_errors != []:
errors.extend(ordering_errors)
return errors
def check_for_gaps(array: List[int], start: int, end: int, *, data_type_plural: str = "Elements") -> List[str]:
"""
Check for compact sequence:
Check that every value from start to end is present in array, and no more.
start and end are values, not indexes
start and end should be included in array
array can be unsorted
"""
if start > end:
raise ValueError(f"{__name__} called with incorrect arguments: start={start} end={end} (start > end)")
errors: List[str] = []
if start == end and len(array) == 1:
return errors
expected_set = set(range(start, end + 1))
actual_set = set(array)
missing = expected_set.difference(actual_set)
extras = actual_set.difference(expected_set)
duplicates = find_duplicates(array)
if len(missing) > 0:
errors.append(f"Missing {data_type_plural}: {print_compact_ranges(list(missing))}")
if len(extras) > 0:
errors.append(f"Unexpected {data_type_plural}: {extras}")
if len(duplicates) > 0:
errors.append(f"Duplicate {data_type_plural}: {duplicates}")
return errors
class FromDB:
def __init__(self, row: Iterable[Any], fields: List[str]) -> None:
self.fields = fields
for field, value in zip(fields, row):
setattr(self, field, value)
def __repr__(self) -> str:
s = ""
for f in self.fields:
s += f"{f}={getattr(self, f)} "
return s
def wallet_type_name(
wallet_type: int,
) -> str:
if wallet_type in set(wt.value for wt in WalletType):
return f"{WalletType(wallet_type).name} ({wallet_type})"
else:
return f"INVALID_WALLET_TYPE ({wallet_type})"
def _cwr(row: Row) -> List[Any]:
r = []
for i, v in enumerate(row):
if i == 2:
r.append(wallet_type_name(v))
else:
r.append(v)
return r
# wallet_types_that_dont_need_derivations: See require_derivation_paths for each wallet type
wallet_types_that_dont_need_derivations = {WalletType.POOLING_WALLET, WalletType.NFT}
class DerivationPath(FromDB):
derivation_index: int
pubkey: str
puzzle_hash: str
wallet_type: WalletType
wallet_id: int
used: int # 1 or 0
hardened: int # 1 or 0
class Wallet(FromDB):
id: int # id >= 1
name: str
wallet_type: WalletType
data: str
def dp_by_wallet_id(derivation_paths: List[DerivationPath]) -> Dict[int, List[DerivationPath]]:
d = defaultdict(list)
for derivation_path in derivation_paths:
d[derivation_path.wallet_id].append(derivation_path)
for k, v in d.items():
d[k] = sorted(v, key=lambda dp: dp.derivation_index)
return d
def derivation_indices_by_wallet_id(derivation_paths: List[DerivationPath]) -> Dict[int, List[int]]:
d = dp_by_wallet_id(derivation_paths)
di = {}
for k, v in d.items():
di[k] = [dp.derivation_index for dp in v]
return di
def print_min_max_derivation_for_wallets(derivation_paths: List[DerivationPath]) -> None:
d = derivation_indices_by_wallet_id(derivation_paths)
print("Min, Max, Count of derivations for each wallet:")
for wallet_id, derivation_index_list in d.items():
# TODO: Fix count by separating hardened and unhardened
print(
f"Wallet ID {wallet_id:2} derivation index min: {derivation_index_list[0]} "
f"max: {derivation_index_list[-1]} count: {len(derivation_index_list)}"
)
class WalletDBReader:
db_wrapper: DBWrapper2 # TODO: Remove db_wrapper member
config = {"db_readers": 1}
sql_log_path = None
verbose = False
async def get_all_wallets(self) -> List[Wallet]:
wallet_fields = ["id", "name", "wallet_type", "data"]
async with self.db_wrapper.reader_no_transaction() as reader:
# TODO: if table doesn't exist
cursor = await reader.execute(f"""SELECT {", ".join(wallet_fields)} FROM users_wallets""")
rows = await cursor.fetchall()
return [Wallet(r, wallet_fields) for r in rows]
async def get_derivation_paths(self) -> List[DerivationPath]:
fields = ["derivation_index", "pubkey", "puzzle_hash", "wallet_type", "wallet_id", "used", "hardened"]
async with self.db_wrapper.reader_no_transaction() as reader:
# TODO: if table doesn't exist
cursor = await reader.execute(f"""SELECT {", ".join(fields)} FROM derivation_paths;""")
rows = await cursor.fetchall()
return [DerivationPath(row, fields) for row in rows]
async def show_tables(self) -> List[str]:
async with self.db_wrapper.reader_no_transaction() as reader:
cursor = await reader.execute("""SELECT name FROM sqlite_master WHERE type='table';""")
print("\nWallet DB Tables:")
print(*([r[0] for r in await cursor.fetchall()]), sep=",\n")
print("\nWallet Schema:")
print(*(await (await cursor.execute("PRAGMA table_info('users_wallets')")).fetchall()), sep=",\n")
print("\nDerivationPath Schema:")
print(*(await (await cursor.execute("PRAGMA table_info('derivation_paths')")).fetchall()), sep=",\n")
print()
return []
async def check_wallets(self) -> List[str]:
# id, name, wallet_type, data
# TODO: Move this SQL up a level
async with self.db_wrapper.reader_no_transaction() as reader:
errors = []
try:
main_wallet_id = 1
main_wallet_type = WalletType.STANDARD_WALLET
row = await execute_fetchone(reader, "SELECT * FROM users_wallets WHERE id=?", (main_wallet_id,))
if row is None:
errors.append(f"There is no wallet with ID {main_wallet_id} in table users_wallets")
elif row[2] != main_wallet_type:
errors.append(
f"We expect wallet {main_wallet_id} to have type {wallet_type_name(main_wallet_type)}, "
f"but it has {wallet_type_name(row[2])}"
)
except Exception as e:
errors.append(f"Exception while trying to access wallet {main_wallet_id} from users_wallets: {e}")
max_id_row = await execute_fetchone(reader, "SELECT MAX(id) FROM users_wallets")
if max_id_row is None:
errors.append("Error fetching max wallet ID from table users_wallets. No wallets ?!?")
else:
cursor = await reader.execute("""SELECT * FROM users_wallets""")
rows = await cursor.fetchall()
max_id = max_id_row[0]
errors.extend(check_for_gaps([r[0] for r in rows], main_wallet_id, max_id, data_type_plural="Wallet IDs"))
if self.verbose:
print("\nWallets:")
print(*[_cwr(r) for r in rows], sep=",\n")
# Check for invalid wallet types in users_wallets
invalid_wallet_types = set()
for row in rows:
if row[2] not in set(wt.value for wt in WalletType):
invalid_wallet_types.add(row[2])
if len(invalid_wallet_types) > 0:
errors.append(f"Invalid Wallet Types found in table users_wallets: {invalid_wallet_types}")
return errors
def check_wallets_missing_derivations(
self, wallets: List[Wallet], derivation_paths: List[DerivationPath]
) -> List[str]:
p = []
d = derivation_indices_by_wallet_id(derivation_paths) # TODO: calc this once, pass in
for w in wallets:
if w.wallet_type not in wallet_types_that_dont_need_derivations and w.id not in d:
p.append(w.id)
if len(p) > 0:
return [f"Wallet IDs with no derivations that require them: {p}"]
return []
def check_derivations_are_compact(self, wallets: List[Wallet], derivation_paths: List[DerivationPath]) -> List[str]:
errors = []
"""
Gaps in derivation index
Missing hardened or unhardened derivations
TODO: Gaps in used derivations
"""
for wallet_id in [w.id for w in wallets]:
for hardened in [0, 1]:
dps = list(filter(lambda x: x.wallet_id == wallet_id and x.hardened == hardened, derivation_paths))
if len(dps) < 1:
continue
dpi = [x.derivation_index for x in dps]
dpi.sort()
max_id = dpi[-1]
h = [" hardened", "unhardened"][hardened]
errors.extend(
check_for_gaps(
dpi, 0, max_id, data_type_plural=f"DerivationPath indexes for {h} wallet_id={wallet_id}"
)
)
return errors
def check_unexpected_derivation_entries(
self, wallets: List[Wallet], derivation_paths: List[DerivationPath]
) -> List[str]:
"""
Check for unexpected derivation path entries
Invalid Wallet Type
Wallet IDs not in table 'users_wallets'
Wallet ID with different wallet_type
"""
errors = []
wallet_id_to_type = {w.id: w.wallet_type for w in wallets}
invalid_wallet_types = []
missing_wallet_ids = []
wrong_type = defaultdict(list)
for d in derivation_paths:
if d.wallet_type not in set(wt.value for wt in WalletType):
invalid_wallet_types.append(d.wallet_type)
if d.wallet_id not in wallet_id_to_type:
missing_wallet_ids.append(d.wallet_id)
elif d.wallet_type != wallet_id_to_type[d.wallet_id]:
wrong_type[(d.hardened, d.wallet_id, d.wallet_type, wallet_id_to_type[d.wallet_id])].append(
d.derivation_index
)
if len(invalid_wallet_types) > 0:
errors.append(f"Invalid wallet_types in derivation_paths table: {invalid_wallet_types}")
if len(missing_wallet_ids) > 0:
errors.append(
f"Wallet IDs found in derivation_paths table, but not in users_wallets table: {missing_wallet_ids}"
)
for k, v in wrong_type.items():
errors.append(
f"""{[" ", "un"][int(k[0])]}hardened Wallet ID {k[1]} uses type {wallet_type_name(k[2])} in """
f"derivation_paths, but type {wallet_type_name(k[3])} in wallet table at these derivation indices: {v}"
)
return errors
async def scan(self, db_path: Path) -> int:
"""Returns number of lines of error output (not warnings)"""
self.db_wrapper = await DBWrapper2.create(
database=db_path,
reader_count=self.config.get("db_readers", 4),
log_path=self.sql_log_path,
synchronous=db_synchronous_on("auto"),
)
# TODO: Pass down db_wrapper
wallets = await self.get_all_wallets()
derivation_paths = await self.get_derivation_paths()
errors = []
warnings = []
try:
if self.verbose:
await self.show_tables()
print_min_max_derivation_for_wallets(derivation_paths)
warnings.extend(await self.check_wallets())
errors.extend(self.check_wallets_missing_derivations(wallets, derivation_paths))
errors.extend(self.check_unexpected_derivation_entries(wallets, derivation_paths))
errors.extend(self.check_derivations_are_compact(wallets, derivation_paths))
errors.extend(check_addresses_used_contiguous(derivation_paths))
if len(warnings) > 0:
print(f" ---- Warnings Found for {db_path.name} ----")
print("\n".join(warnings))
if len(errors) > 0:
print(f" ---- Errors Found for {db_path.name}----")
print("\n".join(errors))
finally:
await self.db_wrapper.close()
return len(errors)
async def scan(root_path: str, db_path: Optional[str] = None, *, verbose: bool = False) -> None:
if db_path is None:
wallet_db_path = Path(root_path) / "wallet" / "db"
wallet_db_paths = list(wallet_db_path.glob("blockchain_wallet_*.sqlite"))
else:
wallet_db_paths = [Path(db_path)]
num_errors = 0
for wallet_db_path in wallet_db_paths:
w = WalletDBReader()
w.verbose = verbose
print(f"Reading {wallet_db_path}")
num_errors += await w.scan(Path(wallet_db_path))
if num_errors > 0:
sys.exit(2)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(scan("", sys.argv[1]))

View File

@ -7,6 +7,7 @@ import click
from chia import __version__
from chia.cmds.beta import beta_cmd
from chia.cmds.completion import completion
from chia.cmds.configure import configure_cmd
from chia.cmds.data import data_cmd
from chia.cmds.db import db_cmd
@ -125,6 +126,7 @@ cli.add_command(peer_cmd)
cli.add_command(data_cmd)
cli.add_command(passphrase_cmd)
cli.add_command(beta_cmd)
cli.add_command(completion)
def main() -> None:

View File

@ -4,7 +4,7 @@ import logging
import traceback
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple, Type
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from aiohttp import ClientConnectorError
@ -31,6 +31,17 @@ NODE_TYPES: Dict[str, Type[RpcClient]] = {
"data_layer": DataLayerRpcClient,
}
node_config_section_names: Dict[Type[RpcClient], str] = {
FarmerRpcClient: "farmer",
WalletRpcClient: "wallet",
FullNodeRpcClient: "full_node",
HarvesterRpcClient: "harvester",
DataLayerRpcClient: "data_layer",
}
_T_RpcClient = TypeVar("_T_RpcClient", bound=RpcClient)
def transaction_submitted_msg(tx: TransactionRecord) -> str:
sent_to = [MempoolSubmissionStatus(s[0], s[1], s[2]).to_json_dict_convenience() for s in tx.sent_to]
@ -49,7 +60,6 @@ async def validate_client_connection(
fingerprint: Optional[int],
login_to_wallet: bool,
) -> Optional[int]:
try:
await rpc_client.healthz()
if type(rpc_client) == WalletRpcClient and login_to_wallet:
@ -66,28 +76,29 @@ async def validate_client_connection(
@asynccontextmanager
async def get_any_service_client(
node_type: str,
client_type: Type[_T_RpcClient],
rpc_port: Optional[int] = None,
root_path: Path = DEFAULT_ROOT_PATH,
fingerprint: Optional[int] = None,
login_to_wallet: bool = True,
) -> AsyncIterator[Tuple[Optional[Any], Dict[str, Any], Optional[int]]]:
) -> AsyncIterator[Tuple[Optional[_T_RpcClient], Dict[str, Any], Optional[int]]]:
"""
Yields a tuple with a RpcClient for the applicable node type a dictionary of the node's configuration,
and a fingerprint if applicable. However, if connecting to the node fails then we will return None for
the RpcClient.
"""
if node_type not in NODE_TYPES.keys():
node_type = node_config_section_names.get(client_type)
if node_type is None:
# Click already checks this, so this should never happen
raise ValueError(f"Invalid node type: {node_type}")
raise ValueError(f"Invalid client type requested: {client_type.__name__}")
# load variables from config file
config = load_config(root_path, "config.yaml", fill_missing_services=node_type == "data_layer")
config = load_config(root_path, "config.yaml", fill_missing_services=issubclass(client_type, DataLayerRpcClient))
self_hostname = config["self_hostname"]
if rpc_port is None:
rpc_port = config[node_type]["rpc_port"]
# select node client type based on string
node_client = await NODE_TYPES[node_type].create(self_hostname, uint16(rpc_port), root_path, config)
node_client = await client_type.create(self_hostname, uint16(rpc_port), root_path, config)
try:
# check if we can connect to node, and if we can then validate
# fingerprint access, otherwise return fingerprint and shutdown client
@ -111,89 +122,90 @@ async def get_wallet(root_path: Path, wallet_client: WalletRpcClient, fingerprin
keychain_proxy: Optional[KeychainProxy] = None
all_keys: List[KeyData] = []
if fingerprint is not None:
selected_fingerprint = fingerprint
else:
keychain_proxy = await connect_to_keychain_and_validate(root_path, log=logging.getLogger(__name__))
if keychain_proxy is None:
raise RuntimeError("Failed to connect to keychain")
# we're only interested in the fingerprints and labels
all_keys = await keychain_proxy.get_keys(include_secrets=False)
# we don't immediately close the keychain proxy connection because it takes a noticeable amount of time
fingerprints = [key.fingerprint for key in all_keys]
if len(fingerprints) == 0:
print("No keys loaded. Run 'chia keys generate' or import a key")
elif len(fingerprints) == 1:
# if only a single key is available, select it automatically
selected_fingerprint = fingerprints[0]
try:
if fingerprint is not None:
selected_fingerprint = fingerprint
else:
keychain_proxy = await connect_to_keychain_and_validate(root_path, log=logging.getLogger(__name__))
if keychain_proxy is None:
raise RuntimeError("Failed to connect to keychain")
# we're only interested in the fingerprints and labels
all_keys = await keychain_proxy.get_keys(include_secrets=False)
# we don't immediately close the keychain proxy connection because it takes a noticeable amount of time
fingerprints = [key.fingerprint for key in all_keys]
if len(fingerprints) == 0:
print("No keys loaded. Run 'chia keys generate' or import a key")
elif len(fingerprints) == 1:
# if only a single key is available, select it automatically
selected_fingerprint = fingerprints[0]
if selected_fingerprint is None and len(all_keys) > 0:
logged_in_fingerprint: Optional[int] = await wallet_client.get_logged_in_fingerprint()
logged_in_key: Optional[KeyData] = None
if logged_in_fingerprint is not None:
logged_in_key = next((key for key in all_keys if key.fingerprint == logged_in_fingerprint), None)
current_sync_status: str = ""
indent = " "
if logged_in_key is not None:
if await wallet_client.get_synced():
current_sync_status = "Synced"
elif await wallet_client.get_sync_status():
current_sync_status = "Syncing"
else:
current_sync_status = "Not Synced"
print()
print("Active Wallet Key (*):")
print(f"{indent}{'-Fingerprint:'.ljust(23)} {logged_in_key.fingerprint}")
if logged_in_key.label is not None:
print(f"{indent}{'-Label:'.ljust(23)} {logged_in_key.label}")
print(f"{indent}{'-Sync Status:'.ljust(23)} {current_sync_status}")
max_key_index_width = 5 # e.g. "12) *", "1) *", or "2) "
max_fingerprint_width = 10 # fingerprint is a 32-bit number
print()
print("Wallet Keys:")
for i, key in enumerate(all_keys):
key_index_str = f"{(str(i + 1) + ')'):<4}"
key_index_str += "*" if key.fingerprint == logged_in_fingerprint else " "
print(
f"{key_index_str:<{max_key_index_width}} "
f"{key.fingerprint:<{max_fingerprint_width}}"
f"{(indent + key.label) if key.label else ''}"
)
val = None
prompt: str = (
f"Choose a wallet key [1-{len(fingerprints)}] ('q' to quit, or Enter to use {logged_in_fingerprint}): "
)
while val is None:
val = input(prompt)
if val == "q":
break
elif val == "" and logged_in_fingerprint is not None:
fingerprint = logged_in_fingerprint
break
elif not val.isdigit():
val = None
else:
index = int(val) - 1
if index < 0 or index >= len(fingerprints):
print("Invalid value")
val = None
continue
if selected_fingerprint is None and len(all_keys) > 0:
logged_in_fingerprint: Optional[int] = await wallet_client.get_logged_in_fingerprint()
logged_in_key: Optional[KeyData] = None
if logged_in_fingerprint is not None:
logged_in_key = next((key for key in all_keys if key.fingerprint == logged_in_fingerprint), None)
current_sync_status: str = ""
indent = " "
if logged_in_key is not None:
if await wallet_client.get_synced():
current_sync_status = "Synced"
elif await wallet_client.get_sync_status():
current_sync_status = "Syncing"
else:
fingerprint = fingerprints[index]
current_sync_status = "Not Synced"
selected_fingerprint = fingerprint
print()
print("Active Wallet Key (*):")
print(f"{indent}{'-Fingerprint:'.ljust(23)} {logged_in_key.fingerprint}")
if logged_in_key.label is not None:
print(f"{indent}{'-Label:'.ljust(23)} {logged_in_key.label}")
print(f"{indent}{'-Sync Status:'.ljust(23)} {current_sync_status}")
max_key_index_width = 5 # e.g. "12) *", "1) *", or "2) "
max_fingerprint_width = 10 # fingerprint is a 32-bit number
print()
print("Wallet Keys:")
for i, key in enumerate(all_keys):
key_index_str = f"{(str(i + 1) + ')'):<4}"
key_index_str += "*" if key.fingerprint == logged_in_fingerprint else " "
print(
f"{key_index_str:<{max_key_index_width}} "
f"{key.fingerprint:<{max_fingerprint_width}}"
f"{(indent + key.label) if key.label else ''}"
)
val = None
prompt: str = (
f"Choose a wallet key [1-{len(fingerprints)}] ('q' to quit, or Enter to use {logged_in_fingerprint}): "
)
while val is None:
val = input(prompt)
if val == "q":
break
elif val == "" and logged_in_fingerprint is not None:
fingerprint = logged_in_fingerprint
break
elif not val.isdigit():
val = None
else:
index = int(val) - 1
if index < 0 or index >= len(fingerprints):
print("Invalid value")
val = None
continue
else:
fingerprint = fingerprints[index]
if selected_fingerprint is not None:
log_in_response = await wallet_client.log_in(selected_fingerprint)
selected_fingerprint = fingerprint
if log_in_response["success"] is False:
print(f"Login failed for fingerprint {selected_fingerprint}: {log_in_response}")
selected_fingerprint = None
if selected_fingerprint is not None:
log_in_response = await wallet_client.log_in(selected_fingerprint)
# Closing the keychain proxy takes a moment, so we wait until after the login is complete
if keychain_proxy is not None:
await keychain_proxy.close()
if log_in_response["success"] is False:
print(f"Login failed for fingerprint {selected_fingerprint}: {log_in_response}")
selected_fingerprint = None
finally:
# Closing the keychain proxy takes a moment, so we wait until after the login is complete
if keychain_proxy is not None:
await keychain_proxy.close()
return selected_fingerprint
@ -204,8 +216,11 @@ async def execute_with_wallet(
extra_params: Dict[str, Any],
function: Callable[[Dict[str, Any], WalletRpcClient, int], Awaitable[None]],
) -> None:
wallet_client: Optional[WalletRpcClient]
async with get_any_service_client("wallet", wallet_rpc_port, fingerprint=fingerprint) as (wallet_client, _, new_fp):
async with get_any_service_client(WalletRpcClient, wallet_rpc_port, fingerprint=fingerprint) as (
wallet_client,
_,
new_fp,
):
if wallet_client is not None:
assert new_fp is not None # wallet only sanity check
await function(extra_params, wallet_client, new_fp)

View File

@ -23,7 +23,7 @@ def coins_cmd(ctx: click.Context) -> None:
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option("-u", "--show-unconfirmed", help="Separately display unconfirmed coins.", is_flag=True)
@click.option(
@ -93,7 +93,7 @@ def list_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option(
"-a",
@ -187,7 +187,7 @@ def combine_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option(
"-n",

49
chia/cmds/completion.py Normal file
View File

@ -0,0 +1,49 @@
from __future__ import annotations
import os
import subprocess
from pathlib import Path
import click
SHELLS = ["bash", "zsh", "fish"]
shell = os.environ.get("SHELL")
if shell is not None:
shell = Path(shell).name
if shell not in SHELLS:
shell = None
@click.group(
short_help="Generate shell completion",
)
def completion() -> None:
pass
@completion.command(short_help="Generate shell completion code")
@click.option(
"-s",
"--shell",
type=click.Choice(SHELLS),
default=shell,
show_default=True,
required=shell is None,
help="Shell type to generate for",
)
def generate(shell: str) -> None:
"""
\b
Generate shell completion code for the current, or specified (-s)hell.
You will need to 'source' this code to enable shell completion.
You can source it directly (performs slower) by running:
\033[3;33meval "$(chia complete generate)"\033[0m
or you can save the output to a file:
\033[3;33mchia complete generate > ~/.chia-complete-bash\033[0m
and source that file with:
\033[3;33m. ~/.chia-complete-bash\033[0m
"""
# Could consider calling directly in the future.
# https://github.com/pallets/click/blob/ef11be6e49e19a055fe7e5a89f0f1f4062c68dba/src/click/shell_completion.py#L17
subprocess.run(["chia"], check=True, env={**os.environ, "_CHIA_COMPLETE": f"{shell}_source"})

View File

@ -24,7 +24,7 @@ def configure(
crawler_minimum_version_count: Optional[int],
seeder_domain_name: str,
seeder_nameserver: str,
):
) -> None:
config_yaml = "config.yaml"
with lock_and_load_config(root_path, config_yaml, fill_missing_services=True) as config:
config.update(load_defaults_for_missing_services(config=config, config_name=config_yaml))
@ -269,22 +269,22 @@ def configure(
)
@click.pass_context
def configure_cmd(
ctx,
set_farmer_peer,
set_node_introducer,
set_fullnode_port,
set_harvester_port,
set_log_level,
enable_upnp,
set_outbound_peer_count,
set_peer_count,
testnet,
set_peer_connect_timeout,
crawler_db_path,
crawler_minimum_version_count,
seeder_domain_name,
seeder_nameserver,
):
ctx: click.Context,
set_farmer_peer: str,
set_node_introducer: str,
set_fullnode_port: str,
set_harvester_port: str,
set_log_level: str,
enable_upnp: str,
set_outbound_peer_count: str,
set_peer_count: str,
testnet: str,
set_peer_connect_timeout: str,
crawler_db_path: str,
crawler_minimum_version_count: int,
seeder_domain_name: str,
seeder_nameserver: str,
) -> None:
configure(
ctx.obj["root_path"],
set_farmer_peer,

View File

@ -6,6 +6,7 @@ from typing import Dict, List, Optional
from chia.cmds.cmds_util import get_any_service_client
from chia.cmds.units import units
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
from chia.util.ints import uint64
@ -13,7 +14,7 @@ from chia.util.ints import uint64
async def create_data_store_cmd(rpc_port: Optional[int], fee: Optional[str]) -> None:
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.create_data_store(fee=final_fee)
print(res)
@ -23,7 +24,7 @@ async def get_value_cmd(rpc_port: Optional[int], store_id: str, key: str, root_h
store_id_bytes = bytes32.from_hexstr(store_id)
key_bytes = hexstr_to_bytes(key)
root_hash_bytes = None if root_hash is None else bytes32.from_hexstr(root_hash)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_value(store_id=store_id_bytes, key=key_bytes, root_hash=root_hash_bytes)
print(res)
@ -37,7 +38,7 @@ async def update_data_store_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.update_data_store(store_id=store_id_bytes, changelist=changelist, fee=final_fee)
print(res)
@ -50,7 +51,7 @@ async def get_keys_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
root_hash_bytes = None if root_hash is None else bytes32.from_hexstr(root_hash)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_keys(store_id=store_id_bytes, root_hash=root_hash_bytes)
print(res)
@ -63,7 +64,7 @@ async def get_keys_values_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
root_hash_bytes = None if root_hash is None else bytes32.from_hexstr(root_hash)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_keys_values(store_id=store_id_bytes, root_hash=root_hash_bytes)
print(res)
@ -74,7 +75,7 @@ async def get_root_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_root(store_id=store_id_bytes)
print(res)
@ -86,7 +87,7 @@ async def subscribe_cmd(
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.subscribe(store_id=store_id_bytes, urls=urls)
print(res)
@ -97,7 +98,7 @@ async def unsubscribe_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.unsubscribe(store_id=store_id_bytes)
print(res)
@ -109,7 +110,7 @@ async def remove_subscriptions_cmd(
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.remove_subscriptions(store_id=store_id_bytes, urls=urls)
print(res)
@ -124,7 +125,7 @@ async def get_kv_diff_cmd(
store_id_bytes = bytes32.from_hexstr(store_id)
hash_1_bytes = bytes32.from_hexstr(hash_1)
hash_2_bytes = bytes32.from_hexstr(hash_2)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_kv_diff(store_id=store_id_bytes, hash_1=hash_1_bytes, hash_2=hash_2_bytes)
print(res)
@ -135,7 +136,7 @@ async def get_root_history_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_root_history(store_id=store_id_bytes)
print(res)
@ -144,7 +145,7 @@ async def get_root_history_cmd(
async def add_missing_files_cmd(
rpc_port: Optional[int], ids: Optional[List[str]], overwrite: bool, foldername: Optional[Path]
) -> None:
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.add_missing_files(
store_ids=(None if ids is None else [bytes32.from_hexstr(id) for id in ids]),
@ -159,7 +160,7 @@ async def add_mirror_cmd(
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.add_mirror(
store_id=store_id_bytes,
@ -173,7 +174,7 @@ async def add_mirror_cmd(
async def delete_mirror_cmd(rpc_port: Optional[int], coin_id: str, fee: Optional[str]) -> None:
coin_id_bytes = bytes32.from_hexstr(coin_id)
final_fee = None if fee is None else uint64(int(Decimal(fee) * units["chia"]))
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.delete_mirror(
coin_id=coin_id_bytes,
@ -184,21 +185,21 @@ async def delete_mirror_cmd(rpc_port: Optional[int], coin_id: str, fee: Optional
async def get_mirrors_cmd(rpc_port: Optional[int], store_id: str) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_mirrors(store_id=store_id_bytes)
print(res)
async def get_subscriptions_cmd(rpc_port: Optional[int]) -> None:
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_subscriptions()
print(res)
async def get_owned_stores_cmd(rpc_port: Optional[int]) -> None:
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_owned_stores()
print(res)
@ -209,7 +210,7 @@ async def get_sync_status_cmd(
store_id: str,
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
async with get_any_service_client("data_layer", rpc_port) as (client, config, _):
async with get_any_service_client(DataLayerRpcClient, rpc_port) as (client, _, _):
if client is not None:
res = await client.get_sync_status(store_id=store_id_bytes)
print(res)

View File

@ -1,6 +1,7 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
import click
@ -15,8 +16,8 @@ def db_cmd() -> None:
@db_cmd.command("upgrade", short_help="upgrade a v1 database to v2")
@click.option("--input", default=None, type=click.Path(), help="specify input database file")
@click.option("--output", default=None, type=click.Path(), help="specify output database file")
@click.option("--input", "in_db_path", default=None, type=click.Path(), help="specify input database file")
@click.option("--output", "out_db_path", default=None, type=click.Path(), help="specify output database file")
@click.option(
"--no-update-config",
default=False,
@ -31,11 +32,14 @@ def db_cmd() -> None:
help="force conversion despite warnings",
)
@click.pass_context
def db_upgrade_cmd(ctx: click.Context, no_update_config: bool, force: bool, **kwargs) -> None:
def db_upgrade_cmd(
ctx: click.Context,
in_db_path: Optional[str],
out_db_path: Optional[str],
no_update_config: bool,
force: bool,
) -> None:
try:
in_db_path = kwargs.get("input")
out_db_path = kwargs.get("output")
db_upgrade_func(
Path(ctx.obj["root_path"]),
None if in_db_path is None else Path(in_db_path),
@ -48,7 +52,7 @@ def db_upgrade_cmd(ctx: click.Context, no_update_config: bool, force: bool, **kw
@db_cmd.command("validate", short_help="validate the (v2) blockchain database. Does not verify proofs")
@click.option("--db", default=None, type=click.Path(), help="Specifies which database file to validate")
@click.option("--db", "in_db_path", default=None, type=click.Path(), help="Specifies which database file to validate")
@click.option(
"--validate-blocks",
default=False,
@ -56,9 +60,8 @@ def db_upgrade_cmd(ctx: click.Context, no_update_config: bool, force: bool, **kw
help="validate consistency of properties of the encoded blocks and block records",
)
@click.pass_context
def db_validate_cmd(ctx: click.Context, validate_blocks: bool, **kwargs) -> None:
def db_validate_cmd(ctx: click.Context, in_db_path: Optional[str], validate_blocks: bool) -> None:
try:
in_db_path = kwargs.get("db")
db_validate_func(
Path(ctx.obj["root_path"]),
None if in_db_path is None else Path(in_db_path),
@ -69,12 +72,11 @@ def db_validate_cmd(ctx: click.Context, validate_blocks: bool, **kwargs) -> None
@db_cmd.command("backup", short_help="backup the blockchain database using VACUUM INTO command")
@click.option("--backup_file", default=None, type=click.Path(), help="Specifies the backup file")
@click.option("--backup_file", "db_backup_file", default=None, type=click.Path(), help="Specifies the backup file")
@click.option("--no_indexes", default=False, is_flag=True, help="Create backup without indexes")
@click.pass_context
def db_backup_cmd(ctx: click.Context, no_indexes: bool, **kwargs) -> None:
def db_backup_cmd(ctx: click.Context, db_backup_file: Optional[str], no_indexes: bool) -> None:
try:
db_backup_file = kwargs.get("backup_file")
db_backup_func(
Path(ctx.obj["root_path"]),
None if db_backup_file is None else Path(db_backup_file),

View File

@ -7,7 +7,7 @@ import sys
import textwrap
from pathlib import Path
from time import time
from typing import Dict, Optional
from typing import Any, Dict, Optional
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.config import load_config, lock_and_load_config, save_config
@ -27,10 +27,9 @@ def db_upgrade_func(
no_update_config: bool = False,
force: bool = False,
) -> None:
update_config: bool = in_db_path is None and out_db_path is None and not no_update_config
config: Dict
config: Dict[str, Any]
selected_network: str
db_pattern: str
if in_db_path is None or out_db_path is None:
@ -81,7 +80,6 @@ def db_upgrade_func(
except RuntimeError as e:
print(f"conversion failed with error: {e}.")
except Exception as e:
print(
textwrap.dedent(
f"""\
@ -167,7 +165,7 @@ def convert_v1_to_v2(in_path: Path, out_path: Path) -> None:
"block_record blob)"
)
out_db.execute(
"CREATE TABLE sub_epoch_segments_v3(" "ses_block_hash blob PRIMARY KEY," "challenge_segments blob)"
"CREATE TABLE sub_epoch_segments_v3(ses_block_hash blob PRIMARY KEY, challenge_segments blob)"
)
out_db.execute("CREATE TABLE current_peak(key int PRIMARY KEY, hash blob)")
@ -202,10 +200,8 @@ def convert_v1_to_v2(in_path: Path, out_path: Path) -> None:
"SELECT header_hash, height, is_fully_compactified, block FROM full_blocks ORDER BY height DESC"
)
) as cursor_2:
out_db.execute("begin transaction")
for row in cursor:
header_hash = bytes.fromhex(row[0])
if header_hash != hh:
continue

View File

@ -41,7 +41,6 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
print(f"opening file for reading: {in_path}")
with closing(sqlite3.connect(in_path)) as in_db:
# read the database version
try:
with closing(in_db.execute("SELECT * FROM database_version")) as cursor:
@ -91,9 +90,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
"FROM full_blocks ORDER BY height DESC"
)
) as cursor:
for row in cursor:
hh = row[0]
prev = row[1]
height = row[2]
@ -111,7 +108,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
actual_prev_hash = block.prev_header_hash
if actual_header_hash != hh:
raise RuntimeError(
f"Block {hh.hex()} has a blob with mismatching " f"hash: {actual_header_hash.hex()}"
f"Block {hh.hex()} has a blob with mismatching hash: {actual_header_hash.hex()}"
)
if block_record.header_hash != hh:
raise RuntimeError(
@ -130,7 +127,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
)
if block.height != height:
raise RuntimeError(
f"Block {hh.hex()} has a mismatching " f"height: {block.height} expected {height}"
f"Block {hh.hex()} has a mismatching height: {block.height} expected {height}"
)
if height != current_height:
@ -146,7 +143,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
if hh == expect_hash:
if next_hash is not None:
raise RuntimeError(f"Database has multiple blocks with hash {hh.hex()}, " f"at height {height}")
raise RuntimeError(f"Database has multiple blocks with hash {hh.hex()}, at height {height}")
if not in_main_chain:
raise RuntimeError(
f"block {hh.hex()} (height: {height}) is part of the main chain, "
@ -168,9 +165,7 @@ def validate_v2(in_path: Path, *, validate_blocks: bool) -> None:
else:
if in_main_chain:
raise RuntimeError(
f"block {hh.hex()} (height: {height}) is orphaned, " "but in_main_chain is set"
)
raise RuntimeError(f"block {hh.hex()} (height: {height}) is orphaned, but in_main_chain is set")
num_orphans += 1
print("")

View File

@ -44,9 +44,7 @@ def farm_cmd() -> None:
@click.option(
"-fp",
"--farmer-rpc-port",
help=(
"Set the port where the Farmer is hosting the RPC interface. " "See the rpc_port under farmer in config.yaml"
),
help=("Set the port where the Farmer is hosting the RPC interface. See the rpc_port under farmer in config.yaml"),
type=int,
default=None,
show_default=True,

View File

@ -15,8 +15,7 @@ SECONDS_PER_BLOCK = (24 * 3600) / 4608
async def get_harvesters_summary(farmer_rpc_port: Optional[int]) -> Optional[Dict[str, Any]]:
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer", farmer_rpc_port) as node_config_fp:
async with get_any_service_client(FarmerRpcClient, farmer_rpc_port) as node_config_fp:
farmer_client, _, _ = node_config_fp
if farmer_client is not None:
return await farmer_client.get_harvesters_summary()
@ -24,8 +23,7 @@ async def get_harvesters_summary(farmer_rpc_port: Optional[int]) -> Optional[Dic
async def get_blockchain_state(rpc_port: Optional[int]) -> Optional[Dict[str, Any]]:
client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port) as node_config_fp:
client, _, _ = node_config_fp
if client is not None:
return await client.get_blockchain_state()
@ -33,8 +31,7 @@ async def get_blockchain_state(rpc_port: Optional[int]) -> Optional[Dict[str, An
async def get_average_block_time(rpc_port: Optional[int]) -> float:
client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port) as node_config_fp:
client, _, _ = node_config_fp
if client is not None:
blocks_to_compare = 500
@ -58,8 +55,7 @@ async def get_average_block_time(rpc_port: Optional[int]) -> float:
async def get_wallets_stats(wallet_rpc_port: Optional[int]) -> Optional[Dict[str, Any]]:
wallet_client: Optional[WalletRpcClient]
async with get_any_service_client("wallet", wallet_rpc_port, login_to_wallet=False) as node_config_fp:
async with get_any_service_client(WalletRpcClient, wallet_rpc_port, login_to_wallet=False) as node_config_fp:
wallet_client, _, _ = node_config_fp
if wallet_client is not None:
return await wallet_client.get_farmed_amount()
@ -67,8 +63,7 @@ async def get_wallets_stats(wallet_rpc_port: Optional[int]) -> Optional[Dict[str
async def get_challenges(farmer_rpc_port: Optional[int]) -> Optional[List[Dict[str, Any]]]:
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer", farmer_rpc_port) as node_config_fp:
async with get_any_service_client(FarmerRpcClient, farmer_rpc_port) as node_config_fp:
farmer_client, _, _ = node_config_fp
if farmer_client is not None:
return await farmer_client.get_signage_points()
@ -145,7 +140,7 @@ async def summary(
harvesters_remote[ip] = {}
harvesters_remote[ip][harvester["connection"]["node_id"]] = harvester
def process_harvesters(harvester_peers_in: dict):
def process_harvesters(harvester_peers_in: Dict[str, Dict[str, Any]]) -> None:
for harvester_peer_id, harvester_dict in harvester_peers_in.items():
syncing = harvester_dict["syncing"]
if syncing is not None and syncing["initial"]:

View File

@ -24,7 +24,14 @@ import click
help="Initialize the blockchain database in v1 format (compatible with older versions of the full node)",
)
@click.pass_context
def init_cmd(ctx: click.Context, create_certs: str, fix_ssl_permissions: bool, testnet: bool, v1_db: bool, **kwargs):
def init_cmd(
ctx: click.Context,
create_certs: str,
fix_ssl_permissions: bool,
testnet: bool,
set_passphrase: bool,
v1_db: bool,
) -> None:
"""
Create a new configuration or migrate from previous versions to current
@ -43,7 +50,6 @@ def init_cmd(ctx: click.Context, create_certs: str, fix_ssl_permissions: bool, t
from .init_funcs import init
set_passphrase = kwargs.get("set_passphrase")
if set_passphrase:
initialize_passphrase()

View File

@ -41,7 +41,7 @@ from chia.wallet.derive_keys import (
)
def dict_add_new_default(updated: Dict, default: Dict, do_not_migrate_keys: Dict[str, Any]):
def dict_add_new_default(updated: Dict[str, Any], default: Dict[str, Any], do_not_migrate_keys: Dict[str, Any]) -> None:
for k in do_not_migrate_keys:
if k in updated and do_not_migrate_keys[k] == "":
updated.pop(k)
@ -155,7 +155,7 @@ def check_keys(new_root: Path, keychain: Optional[Keychain] = None) -> None:
save_config(new_root, "config.yaml", config)
def copy_files_rec(old_path: Path, new_path: Path):
def copy_files_rec(old_path: Path, new_path: Path) -> None:
if old_path.is_file():
print(f"{new_path}")
new_path.parent.mkdir(parents=True, exist_ok=True)
@ -171,7 +171,7 @@ def migrate_from(
new_root: Path,
manifest: List[str],
do_not_migrate_settings: List[str],
):
) -> int:
"""
Copy all the files in "manifest" to the new config directory.
"""
@ -193,7 +193,7 @@ def migrate_from(
with lock_and_load_config(new_root, "config.yaml") as config:
config_str: str = initial_config_file("config.yaml")
default_config: Dict = yaml.safe_load(config_str)
default_config: Dict[str, Any] = yaml.safe_load(config_str)
flattened_keys = unflatten_properties({k: "" for k in do_not_migrate_settings})
dict_add_new_default(config, default_config, flattened_keys)
@ -204,7 +204,7 @@ def migrate_from(
return 1
def copy_cert_files(cert_path: Path, new_path: Path):
def copy_cert_files(cert_path: Path, new_path: Path) -> None:
for old_path_child in cert_path.glob("*.crt"):
new_path_child = new_path / old_path_child.name
copy_files_rec(old_path_child, new_path_child)
@ -222,7 +222,7 @@ def init(
fix_ssl_permissions: bool = False,
testnet: bool = False,
v1_db: bool = False,
):
) -> Optional[int]:
if create_certs is not None:
if root_path.exists():
if os.path.isdir(create_certs):
@ -255,6 +255,8 @@ def init(
else:
return chia_init(root_path, fix_ssl_permissions=fix_ssl_permissions, testnet=testnet, v1_db=v1_db)
return None
def chia_version_number() -> Tuple[str, str, str, str]:
scm_full_version = __version__
@ -316,7 +318,7 @@ def chia_init(
fix_ssl_permissions: bool = False,
testnet: bool = False,
v1_db: bool = False,
):
) -> int:
"""
Standard first run initialization or migration steps. Handles config creation,
generation of SSL certs, and setting target addresses (via check_keys).
@ -383,7 +385,7 @@ def chia_init(
if should_check_keys:
check_keys(root_path)
config: Dict
config: Dict[str, Any]
db_path_replaced: str
if v1_db:

View File

@ -7,7 +7,7 @@ import click
@click.group("keys", short_help="Manage your keys")
@click.pass_context
def keys_cmd(ctx: click.Context):
def keys_cmd(ctx: click.Context) -> None:
"""Create, delete, view and use your key pairs"""
from pathlib import Path
@ -26,7 +26,7 @@ def keys_cmd(ctx: click.Context):
required=False,
)
@click.pass_context
def generate_cmd(ctx: click.Context, label: Optional[str]):
def generate_cmd(ctx: click.Context, label: Optional[str]) -> None:
from .init_funcs import check_keys
from .keys_funcs import generate_and_add
@ -66,7 +66,13 @@ def generate_cmd(ctx: click.Context, label: Optional[str]):
default=None,
)
@click.pass_context
def show_cmd(ctx: click.Context, show_mnemonic_seed, non_observer_derivation, json, fingerprint):
def show_cmd(
ctx: click.Context,
show_mnemonic_seed: bool,
non_observer_derivation: bool,
json: bool,
fingerprint: Optional[int],
) -> None:
from .keys_funcs import show_keys
show_keys(ctx.obj["root_path"], show_mnemonic_seed, non_observer_derivation, json, fingerprint)
@ -90,7 +96,7 @@ def show_cmd(ctx: click.Context, show_mnemonic_seed, non_observer_derivation, js
required=False,
)
@click.pass_context
def add_cmd(ctx: click.Context, filename: str, label: Optional[str]):
def add_cmd(ctx: click.Context, filename: str, label: Optional[str]) -> None:
from .init_funcs import check_keys
from .keys_funcs import query_and_add_private_key_seed
@ -105,12 +111,12 @@ def add_cmd(ctx: click.Context, filename: str, label: Optional[str]):
@keys_cmd.group("label", short_help="Manage your key labels")
def label_cmd():
def label_cmd() -> None:
pass
@label_cmd.command("show", short_help="Show the labels of all available keys")
def show_label_cmd():
def show_label_cmd() -> None:
from .keys_funcs import show_all_key_labels
show_all_key_labels()
@ -131,7 +137,7 @@ def show_label_cmd():
type=str,
required=True,
)
def set_label_cmd(fingerprint: int, label: str):
def set_label_cmd(fingerprint: int, label: str) -> None:
from .keys_funcs import set_key_label
set_key_label(fingerprint, label)
@ -145,7 +151,7 @@ def set_label_cmd(fingerprint: int, label: str):
type=int,
required=True,
)
def delete_label_cmd(fingerprint: int):
def delete_label_cmd(fingerprint: int) -> None:
from .keys_funcs import delete_key_label
delete_key_label(fingerprint)
@ -161,7 +167,7 @@ def delete_label_cmd(fingerprint: int):
required=True,
)
@click.pass_context
def delete_cmd(ctx: click.Context, fingerprint: int):
def delete_cmd(ctx: click.Context, fingerprint: int) -> None:
from .init_funcs import check_keys
from .keys_funcs import delete
@ -170,14 +176,14 @@ def delete_cmd(ctx: click.Context, fingerprint: int):
@keys_cmd.command("delete_all", short_help="Delete all private keys in keychain")
def delete_all_cmd():
def delete_all_cmd() -> None:
from chia.util.keychain import Keychain
Keychain().delete_all_keys()
@keys_cmd.command("generate_and_print", short_help="Generates but does NOT add to keychain")
def generate_and_print_cmd():
def generate_and_print_cmd() -> None:
from .keys_funcs import generate_and_print
generate_and_print()
@ -220,14 +226,14 @@ def generate_and_print_cmd():
)
def sign_cmd(
message: str, fingerprint: Optional[int], filename: Optional[str], hd_path: str, as_bytes: bool, json: bool
):
) -> None:
from .keys_funcs import resolve_derivation_master_key, sign
private_key = resolve_derivation_master_key(filename if filename is not None else fingerprint)
sign(message, private_key, hd_path, as_bytes, json)
def parse_signature_json(json_str: str):
def parse_signature_json(json_str: str) -> Tuple[str, str, str, str]:
import json
try:
@ -265,7 +271,7 @@ def parse_signature_json(json_str: str):
show_default=True,
type=str,
)
def verify_cmd(message: str, public_key: str, signature: str, as_bytes: bool, json: str):
def verify_cmd(message: str, public_key: str, signature: str, as_bytes: bool, json: str) -> None:
from .keys_funcs import as_bytes_from_signing_mode, verify
if json is not None:
@ -294,7 +300,7 @@ def verify_cmd(message: str, public_key: str, signature: str, as_bytes: bool, js
required=False,
)
@click.pass_context
def derive_cmd(ctx: click.Context, fingerprint: Optional[int], filename: Optional[str]):
def derive_cmd(ctx: click.Context, fingerprint: Optional[int], filename: Optional[str]) -> None:
ctx.obj["fingerprint"] = fingerprint
ctx.obj["filename"] = filename
@ -347,7 +353,7 @@ def search_cmd(
search_type: Tuple[str, ...],
derive_from_hd_path: Optional[str],
prefix: Optional[str],
):
) -> None:
import sys
from blspy import PrivateKey
@ -402,7 +408,7 @@ def search_cmd(
@click.pass_context
def wallet_address_cmd(
ctx: click.Context, index: int, count: int, prefix: Optional[str], non_observer_derivation: bool, show_hd_path: bool
):
) -> None:
from .keys_funcs import derive_wallet_address, resolve_derivation_master_key
fingerprint: Optional[int] = ctx.obj.get("fingerprint", None)
@ -467,7 +473,7 @@ def child_key_cmd(
non_observer_derivation: bool,
show_private_keys: bool,
show_hd_path: bool,
):
) -> None:
from .keys_funcs import derive_child_key, resolve_derivation_master_key
if key_type is None and derive_from_hd_path is None:

View File

@ -17,7 +17,7 @@ from chia.util.config import load_config
from chia.util.errors import KeychainException
from chia.util.file_keyring import MAX_LABEL_LENGTH
from chia.util.ints import uint32
from chia.util.keychain import Keychain, bytes_to_mnemonic, generate_mnemonic, mnemonic_to_seed
from chia.util.keychain import Keychain, KeyData, bytes_to_mnemonic, generate_mnemonic, mnemonic_to_seed
from chia.util.keyring_wrapper import KeyringWrapper
from chia.wallet.derive_keys import (
master_sk_to_farmer_sk,
@ -40,7 +40,7 @@ def unlock_keyring() -> None:
sys.exit(1)
def generate_and_print():
def generate_and_print() -> str:
"""
Generates a seed for a private key, and prints the mnemonic to the terminal.
"""
@ -52,7 +52,7 @@ def generate_and_print():
return mnemonic
def generate_and_add(label: Optional[str]):
def generate_and_add(label: Optional[str]) -> None:
"""
Generates a seed for a private key, prints the mnemonic to the terminal, and adds the key to the keyring.
"""
@ -61,7 +61,7 @@ def generate_and_add(label: Optional[str]):
query_and_add_private_key_seed(mnemonic=generate_mnemonic(), label=label)
def query_and_add_private_key_seed(mnemonic: Optional[str], label: Optional[str] = None):
def query_and_add_private_key_seed(mnemonic: Optional[str], label: Optional[str] = None) -> None:
unlock_keyring()
if mnemonic is None:
mnemonic = input("Enter the mnemonic you want to use: ")
@ -72,7 +72,7 @@ def query_and_add_private_key_seed(mnemonic: Optional[str], label: Optional[str]
add_private_key_seed(mnemonic, label)
def add_private_key_seed(mnemonic: str, label: Optional[str]):
def add_private_key_seed(mnemonic: str, label: Optional[str]) -> None:
"""
Add a private key seed to the keyring, with the given mnemonic and an optional label.
"""
@ -127,7 +127,7 @@ def delete_key_label(fingerprint: int) -> None:
def show_keys(
root_path: Path, show_mnemonic: bool, non_observer_derivation: bool, json_output: bool, fingerprint: Optional[int]
):
) -> None:
"""
Prints all keys and mnemonics (if available).
"""
@ -153,8 +153,8 @@ def show_keys(
msg = "Showing all public and private keys"
print(msg)
def process_key_data(key_data):
key = {}
def process_key_data(key_data: KeyData) -> Dict[str, Any]:
key: Dict[str, Any] = {}
sk = key_data.private_key
if key_data.label is not None:
key["label"] = key_data.label
@ -178,7 +178,7 @@ def show_keys(
key["mnemonic"] = bytes_to_mnemonic(key_data.entropy)
return key
keys = map(process_key_data, all_keys)
keys = [process_key_data(key) for key in all_keys]
if json_output:
print(json.dumps({"keys": list(keys)}))
@ -199,7 +199,7 @@ def show_keys(
print(key["mnemonic"])
def delete(fingerprint: int):
def delete(fingerprint: int) -> None:
"""
Delete a key by its public key fingerprint (which is an integer).
"""
@ -246,7 +246,7 @@ def derive_sk_from_hd_path(master_sk: PrivateKey, hd_path_root: str) -> Tuple[Pr
current_sk: PrivateKey = master_sk
# Derive keys along the path
for (current_index, derivation_type) in index_and_derivation_types:
for current_index, derivation_type in index_and_derivation_types:
if derivation_type == DerivationType.NONOBSERVER:
current_sk = _derive_path(current_sk, [current_index])
elif derivation_type == DerivationType.OBSERVER:
@ -257,7 +257,7 @@ def derive_sk_from_hd_path(master_sk: PrivateKey, hd_path_root: str) -> Tuple[Pr
return (current_sk, "m/" + "/".join(path) + "/")
def sign(message: str, private_key: PrivateKey, hd_path: str, as_bytes: bool, json_output: bool):
def sign(message: str, private_key: PrivateKey, hd_path: str, as_bytes: bool, json_output: bool) -> None:
sk: PrivateKey = derive_sk_from_hd_path(private_key, hd_path)[0]
data = bytes.fromhex(message) if as_bytes else bytes(message, "utf-8")
signing_mode: SigningMode = (
@ -283,7 +283,7 @@ def sign(message: str, private_key: PrivateKey, hd_path: str, as_bytes: bool, js
print(f"Signing Mode: {signing_mode.value}")
def verify(message: str, public_key: str, signature: str, as_bytes: bool):
def verify(message: str, public_key: str, signature: str, as_bytes: bool) -> None:
data = bytes.fromhex(message) if as_bytes else bytes(message, "utf-8")
public_key = G1Element.from_bytes(bytes.fromhex(public_key))
signature = G2Element.from_bytes(bytes.fromhex(signature))
@ -297,7 +297,7 @@ def as_bytes_from_signing_mode(signing_mode_str: str) -> bool:
return False
def _clear_line_part(n: int):
def _clear_line_part(n: int) -> None:
# Move backward, overwrite with spaces, then move backward again
sys.stdout.write("\b" * n)
sys.stdout.write(" " * n)
@ -384,7 +384,7 @@ def _search_derived(
if len(found_items) > 0 and show_progress:
print()
for (term, found_item, found_item_type) in found_items:
for term, found_item, found_item_type in found_items:
# Update remaining_search_terms and found_search_terms
del remaining_search_terms[term]
found_search_terms.append(term)
@ -440,7 +440,7 @@ def search_derive(
search_private_key = "private_key" in search_types
if prefix is None:
config: Dict = load_config(root_path, "config.yaml")
config: Dict[str, Any] = load_config(root_path, "config.yaml")
selected: str = config["selected_network"]
prefix = config["network_overrides"]["config"][selected]["address_prefix"]
@ -572,13 +572,13 @@ def derive_wallet_address(
prefix: Optional[str],
non_observer_derivation: bool,
show_hd_path: bool,
):
) -> None:
"""
Generate wallet addresses using keys derived from the provided private key.
"""
if prefix is None:
config: Dict = load_config(root_path, "config.yaml")
config: Dict[str, Any] = load_config(root_path, "config.yaml")
selected: str = config["selected_network"]
prefix = config["network_overrides"]["config"][selected]["address_prefix"]
path_indices: List[int] = [12381, 8444, 2]
@ -602,10 +602,10 @@ def derive_wallet_address(
print(f"Wallet address {i}: {address}")
def private_key_string_repr(private_key: PrivateKey):
def private_key_string_repr(private_key: PrivateKey) -> str:
"""Print a PrivateKey in a human-readable formats"""
s: str = str(private_key)
s = str(private_key)
return s[len("<PrivateKey ") : s.rfind(">")] if s.startswith("<PrivateKey ") else s
@ -618,7 +618,7 @@ def derive_child_key(
non_observer_derivation: bool,
show_private_keys: bool,
show_hd_path: bool,
):
) -> None:
"""
Derive child keys from the provided master key.
"""
@ -689,7 +689,7 @@ def private_key_for_fingerprint(fingerprint: int) -> Optional[PrivateKey]:
return None
def get_private_key_with_fingerprint_or_prompt(fingerprint: Optional[int]):
def get_private_key_with_fingerprint_or_prompt(fingerprint: Optional[int]) -> Optional[PrivateKey]:
"""
Get a private key with the specified fingerprint. If fingerprint is not
specified, prompt the user to select a key.

View File

@ -12,8 +12,7 @@ async def netstorge_async(rpc_port: Optional[int], delta_block_height: str, star
"""
Calculates the estimated space on the network given two block header hashes.
"""
client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port) as node_config_fp:
client, _, _ = node_config_fp
if client is not None:
if delta_block_height:

View File

@ -91,7 +91,7 @@ def verify_passphrase_meets_requirements(
def prompt_for_passphrase(prompt: str) -> str:
if sys.platform == "win32" or sys.platform == "cygwin":
print(prompt, end="")
print(prompt, end="", flush=True)
prompt = ""
return getpass(prompt)

View File

@ -20,9 +20,9 @@ from chia.cmds.peer_funcs import peer_async
default=None,
)
@click.option(
"-c", "--connections", help="List nodes connected to this Full Node", is_flag=True, type=bool, default=False
"-c", "--connections", help="List connections to the specified service", is_flag=True, type=bool, default=False
)
@click.option("-a", "--add-connection", help="Connect to another Full Node by ip:port", type=str, default="")
@click.option("-a", "--add-connection", help="Connect specified Chia service to ip:port", type=str, default="")
@click.option(
"-r", "--remove-connection", help="Remove a Node by the first 8 characters of NodeID", type=str, default=""
)

View File

@ -3,7 +3,7 @@ from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Optional
from chia.cmds.cmds_util import get_any_service_client
from chia.cmds.cmds_util import NODE_TYPES, get_any_service_client
from chia.rpc.rpc_client import RpcClient
@ -113,8 +113,8 @@ async def peer_async(
add_connection: str,
remove_connection: str,
) -> None:
rpc_client: Optional[RpcClient]
async with get_any_service_client(node_type, rpc_port, root_path) as node_config_fp:
client_type = NODE_TYPES[node_type]
async with get_any_service_client(client_type, rpc_port, root_path) as node_config_fp:
rpc_client, config, _ = node_config_fp
if rpc_client is not None:
# Check or edit node connections

View File

@ -10,7 +10,7 @@ from chia.cmds.cmds_util import execute_with_wallet
MAX_CMDLINE_FEE = Decimal(0.5)
def validate_fee(ctx, param, value):
def validate_fee(ctx: click.Context, param: click.Parameter, value: str) -> str:
try:
fee = Decimal(value)
except ValueError:
@ -34,7 +34,7 @@ def plotnft_cmd() -> None:
default=None,
)
@click.option("-i", "--id", help="ID of the wallet to use", type=int, default=None, show_default=True, required=False)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
def show_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int) -> None:
import asyncio
@ -57,7 +57,7 @@ def get_login_link_cmd(launcher_id: str) -> None:
@plotnft_cmd.command("create", short_help="Create a plot NFT")
@click.option("-y", "--yes", help="No prompts", is_flag=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-u", "--pool_url", help="HTTPS host:port of the pool to join", type=str, required=False)
@click.option("-s", "--state", help="Initial state of Plot NFT: local or pool", type=str, required=True)
@click.option(
@ -108,7 +108,7 @@ def create_cmd(
@plotnft_cmd.command("join", short_help="Join a plot NFT to a Pool")
@click.option("-y", "--yes", help="No prompts", is_flag=True)
@click.option("-i", "--id", help="ID of the wallet to use", type=int, default=None, show_default=True, required=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-u", "--pool_url", help="HTTPS host:port of the pool to join", type=str, required=True)
@click.option(
"-m",
@ -139,7 +139,7 @@ def join_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int, fee: int
@plotnft_cmd.command("leave", short_help="Leave a pool and return to self-farming")
@click.option("-y", "--yes", help="No prompts", is_flag=True)
@click.option("-i", "--id", help="ID of the wallet to use", type=int, default=None, show_default=True, required=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-m",
"--fee",
@ -168,7 +168,7 @@ def self_pool_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int, fee
@plotnft_cmd.command("inspect", short_help="Get Detailed plotnft information as JSON")
@click.option("-i", "--id", help="ID of the wallet to use", type=int, default=None, show_default=True, required=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-wp",
"--wallet-rpc-port",
@ -187,7 +187,7 @@ def inspect(wallet_rpc_port: Optional[int], fingerprint: int, id: int) -> None:
@plotnft_cmd.command("claim", short_help="Claim rewards from a plot NFT")
@click.option("-i", "--id", help="ID of the wallet to use", type=int, default=None, show_default=True, required=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-m",
"--fee",

View File

@ -7,7 +7,7 @@ import time
from dataclasses import replace
from decimal import Decimal
from pprint import pprint
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Awaitable, Callable, Dict, List, Optional
import aiohttp
@ -31,12 +31,12 @@ from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.util.wallet_types import WalletType
async def create_pool_args(pool_url: str) -> Dict:
async def create_pool_args(pool_url: str) -> Dict[str, Any]:
try:
async with aiohttp.ClientSession() as session:
async with session.get(f"{pool_url}/pool_info", ssl=ssl_context_for_root(get_mozilla_ca_crt())) as response:
if response.ok:
json_dict = json.loads(await response.text())
json_dict: Dict[str, Any] = json.loads(await response.text())
else:
raise ValueError(f"Response from {pool_url} not OK: {response.status}")
except Exception as e:
@ -54,7 +54,7 @@ async def create_pool_args(pool_url: str) -> Dict:
return json_dict
async def create(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
async def create(args: Dict[str, Any], wallet_client: WalletRpcClient, fingerprint: int) -> None:
state = args["state"]
prompt = not args.get("yes", False)
fee = Decimal(args.get("fee", 0))
@ -116,8 +116,8 @@ async def pprint_pool_wallet_state(
pool_wallet_info: PoolWalletInfo,
address_prefix: str,
pool_state_dict: Optional[Dict[str, Any]],
):
if pool_wallet_info.current.state == PoolSingletonState.LEAVING_POOL and pool_wallet_info.target is None:
) -> None:
if pool_wallet_info.current.state == PoolSingletonState.LEAVING_POOL.value and pool_wallet_info.target is None:
expected_leave_height = pool_wallet_info.singleton_block_height + pool_wallet_info.current.relative_lock_height
print(f"Current state: INVALID_STATE. Please leave/join again after block height {expected_leave_height}")
else:
@ -139,12 +139,12 @@ async def pprint_pool_wallet_state(
print(f"Target state: {PoolSingletonState(pool_wallet_info.target.state).name}")
print(f"Target pool URL: {pool_wallet_info.target.pool_url}")
if pool_wallet_info.current.state == PoolSingletonState.SELF_POOLING.value:
balances: Dict = await wallet_client.get_wallet_balance(wallet_id)
balances: Dict[str, Any] = await wallet_client.get_wallet_balance(wallet_id)
balance = balances["confirmed_wallet_balance"]
typ = WalletType(int(WalletType.POOLING_WALLET))
address_prefix, scale = wallet_coin_unit(typ, address_prefix)
print(f"Claimable balance: {print_balance(balance, scale, address_prefix)}")
if pool_wallet_info.current.state == PoolSingletonState.FARMING_TO_POOL:
if pool_wallet_info.current.state == PoolSingletonState.FARMING_TO_POOL.value:
print(f"Current pool URL: {pool_wallet_info.current.pool_url}")
if pool_state_dict is not None:
print(f"Current difficulty: {pool_state_dict['current_difficulty']}")
@ -166,22 +166,21 @@ async def pprint_pool_wallet_state(
except Exception:
print(f"Payout instructions (pool will pay you with this): {payout_instructions}")
print(f"Relative lock height: {pool_wallet_info.current.relative_lock_height} blocks")
if pool_wallet_info.current.state == PoolSingletonState.LEAVING_POOL:
if pool_wallet_info.current.state == PoolSingletonState.LEAVING_POOL.value:
expected_leave_height = pool_wallet_info.singleton_block_height + pool_wallet_info.current.relative_lock_height
if pool_wallet_info.target is not None:
print(f"Expected to leave after block height: {expected_leave_height}")
async def show(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer") as node_config_fp:
async def show(args: Dict[str, Any], wallet_client: WalletRpcClient, fingerprint: int) -> None:
async with get_any_service_client(FarmerRpcClient) as node_config_fp:
farmer_client, config, _ = node_config_fp
if farmer_client is not None:
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
summaries_response = await wallet_client.get_wallets()
wallet_id_passed_in = args.get("id", None)
pool_state_list = (await farmer_client.get_pool_state())["pool_state"]
pool_state_dict: Dict[bytes32, Dict] = {
pool_state_dict: Dict[bytes32, Dict[str, Any]] = {
bytes32.from_hexstr(pool_state_item["pool_config"]["launcher_id"]): pool_state_item
for pool_state_item in pool_state_list
}
@ -223,8 +222,7 @@ async def show(args: dict, wallet_client: WalletRpcClient, fingerprint: int) ->
async def get_login_link(launcher_id_str: str) -> None:
launcher_id: bytes32 = bytes32.from_hexstr(launcher_id_str)
farmer_client: Optional[FarmerRpcClient]
async with get_any_service_client("farmer") as node_config_fp:
async with get_any_service_client(FarmerRpcClient) as node_config_fp:
farmer_client, _, _ = node_config_fp
if farmer_client is not None:
login_link: Optional[str] = await farmer_client.get_pool_login_link(launcher_id)
@ -235,8 +233,13 @@ async def get_login_link(launcher_id_str: str) -> None:
async def submit_tx_with_confirmation(
message: str, prompt: bool, func: Callable, wallet_client: WalletRpcClient, fingerprint: int, wallet_id: int
):
message: str,
prompt: bool,
func: Callable[[], Awaitable[Dict[str, Any]]],
wallet_client: WalletRpcClient,
fingerprint: int,
wallet_id: int,
) -> None:
print(message)
if prompt:
user_input: str = input("Confirm [n]/y: ")
@ -245,7 +248,7 @@ async def submit_tx_with_confirmation(
if user_input.lower() == "y" or user_input.lower() == "yes":
try:
result: Dict = await func()
result = await func()
tx_record: TransactionRecord = result["transaction"]
start = time.time()
while time.time() - start < 10:
@ -261,7 +264,7 @@ async def submit_tx_with_confirmation(
print("Aborting.")
async def join_pool(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
async def join_pool(args: Dict[str, Any], wallet_client: WalletRpcClient, fingerprint: int) -> None:
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
enforce_https = config["full_node"]["selected_network"] == "mainnet"
pool_url: str = args["pool_url"]
@ -306,7 +309,7 @@ async def join_pool(args: dict, wallet_client: WalletRpcClient, fingerprint: int
await submit_tx_with_confirmation(msg, prompt, func, wallet_client, fingerprint, wallet_id)
async def self_pool(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
async def self_pool(args: Dict[str, Any], wallet_client: WalletRpcClient, fingerprint: int) -> None:
wallet_id = args.get("id", None)
prompt = not args.get("yes", False)
fee = Decimal(args.get("fee", 0))
@ -317,7 +320,7 @@ async def self_pool(args: dict, wallet_client: WalletRpcClient, fingerprint: int
await submit_tx_with_confirmation(msg, prompt, func, wallet_client, fingerprint, wallet_id)
async def inspect_cmd(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
async def inspect_cmd(args: Dict[str, Any], wallet_client: WalletRpcClient, fingerprint: int) -> None:
wallet_id = args.get("id", None)
pool_wallet_info, unconfirmed_transactions = await wallet_client.pw_status(wallet_id)
print(
@ -330,7 +333,7 @@ async def inspect_cmd(args: dict, wallet_client: WalletRpcClient, fingerprint: i
)
async def claim_cmd(args: dict, wallet_client: WalletRpcClient, fingerprint: int) -> None:
async def claim_cmd(args: Dict[str, Any], wallet_client: WalletRpcClient, fingerprint: int) -> None:
wallet_id = args.get("id", None)
fee = Decimal(args.get("fee", 0))
fee_mojos = uint64(int(fee * units["chia"]))

View File

@ -9,7 +9,6 @@ import click
from chia.plotting.util import add_plot_directory, validate_plot_size
DEFAULT_STRIPE_SIZE = 65536
log = logging.getLogger(__name__)
@ -117,21 +116,21 @@ def create_cmd(
connect_to_daemon: bool,
):
from chia.plotting.create_plots import create_plots, resolve_plot_keys
from chia.plotting.util import Params
class Params(object):
def __init__(self):
self.size = size
self.num = num
self.buffer = buffer
self.num_threads = num_threads
self.buckets = buckets
self.stripe_size = DEFAULT_STRIPE_SIZE
self.tmp_dir = Path(tmp_dir)
self.tmp2_dir = Path(tmp2_dir) if tmp2_dir else None
self.final_dir = Path(final_dir)
self.plotid = plotid
self.memo = memo
self.nobitfield = nobitfield
params = Params(
size=size,
num=num,
buffer=buffer,
num_threads=num_threads,
buckets=buckets,
tmp_dir=Path(tmp_dir),
tmp2_dir=Path(tmp2_dir) if tmp2_dir else None,
final_dir=Path(final_dir),
plotid=plotid,
memo=memo,
nobitfield=nobitfield,
)
root_path: Path = ctx.obj["root_path"]
try:
@ -152,7 +151,7 @@ def create_cmd(
)
)
asyncio.run(create_plots(Params(), plot_keys))
asyncio.run(create_plots(params, plot_keys))
if not exclude_final_dir:
try:
add_plot_directory(root_path, final_dir)

View File

@ -181,7 +181,7 @@ async def print_fee_info(node_client: FullNodeRpcClient) -> None:
print("\nFee Rate Estimates:")
max_name_len = max(len(name) for name in target_times_names)
for (n, e) in zip(target_times_names, res["estimates"]):
for n, e in zip(target_times_names, res["estimates"]):
print(f" {n:>{max_name_len}}: {e:.3f} mojo per CLVM cost")
print("")
@ -196,8 +196,7 @@ async def show_async(
) -> None:
from chia.cmds.cmds_util import get_any_service_client
node_client: Optional[FullNodeRpcClient]
async with get_any_service_client("full_node", rpc_port, root_path) as node_config_fp:
async with get_any_service_client(FullNodeRpcClient, rpc_port, root_path) as node_config_fp:
node_client, config, _ = node_config_fp
if node_client is not None:
# Check State
@ -210,7 +209,7 @@ async def show_async(
if block_header_hash_by_height != "":
block_header = await node_client.get_block_record_by_height(block_header_hash_by_height)
if block_header is not None:
print(f"Header hash of block {block_header_hash_by_height}: " f"{block_header.header_hash.hex()}")
print(f"Header hash of block {block_header_hash_by_height}: {block_header.header_hash.hex()}")
else:
print("Block height", block_header_hash_by_height, "not found")
if block_by_header_hash != "":

View File

@ -5,6 +5,7 @@ from typing import Any, Dict, List, Optional, Tuple
import click
from chia.cmds.check_wallet_db import help_text as check_help_text
from chia.cmds.cmds_util import execute_with_wallet
from chia.cmds.coins import coins_cmd
from chia.cmds.plotnft import validate_fee
@ -27,7 +28,7 @@ def wallet_cmd(ctx: click.Context) -> None:
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option("-tx", "--tx_id", help="transaction id to search for", type=str, required=True)
@click.option("--verbose", "-v", count=True, type=int)
@ -48,7 +49,7 @@ def get_transaction_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: in
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option(
"-o",
@ -139,7 +140,7 @@ def get_transactions_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option("-a", "--amount", help="How much chia to send, in XCH", type=str, required=True)
@click.option("-e", "--memo", help="Additional memo for the transaction", type=str, default=None)
@ -178,6 +179,13 @@ def get_transactions_cmd(
multiple=True,
help="Exclude this coin from being spent.",
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the change.",
is_flag=True,
default=False,
)
def send_cmd(
wallet_rpc_port: Optional[int],
fingerprint: int,
@ -190,6 +198,7 @@ def send_cmd(
min_coin_amount: str,
max_coin_amount: str,
coins_to_exclude: Tuple[str],
reuse: bool,
) -> None:
extra_params = {
"id": id,
@ -201,6 +210,7 @@ def send_cmd(
"min_coin_amount": min_coin_amount,
"max_coin_amount": max_coin_amount,
"exclude_coin_ids": list(coins_to_exclude),
"reuse_puzhash": True if reuse else None,
}
import asyncio
@ -217,7 +227,7 @@ def send_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-w",
"--wallet_type",
@ -245,7 +255,7 @@ def show_cmd(wallet_rpc_port: Optional[int], fingerprint: int, wallet_type: Opti
default=None,
)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-n/-l",
"--new-address/--latest-address",
@ -276,7 +286,7 @@ def get_address_cmd(wallet_rpc_port: Optional[int], id, fingerprint: int, new_ad
default=None,
)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, default=1, show_default=True, required=True)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
def delete_unconfirmed_transactions_cmd(wallet_rpc_port: Optional[int], id, fingerprint: int) -> None:
extra_params = {"id": id}
import asyncio
@ -294,7 +304,7 @@ def delete_unconfirmed_transactions_cmd(wallet_rpc_port: Optional[int], id, fing
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
def get_derivation_index_cmd(wallet_rpc_port: Optional[int], fingerprint: int) -> None:
extra_params: Dict[str, Any] = {}
import asyncio
@ -312,7 +322,7 @@ def get_derivation_index_cmd(wallet_rpc_port: Optional[int], fingerprint: int) -
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-a", "--address", help="The address you want to use for signing", type=str, required=True)
@click.option("-m", "--hex_message", help="The hex message you want sign", type=str, required=True)
def address_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, address: str, hex_message: str) -> None:
@ -334,7 +344,7 @@ def address_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, addre
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-i", "--index", help="Index to set. Must be greater than the current derivation index", type=int, required=True
)
@ -390,7 +400,7 @@ def add_token_cmd(wallet_rpc_port: Optional[int], asset_id: str, token_name: str
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option(
"-o",
"--offer",
@ -409,10 +419,29 @@ def add_token_cmd(wallet_rpc_port: Optional[int], asset_id: str, token_name: str
@click.option(
"-m", "--fee", help="A fee to add to the offer when it gets taken, in XCH", default="0", show_default=True
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the offer.",
is_flag=True,
default=False,
)
def make_offer_cmd(
wallet_rpc_port: Optional[int], fingerprint: int, offer: Tuple[str], request: Tuple[str], filepath: str, fee: str
wallet_rpc_port: Optional[int],
fingerprint: int,
offer: Tuple[str],
request: Tuple[str],
filepath: str,
fee: str,
reuse: bool,
) -> None:
extra_params = {"offers": offer, "requests": request, "filepath": filepath, "fee": fee}
extra_params = {
"offers": offer,
"requests": request,
"filepath": filepath,
"fee": fee,
"reuse_puzhash": True if reuse else None,
}
import asyncio
from .wallet_funcs import make_offer
@ -430,7 +459,7 @@ def make_offer_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-id", "--id", help="The ID of the offer that you wish to examine")
@click.option("-p", "--filepath", help="The path to rewrite the offer file to (must be used in conjunction with --id)")
@click.option("-em", "--exclude-my-offers", help="Exclude your own offers from the output", is_flag=True)
@ -476,15 +505,32 @@ def get_offers_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-e", "--examine-only", help="Print the summary of the offer file but do not take it", is_flag=True)
@click.option(
"-m", "--fee", help="The fee to use when pushing the completed offer, in XCH", default="0", show_default=True
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the offer.",
is_flag=True,
default=False,
)
def take_offer_cmd(
path_or_hex: str, wallet_rpc_port: Optional[int], fingerprint: int, examine_only: bool, fee: str
path_or_hex: str,
wallet_rpc_port: Optional[int],
fingerprint: int,
examine_only: bool,
fee: str,
reuse: bool,
) -> None:
extra_params = {"file": path_or_hex, "examine_only": examine_only, "fee": fee}
extra_params = {
"file": path_or_hex,
"examine_only": examine_only,
"fee": fee,
"reuse_puzhash": True if reuse else None,
}
import asyncio
from .wallet_funcs import take_offer
@ -500,7 +546,7 @@ def take_offer_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-id", "--id", help="The offer ID that you wish to cancel", required=True)
@click.option("--insecure", help="Don't make an on-chain transaction, simply mark the offer as cancelled", is_flag=True)
@click.option(
@ -515,6 +561,21 @@ def cancel_offer_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: str,
asyncio.run(execute_with_wallet(wallet_rpc_port, fingerprint, extra_params, cancel_offer))
@wallet_cmd.command("check", short_help="Check wallet DB integrity", help=check_help_text)
@click.option("-v", "--verbose", help="Print more information", is_flag=True)
@click.option("--db-path", help="The path to a wallet DB. Default is to scan all active wallet DBs.")
@click.pass_context
# TODO: accept multiple dbs on commandline
# TODO: Convert to Path earlier
def check_wallet_cmd(ctx: click.Context, db_path: str, verbose: bool) -> None:
"""check, scan, diagnose, fsck Chia Wallet DBs"""
import asyncio
from chia.cmds.check_wallet_db import scan
asyncio.run(scan(ctx.obj["root_path"], db_path, verbose=verbose))
@wallet_cmd.group("did", short_help="DID related actions")
def did_cmd():
pass
@ -528,7 +589,7 @@ def did_cmd():
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-n", "--name", help="Set the DID wallet name", type=str)
@click.option(
"-a",
@ -566,7 +627,7 @@ def did_create_wallet_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--did_id", help="DID ID you want to use for signing", type=str, required=True)
@click.option("-m", "--hex_message", help="The hex message you want to sign", type=str, required=True)
def did_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, did_id: str, hex_message: str) -> None:
@ -586,7 +647,7 @@ def did_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, did_id: s
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, required=True)
@click.option("-n", "--name", help="Set the DID wallet name", type=str, required=True)
def did_wallet_name_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int, name: str) -> None:
@ -606,7 +667,7 @@ def did_wallet_name_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: in
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the wallet to use", type=int, required=True)
def did_get_did_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int) -> None:
import asyncio
@ -630,7 +691,7 @@ def nft_cmd():
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-di", "--did-id", help="DID Id to use", type=str)
@click.option("-n", "--name", help="Set the NFT wallet name", type=str)
def nft_wallet_create_cmd(
@ -652,7 +713,7 @@ def nft_wallet_create_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--nft_id", help="NFT ID you want to use for signing", type=str, required=True)
@click.option("-m", "--hex_message", help="The hex message you want to sign", type=str, required=True)
def nft_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, nft_id: str, hex_message: str) -> None:
@ -672,7 +733,7 @@ def nft_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, nft_id: s
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the NFT wallet to use", type=int, required=True)
@click.option("-ra", "--royalty-address", help="Royalty address", type=str)
@click.option("-ta", "--target-address", help="Target address", type=str)
@ -702,6 +763,13 @@ def nft_sign_message(wallet_rpc_port: Optional[int], fingerprint: int, nft_id: s
default=0,
show_default=True,
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the change.",
is_flag=True,
default=False,
)
def nft_mint_cmd(
wallet_rpc_port: Optional[int],
fingerprint: int,
@ -719,6 +787,7 @@ def nft_mint_cmd(
edition_number: Optional[int],
fee: str,
royalty_percentage_fraction: int,
reuse: bool,
) -> None:
import asyncio
@ -749,6 +818,7 @@ def nft_mint_cmd(
"edition_number": edition_number,
"fee": fee,
"royalty_percentage": royalty_percentage_fraction,
"reuse_puzhash": True if reuse else None,
}
asyncio.run(execute_with_wallet(wallet_rpc_port, fingerprint, extra_params, mint_nft))
@ -761,7 +831,7 @@ def nft_mint_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the NFT wallet to use", type=int, required=True)
@click.option("-ni", "--nft-coin-id", help="Id of the NFT coin to add the URI to", type=str, required=True)
@click.option("-u", "--uri", help="URI to add to the NFT", type=str)
@ -776,6 +846,13 @@ def nft_mint_cmd(
show_default=True,
callback=validate_fee,
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the change.",
is_flag=True,
default=False,
)
def nft_add_uri_cmd(
wallet_rpc_port: Optional[int],
fingerprint: int,
@ -785,6 +862,7 @@ def nft_add_uri_cmd(
metadata_uri: str,
license_uri: str,
fee: str,
reuse: bool,
) -> None:
import asyncio
@ -797,6 +875,7 @@ def nft_add_uri_cmd(
"metadata_uri": metadata_uri,
"license_uri": license_uri,
"fee": fee,
"reuse_puzhash": True if reuse else None,
}
asyncio.run(execute_with_wallet(wallet_rpc_port, fingerprint, extra_params, add_uri_to_nft))
@ -809,7 +888,7 @@ def nft_add_uri_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the NFT wallet to use", type=int, required=True)
@click.option("-ni", "--nft-coin-id", help="Id of the NFT coin to transfer", type=str, required=True)
@click.option("-ta", "--target-address", help="Target recipient wallet address", type=str, required=True)
@ -822,6 +901,13 @@ def nft_add_uri_cmd(
show_default=True,
callback=validate_fee,
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the change.",
is_flag=True,
default=False,
)
def nft_transfer_cmd(
wallet_rpc_port: Optional[int],
fingerprint: int,
@ -829,6 +915,7 @@ def nft_transfer_cmd(
nft_coin_id: str,
target_address: str,
fee: str,
reuse: bool,
) -> None:
import asyncio
@ -839,6 +926,7 @@ def nft_transfer_cmd(
"nft_coin_id": nft_coin_id,
"target_address": target_address,
"fee": fee,
"reuse_puzhash": True if reuse else None,
}
asyncio.run(execute_with_wallet(wallet_rpc_port, fingerprint, extra_params, transfer_nft))
@ -851,7 +939,7 @@ def nft_transfer_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the NFT wallet to use", type=int, required=True)
def nft_list_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int) -> None:
import asyncio
@ -870,7 +958,7 @@ def nft_list_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int) -> N
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="Id of the NFT wallet to use", type=int, required=True)
@click.option("-di", "--did-id", help="DID Id to set on the NFT", type=str, required=True)
@click.option("-ni", "--nft-coin-id", help="Id of the NFT coin to set the DID on", type=str, required=True)
@ -883,6 +971,13 @@ def nft_list_cmd(wallet_rpc_port: Optional[int], fingerprint: int, id: int) -> N
show_default=True,
callback=validate_fee,
)
@click.option(
"-r",
"--reuse",
help="Reuse existing address for the change.",
is_flag=True,
default=False,
)
def nft_set_did_cmd(
wallet_rpc_port: Optional[int],
fingerprint: int,
@ -890,6 +985,7 @@ def nft_set_did_cmd(
did_id: str,
nft_coin_id: str,
fee: str,
reuse: bool,
) -> None:
import asyncio
@ -900,6 +996,7 @@ def nft_set_did_cmd(
"did_id": did_id,
"nft_coin_id": nft_coin_id,
"fee": fee,
"reuse_puzhash": True if reuse else None,
}
asyncio.run(execute_with_wallet(wallet_rpc_port, fingerprint, extra_params, set_nft_did))
@ -912,7 +1009,7 @@ def nft_set_did_cmd(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-ni", "--nft-coin-id", help="Id of the NFT coin to get information on", type=str, required=True)
def nft_get_info_cmd(
wallet_rpc_port: Optional[int],
@ -946,7 +1043,7 @@ def notification_cmd():
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-t", "--to-address", help="The address to send the notification to", type=str, required=True)
@click.option(
"-a",
@ -990,7 +1087,7 @@ def _send_notification(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="The specific notification ID to show", type=str, default=[], multiple=True)
@click.option("-s", "--start", help="The number of notifications to skip", type=int, default=None)
@click.option("-e", "--end", help="The number of notifications to stop at", type=int, default=None)
@ -1023,7 +1120,7 @@ def _get_notifications(
type=int,
default=None,
)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which wallet to use", type=int)
@click.option("-f", "--fingerprint", help="Set the fingerprint to specify which key to use", type=int)
@click.option("-i", "--id", help="A specific notification ID to delete", type=str, multiple=True)
@click.option("--all", help="All notifications can be deleted (they will be recovered during resync)", is_flag=True)
def _delete_notifications(

View File

@ -199,6 +199,7 @@ async def send(args: dict, wallet_client: WalletRpcClient, fingerprint: int) ->
max_coin_amount = Decimal(args["max_coin_amount"])
exclude_coin_ids: List[str] = args["exclude_coin_ids"]
memo = args["memo"]
reuse_puzhash = args["reuse_puzhash"]
if memo is None:
memos = None
else:
@ -236,6 +237,7 @@ async def send(args: dict, wallet_client: WalletRpcClient, fingerprint: int) ->
final_min_coin_amount,
final_max_coin_amount,
exclude_coin_ids=exclude_coin_ids,
reuse_puzhash=reuse_puzhash,
)
elif typ == WalletType.CAT:
print("Submitting transaction...")
@ -248,6 +250,7 @@ async def send(args: dict, wallet_client: WalletRpcClient, fingerprint: int) ->
final_min_coin_amount,
final_max_coin_amount,
exclude_coin_ids=exclude_coin_ids,
reuse_puzhash=reuse_puzhash,
)
else:
print("Only standard wallet and CAT wallets are supported")
@ -320,6 +323,7 @@ async def make_offer(args: dict, wallet_client: WalletRpcClient, fingerprint: in
requests: List[str] = args["requests"]
filepath: str = args["filepath"]
fee: int = int(Decimal(args["fee"]) * units["chia"])
reuse_puzhash: Optional[bool] = args["reuse_puzhash"]
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
if [] in [offers, requests]:
@ -459,7 +463,7 @@ async def make_offer(args: dict, wallet_client: WalletRpcClient, fingerprint: in
print("Not creating offer...")
else:
offer, trade_record = await wallet_client.create_offer_for_ids(
offer_dict, driver_dict=driver_dict, fee=fee
offer_dict, driver_dict=driver_dict, fee=fee, reuse_puzhash=reuse_puzhash
)
if offer is not None:
with open(pathlib.Path(filepath), "w") as file:
@ -520,7 +524,7 @@ async def print_trade_record(record, wallet_client: WalletRpcClient, summaries:
offer = Offer.from_bytes(record.offer)
offered, requested, _ = offer.summary()
outbound_balances: Dict[str, int] = offer.get_pending_amounts()
fees: Decimal = Decimal(offer.bundle.fees())
fees: Decimal = Decimal(offer.fees())
cat_name_resolver = wallet_client.cat_asset_id_to_name
print(" OFFERED:")
await print_offer_summary(cat_name_resolver, offered)
@ -678,7 +682,7 @@ async def take_offer(args: dict, wallet_client: WalletRpcClient, fingerprint: in
converted_amount = Decimal(amount) / divisor
print(f" - {converted_amount} {asset} ({amount} mojos)")
print(f"Included Fees: {Decimal(offer.bundle.fees()) / units['chia']} XCH, {offer.bundle.fees()} mojos")
print(f"Included Fees: {Decimal(offer.fees()) / units['chia']} XCH, {offer.fees()} mojos")
if not examine_only:
print()
@ -767,7 +771,7 @@ async def print_balances(args: dict, wallet_client: WalletRpcClient, fingerprint
print()
print(f"{summary['name']}:")
print(f"{indent}{'-Total Balance:'.ljust(23)} {total_balance}")
print(f"{indent}{'-Pending Total Balance:'.ljust(23)} " f"{unconfirmed_wallet_balance}")
print(f"{indent}{'-Pending Total Balance:'.ljust(23)} {unconfirmed_wallet_balance}")
print(f"{indent}{'-Spendable:'.ljust(23)} {spendable_balance}")
print(f"{indent}{'-Type:'.ljust(23)} {typ.name}")
if typ == WalletType.DECENTRALIZED_ID:
@ -889,6 +893,7 @@ async def mint_nft(args: Dict, wallet_client: WalletRpcClient, fingerprint: int)
fee,
royalty_percentage,
did_id,
reuse_puzhash=args["reuse_puzhash"],
)
spend_bundle = response["spend_bundle"]
print(f"NFT minted Successfully with spend bundle: {spend_bundle}")
@ -917,7 +922,9 @@ async def add_uri_to_nft(args: Dict, wallet_client: WalletRpcClient, fingerprint
else:
raise ValueError("You must provide at least one of the URI flags")
fee: int = int(Decimal(args["fee"]) * units["chia"])
response = await wallet_client.add_uri_to_nft(wallet_id, nft_coin_id, key, uri_value, fee)
response = await wallet_client.add_uri_to_nft(
wallet_id, nft_coin_id, key, uri_value, fee, args["reuse_puzhash"]
)
spend_bundle = response["spend_bundle"]
print(f"URI added successfully with spend bundle: {spend_bundle}")
except Exception as e:
@ -931,7 +938,9 @@ async def transfer_nft(args: Dict, wallet_client: WalletRpcClient, fingerprint:
config = load_config(DEFAULT_ROOT_PATH, "config.yaml")
target_address = ensure_valid_address(args["target_address"], allowed_types={AddressType.XCH}, config=config)
fee: int = int(Decimal(args["fee"]) * units["chia"])
response = await wallet_client.transfer_nft(wallet_id, nft_coin_id, target_address, fee)
response = await wallet_client.transfer_nft(
wallet_id, nft_coin_id, target_address, fee, reuse_puzhash=args["reuse_puzhash"]
)
spend_bundle = response["spend_bundle"]
print(f"NFT transferred successfully with spend bundle: {spend_bundle}")
except Exception as e:
@ -997,7 +1006,9 @@ async def set_nft_did(args: Dict, wallet_client: WalletRpcClient, fingerprint: i
nft_coin_id = args["nft_coin_id"]
fee: int = int(Decimal(args["fee"]) * units["chia"])
try:
response = await wallet_client.set_nft_did(wallet_id, did_id, nft_coin_id, fee)
response = await wallet_client.set_nft_did(
wallet_id, did_id, nft_coin_id, fee, reuse_puzhash=args["reuse_puzhash"]
)
spend_bundle = response["spend_bundle"]
print(f"Transaction to set DID on NFT has been initiated with: {spend_bundle}")
except Exception as e:

View File

@ -59,6 +59,7 @@ async def validate_block_body(
if isinstance(block, FullBlock):
assert height == block.height
prev_transaction_block_height: uint32 = uint32(0)
prev_transaction_block_timestamp: uint64 = uint64(0)
# 1. For non transaction-blocs: foliage block, transaction filter, transactions info, and generator must
# be empty. If it is a block but not a transaction block, there is no body to validate. Check that all fields are
@ -103,6 +104,8 @@ async def validate_block_body(
# Add reward claims for all blocks from the prev prev block, until the prev block (including the latter)
prev_transaction_block = blocks.block_record(block.foliage_transaction_block.prev_transaction_block_hash)
prev_transaction_block_height = prev_transaction_block.height
assert prev_transaction_block.timestamp
prev_transaction_block_timestamp = prev_transaction_block.timestamp
assert prev_transaction_block.fees is not None
pool_coin = create_pool_coin(
prev_transaction_block_height,
@ -316,9 +319,9 @@ async def validate_block_body(
curr_npc_result = get_name_puzzle_conditions(
curr_block_generator,
min(constants.MAX_BLOCK_COST_CLVM, curr.transactions_info.cost),
cost_per_byte=constants.COST_PER_BYTE,
mempool_mode=False,
height=curr.height,
constants=constants,
)
removals_in_curr, additions_in_curr = tx_removals_and_additions(curr_npc_result.conds)
else:
@ -456,11 +459,18 @@ async def validate_block_body(
# verify absolute/relative height/time conditions
if npc_result is not None:
assert npc_result.conds is not None
block_timestamp: uint64
if height < constants.SOFT_FORK2_HEIGHT:
block_timestamp = block.foliage_transaction_block.timestamp
else:
block_timestamp = prev_transaction_block_timestamp
error = mempool_check_time_locks(
removal_coin_records,
npc_result.conds,
prev_transaction_block_height,
block.foliage_transaction_block.timestamp,
block_timestamp,
)
if error:
return error, None
@ -470,7 +480,9 @@ async def validate_block_body(
pairs_msgs: List[bytes] = []
if npc_result:
assert npc_result.conds is not None
pairs_pks, pairs_msgs = pkm_pairs(npc_result.conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
pairs_pks, pairs_msgs = pkm_pairs(
npc_result.conds, constants.AGG_SIG_ME_ADDITIONAL_DATA, soft_fork=height >= constants.SOFT_FORK_HEIGHT
)
# 22. Verify aggregated signature
# TODO: move this to pre_validate_blocks_multiprocessing so we can sync faster

View File

@ -131,10 +131,7 @@ def create_foliage(
if block_generator is not None:
generator_block_heights_list = block_generator.block_height_list
result: NPCResult = get_name_puzzle_conditions(
block_generator,
constants.MAX_BLOCK_COST_CLVM,
cost_per_byte=constants.COST_PER_BYTE,
mempool_mode=True,
block_generator, constants.MAX_BLOCK_COST_CLVM, mempool_mode=True, height=height
)
cost = result.cost
@ -388,7 +385,7 @@ def create_unfinished_block(
additions = []
if removals is None:
removals = []
(foliage, foliage_transaction_block, transactions_info,) = create_foliage(
(foliage, foliage_transaction_block, transactions_info) = create_foliage(
constants,
rc_block,
block_generator,

View File

@ -894,7 +894,7 @@ def validate_finished_header_block(
# 27b. Check genesis block height, weight, and prev block hash
if header_block.height != uint32(0):
return None, ValidationError(Err.INVALID_HEIGHT)
if header_block.weight != constants.DIFFICULTY_STARTING:
if header_block.weight != uint128(constants.DIFFICULTY_STARTING):
return None, ValidationError(Err.INVALID_WEIGHT)
if header_block.prev_header_hash != constants.GENESIS_CHALLENGE:
return None, ValidationError(Err.INVALID_PREV_BLOCK_HASH)

View File

@ -3,6 +3,8 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional
from typing_extensions import Protocol
from chia.consensus.constants import ConsensusConstants
from chia.consensus.pot_iterations import calculate_ip_iters, calculate_sp_iters
from chia.types.blockchain_format.classgroup import ClassgroupElement
@ -13,6 +15,32 @@ from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.streamable import Streamable, streamable
class BlockRecordProtocol(Protocol):
@property
def header_hash(self) -> bytes32:
...
@property
def height(self) -> uint32:
...
@property
def timestamp(self) -> Optional[uint64]:
...
@property
def prev_transaction_block_height(self) -> uint32:
...
@property
def prev_transaction_block_hash(self) -> Optional[bytes32]:
...
@property
def is_transaction_block(self) -> bool:
return self.timestamp is not None
@streamable
@dataclass(frozen=True)
class BlockRecord(Streamable):

View File

@ -54,9 +54,9 @@ from chia.util.setproctitle import getproctitle, setproctitle
log = logging.getLogger(__name__)
class ReceiveBlockResult(Enum):
class AddBlockResult(Enum):
"""
When Blockchain.receive_block(b) is called, one of these results is returned,
When Blockchain.add_block(b) is called, one of these results is returned,
showing whether the block was added to the chain (extending the peak),
and if not, why it was not added.
"""
@ -193,12 +193,12 @@ class Blockchain(BlockchainInterface):
async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]:
return await self.block_store.get_full_block(header_hash)
async def receive_block(
async def add_block(
self,
block: FullBlock,
pre_validation_result: PreValidationResult,
fork_point_with_peak: Optional[uint32] = None,
) -> Tuple[ReceiveBlockResult, Optional[Err], Optional[StateChangeSummary]]:
) -> Tuple[AddBlockResult, Optional[Err], Optional[StateChangeSummary]]:
"""
This method must be called under the blockchain lock
Adds a new block into the blockchain, if it's valid and connected to the current
@ -223,18 +223,18 @@ class Blockchain(BlockchainInterface):
genesis: bool = block.height == 0
if self.contains_block(block.header_hash):
return ReceiveBlockResult.ALREADY_HAVE_BLOCK, None, None
return AddBlockResult.ALREADY_HAVE_BLOCK, None, None
if not self.contains_block(block.prev_header_hash) and not genesis:
return ReceiveBlockResult.DISCONNECTED_BLOCK, Err.INVALID_PREV_BLOCK_HASH, None
return AddBlockResult.DISCONNECTED_BLOCK, Err.INVALID_PREV_BLOCK_HASH, None
if not genesis and (self.block_record(block.prev_header_hash).height + 1) != block.height:
return ReceiveBlockResult.INVALID_BLOCK, Err.INVALID_HEIGHT, None
return AddBlockResult.INVALID_BLOCK, Err.INVALID_HEIGHT, None
npc_result: Optional[NPCResult] = pre_validation_result.npc_result
required_iters = pre_validation_result.required_iters
if pre_validation_result.error is not None:
return ReceiveBlockResult.INVALID_BLOCK, Err(pre_validation_result.error), None
return AddBlockResult.INVALID_BLOCK, Err(pre_validation_result.error), None
assert required_iters is not None
error_code, _ = await validate_block_body(
@ -252,7 +252,7 @@ class Blockchain(BlockchainInterface):
validate_signature=not pre_validation_result.validated_signature,
)
if error_code is not None:
return ReceiveBlockResult.INVALID_BLOCK, error_code, None
return AddBlockResult.INVALID_BLOCK, error_code, None
block_record = block_to_block_record(
self.constants,
@ -300,9 +300,9 @@ class Blockchain(BlockchainInterface):
if state_change_summary is not None:
# new coin records added
return ReceiveBlockResult.NEW_PEAK, None, state_change_summary
return AddBlockResult.NEW_PEAK, None, state_change_summary
else:
return ReceiveBlockResult.ADDED_AS_ORPHAN, None, None
return AddBlockResult.ADDED_AS_ORPHAN, None, None
async def _reconsider_peak(
self,
@ -428,7 +428,6 @@ class Blockchain(BlockchainInterface):
async def get_tx_removals_and_additions(
self, block: FullBlock, npc_result: Optional[NPCResult] = None
) -> Tuple[List[bytes32], List[Coin], Optional[NPCResult]]:
if not block.is_transaction_block():
return [], [], None
@ -439,11 +438,7 @@ class Blockchain(BlockchainInterface):
block_generator: Optional[BlockGenerator] = await self.get_block_generator(block)
assert block_generator is not None
npc_result = get_name_puzzle_conditions(
block_generator,
self.constants.MAX_BLOCK_COST_CLVM,
cost_per_byte=self.constants.COST_PER_BYTE,
mempool_mode=False,
height=block.height,
block_generator, self.constants.MAX_BLOCK_COST_CLVM, mempool_mode=False, height=block.height
)
tx_removals, tx_additions = tx_removals_and_additions(npc_result.conds)
return tx_removals, tx_additions, npc_result
@ -540,6 +535,9 @@ class Blockchain(BlockchainInterface):
async def validate_unfinished_block_header(
self, block: UnfinishedBlock, skip_overflow_ss_validation: bool = True
) -> Tuple[Optional[uint64], Optional[Err]]:
if len(block.transactions_generator_ref_list) > self.constants.MAX_GENERATOR_REF_LIST_SIZE:
return None, Err.TOO_MANY_GENERATOR_REFS
if (
not self.contains_block(block.prev_header_hash)
and block.prev_header_hash != self.constants.GENESIS_CHALLENGE

View File

@ -61,8 +61,12 @@ class ConsensusConstants:
MAX_GENERATOR_SIZE: uint32
MAX_GENERATOR_REF_LIST_SIZE: uint32
POOL_SUB_SLOT_ITERS: uint64
# soft fork initiated in 1.7.0 release
SOFT_FORK_HEIGHT: uint32
# soft fork initiated in 1.8.0 release
SOFT_FORK2_HEIGHT: uint32
def replace(self, **changes: object) -> "ConsensusConstants":
return dataclasses.replace(self, **changes)

View File

@ -56,6 +56,7 @@ default_kwargs = {
"MAX_GENERATOR_REF_LIST_SIZE": 512, # Number of references allowed in the block generator ref list
"POOL_SUB_SLOT_ITERS": 37600000000, # iters limit * NUM_SPS
"SOFT_FORK_HEIGHT": 3630000,
"SOFT_FORK2_HEIGHT": 4000000,
}

View File

@ -27,7 +27,6 @@ def block_to_block_record(
header_block: Optional[HeaderBlock],
sub_slot_iters: Optional[uint64] = None,
) -> BlockRecord:
if full_block is None:
assert header_block is not None
block: Union[HeaderBlock, FullBlock] = header_block
@ -99,7 +98,6 @@ def header_block_to_sub_block_record(
prev_transaction_block_height: uint32,
ses: Optional[SubEpochSummary],
) -> BlockRecord:
reward_claims_incorporated = (
block.transactions_info.reward_claims_incorporated if block.transactions_info is not None else None
)

View File

@ -91,9 +91,9 @@ def batch_pre_validate_blocks(
npc_result = get_name_puzzle_conditions(
block_generator,
min(constants.MAX_BLOCK_COST_CLVM, block.transactions_info.cost),
cost_per_byte=constants.COST_PER_BYTE,
mempool_mode=False,
height=block.height,
constants=constants,
)
removals, tx_additions = tx_removals_and_additions(npc_result.conds)
if npc_result is not None and npc_result.error is not None:
@ -116,14 +116,17 @@ def batch_pre_validate_blocks(
successfully_validated_signatures = False
# If we failed CLVM, no need to validate signature, the block is already invalid
if error_int is None:
# If this is False, it means either we don't have a signature (not a tx block) or we have an invalid
# signature (which also puts in an error) or we didn't validate the signature because we want to
# validate it later. receive_block will attempt to validate the signature later.
# validate it later. add_block will attempt to validate the signature later.
if validate_signatures:
if npc_result is not None and block.transactions_info is not None:
assert npc_result.conds
pairs_pks, pairs_msgs = pkm_pairs(npc_result.conds, constants.AGG_SIG_ME_ADDITIONAL_DATA)
pairs_pks, pairs_msgs = pkm_pairs(
npc_result.conds,
constants.AGG_SIG_ME_ADDITIONAL_DATA,
soft_fork=block.height >= constants.SOFT_FORK_HEIGHT,
)
# Using AugSchemeMPL.aggregate_verify, so it's safe to use from_bytes_unchecked
pks_objects: List[G1Element] = [G1Element.from_bytes_unchecked(pk) for pk in pairs_pks]
if not AugSchemeMPL.aggregate_verify(
@ -383,7 +386,6 @@ def _run_generator(
npc_result: NPCResult = get_name_puzzle_conditions(
block_generator,
min(constants.MAX_BLOCK_COST_CLVM, unfinished_block.transactions_info.cost),
cost_per_byte=constants.COST_PER_BYTE,
mempool_mode=False,
height=height,
)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import logging
from dataclasses import dataclass
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List, Optional, Type
@ -95,23 +95,22 @@ class DeleteLabelRequest(Streamable):
return EmptyResponse()
@dataclass
class KeychainServer:
"""
Implements a remote keychain service for clients to perform key operations on
"""
def __init__(self):
self._default_keychain = Keychain()
self._alt_keychains = {}
_default_keychain: Keychain = field(default_factory=Keychain)
_alt_keychains: Dict[str, Keychain] = field(default_factory=dict)
def get_keychain_for_request(self, request: Dict[str, Any]):
def get_keychain_for_request(self, request: Dict[str, Any]) -> Keychain:
"""
Keychain instances can have user and service strings associated with them.
The keychain backends ultimately point to the same data stores, but the user
and service strings are used to partition those data stores. We attempt to
maintain a mapping of user/service pairs to their corresponding Keychain.
"""
keychain = None
user = request.get("kc_user", self._default_keychain.user)
service = request.get("kc_service", self._default_keychain.service)
if user == self._default_keychain.user and service == self._default_keychain.service:

View File

@ -13,9 +13,10 @@ import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, TextIO, Tuple
from typing import Any, AsyncIterator, Dict, List, Optional, Set, TextIO, Tuple
from chia import __version__
from chia.cmds.init_funcs import check_keys, chia_full_version_str, chia_init
@ -128,15 +129,13 @@ class WebSocketServer:
ca_key_path: Path,
crt_path: Path,
key_path: Path,
shutdown_event: asyncio.Event,
run_check_keys_on_unlock: bool = False,
):
self.root_path = root_path
self.log = log
self.services: Dict = dict()
self.services: Dict[str, List[subprocess.Popen]] = dict()
self.plots_queue: List[Dict] = []
self.connections: Dict[str, List[WebSocketResponse]] = dict() # service_name : [WebSocket]
self.remote_address_map: Dict[WebSocketResponse, str] = dict() # socket: service_name
self.connections: Dict[str, Set[WebSocketResponse]] = dict() # service name : {WebSocketResponse}
self.ping_job: Optional[asyncio.Task] = None
self.net_config = load_config(root_path, "config.yaml")
self.self_hostname = self.net_config["self_hostname"]
@ -147,9 +146,10 @@ class WebSocketServer:
self.ssl_context = ssl_context_for_server(ca_crt_path, ca_key_path, crt_path, key_path, log=self.log)
self.keychain_server = KeychainServer()
self.run_check_keys_on_unlock = run_check_keys_on_unlock
self.shutdown_event = shutdown_event
self.shutdown_event = asyncio.Event()
async def start(self) -> None:
@asynccontextmanager
async def run(self) -> AsyncIterator[None]:
self.log.info("Starting Daemon Server")
# Note: the minimum_version has been already set to TLSv1_2
@ -180,6 +180,12 @@ class WebSocketServer:
ssl_context=self.ssl_context,
logger=self.log,
)
try:
yield
finally:
if not self.shutdown_event.is_set():
await self.stop()
await self.exit()
async def setup_process_global_state(self) -> None:
try:
@ -213,11 +219,11 @@ class WebSocketServer:
if stop_service_jobs:
await asyncio.wait(stop_service_jobs)
self.services.clear()
asyncio.create_task(self.exit())
self.shutdown_event.set()
log.info(f"Daemon Server stopping, Services stopped: {service_names}")
return {"success": True, "services_stopped": service_names}
async def incoming_connection(self, request):
async def incoming_connection(self, request: web.Request) -> web.StreamResponse:
ws: WebSocketResponse = web.WebSocketResponse(
max_msg_size=self.daemon_max_message_size, heartbeat=self.heartbeat
)
@ -226,79 +232,106 @@ class WebSocketServer:
while True:
msg = await ws.receive()
self.log.debug("Received message: %s", msg)
decoded: WsRpcMessage = {
"command": "",
"ack": False,
"data": {},
"request_id": "",
"destination": "",
"origin": "",
}
if msg.type == WSMsgType.TEXT:
try:
decoded = json.loads(msg.data)
if "data" not in decoded:
decoded["data"] = {}
response, sockets_to_use = await self.handle_message(ws, decoded)
maybe_response = await self.handle_message(ws, decoded)
if maybe_response is None:
continue
response, connections = maybe_response
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Error while handling message: {tb}")
error = {"success": False, "error": f"{e}"}
response = format_response(decoded, error)
sockets_to_use = []
if len(sockets_to_use) > 0:
for socket in sockets_to_use:
try:
await socket.send_str(response)
except Exception as e:
tb = traceback.format_exc()
self.log.error(f"Unexpected exception trying to send to websocket: {e} {tb}")
self.remove_connection(socket)
await socket.close()
else:
service_name = "Unknown"
if ws in self.remote_address_map:
service_name = self.remote_address_map[ws]
if msg.type == WSMsgType.CLOSE:
self.log.info(f"ConnectionClosed. Closing websocket with {service_name}")
elif msg.type == WSMsgType.ERROR:
self.log.info(f"Websocket exception. Closing websocket with {service_name}. {ws.exception()}")
connections = {ws} # send error back to the sender
await self.send_all_responses(connections, response)
else:
service_names = self.remove_connection(ws)
if len(service_names) == 0:
service_names = ["Unknown"]
if msg.type == WSMsgType.CLOSE:
self.log.info(f"ConnectionClosed. Closing websocket with {service_names}")
elif msg.type == WSMsgType.ERROR:
self.log.info(f"Websocket exception. Closing websocket with {service_names}. {ws.exception()}")
else:
self.log.info(f"Unexpected message type. Closing websocket with {service_names}. {msg.type}")
self.remove_connection(ws)
await ws.close()
break
def remove_connection(self, websocket: WebSocketResponse):
service_name = None
if websocket in self.remote_address_map:
service_name = self.remote_address_map[websocket]
self.remote_address_map.pop(websocket)
if service_name in self.connections:
after_removal = []
for connection in self.connections[service_name]:
if connection == websocket:
continue
return ws
async def send_all_responses(self, connections: Set[WebSocketResponse], response: str) -> None:
for connection in connections.copy():
try:
await connection.send_str(response)
except Exception as e:
service_names = self.remove_connection(connection)
if len(service_names) == 0:
service_names = ["Unknown"]
if isinstance(e, ConnectionResetError):
self.log.info(f"Peer disconnected. Closing websocket with {service_names}")
else:
after_removal.append(connection)
self.connections[service_name] = after_removal
tb = traceback.format_exc()
self.log.error(f"Unexpected exception trying to send to {service_names} (websocket: {e} {tb})")
self.log.info(f"Closing websocket with {service_names}")
await connection.close()
def remove_connection(self, websocket: WebSocketResponse) -> List[str]:
"""Returns a list of service names from which the connection was removed"""
service_names = []
for service_name, connections in self.connections.items():
try:
connections.remove(websocket)
except KeyError:
continue
service_names.append(service_name)
return service_names
async def ping_task(self) -> None:
restart = True
await asyncio.sleep(30)
for remote_address, service_name in self.remote_address_map.items():
if service_name in self.connections:
sockets = self.connections[service_name]
for socket in sockets:
try:
self.log.debug(f"About to ping: {service_name}")
await socket.ping()
except asyncio.CancelledError:
self.log.warning("Ping task received Cancel")
restart = False
break
except Exception:
self.log.exception("Ping error")
self.log.error("Ping failed, connection closed.")
self.remove_connection(socket)
await socket.close()
for service_name, connections in self.connections.items():
if service_name == service_plotter:
continue
for connection in connections:
try:
self.log.debug(f"About to ping: {service_name}")
await connection.ping()
except asyncio.CancelledError:
self.log.warning("Ping task received Cancel")
restart = False
break
except Exception:
self.log.exception(f"Ping error to {service_name}")
self.log.error(f"Ping failed, connection closed to {service_name}.")
self.remove_connection(connection)
await connection.close()
if restart is True:
self.ping_job = asyncio.create_task(self.ping_task())
async def handle_message(
self, websocket: WebSocketResponse, message: WsRpcMessage
) -> Tuple[Optional[str], List[Any]]:
) -> Optional[Tuple[str, Set[WebSocketResponse]]]:
"""
This function gets called when new message is received via websocket.
"""
@ -310,7 +343,7 @@ class WebSocketServer:
sockets = self.connections[destination]
return dict_to_json_str(message), sockets
return None, []
return None
data = message["data"]
commands_with_data = [
@ -367,7 +400,7 @@ class WebSocketServer:
response = {"success": False, "error": f"unknown_command {command}"}
full_response = format_response(message, response)
return full_response, [websocket]
return full_response, {websocket}
async def is_keyring_locked(self) -> Dict[str, Any]:
locked: bool = Keychain.is_keyring_locked()
@ -1020,7 +1053,7 @@ class WebSocketServer:
run_next = True
config["state"] = PlotState.REMOVING
self.state_changed(service_plotter, self.prepare_plot_state_message(PlotEvent.STATE_CHANGED, id))
await kill_process(process, self.root_path, service_plotter, id)
await kill_processes([process], self.root_path, service_plotter, id)
config["state"] = PlotState.FINISHED
config["deleted"] = True
@ -1056,9 +1089,8 @@ class WebSocketServer:
error = "unknown service"
if service_command in self.services:
service = self.services[service_command]
r = service is not None and service.poll() is None
if r is False:
processes = self.services[service_command]
if all(process.poll() is not None for process in processes):
self.services.pop(service_command)
error = None
else:
@ -1078,7 +1110,7 @@ class WebSocketServer:
if testing is True:
exe_command = f"{service_command} --testing=true"
process, pid_path = launch_service(self.root_path, exe_command)
self.services[service_command] = process
self.services[service_command] = [process]
success = True
except (subprocess.SubprocessError, IOError):
log.exception(f"problem starting {service_command}")
@ -1094,12 +1126,13 @@ class WebSocketServer:
return response
def is_service_running(self, service_name: str) -> bool:
processes: List[subprocess.Popen]
if service_name == service_plotter:
processes = self.services.get(service_name)
is_running = processes is not None and len(processes) > 0
processes = self.services.get(service_name, [])
is_running = len(processes) > 0
else:
process = self.services.get(service_name)
is_running = process is not None and process.poll() is None
processes = self.services.get(service_name, [])
is_running = any(process.poll() is None for process in processes)
if not is_running:
# Check if we have a connection to the requested service. This might be the
# case if the service was started manually (i.e. not started by the daemon).
@ -1123,7 +1156,6 @@ class WebSocketServer:
if self.webserver is not None:
self.webserver.close()
await self.webserver.await_closed()
self.shutdown_event.set()
log.info("chia daemon exiting")
async def register_service(self, websocket: WebSocketResponse, request: Dict[str, Any]) -> Dict[str, Any]:
@ -1133,8 +1165,8 @@ class WebSocketServer:
self.log.error("Service Name missing from request to 'register_service'")
return {"success": False}
if service not in self.connections:
self.connections[service] = []
self.connections[service].append(websocket)
self.connections[service] = set()
self.connections[service].add(websocket)
response: Dict[str, Any] = {"success": True}
if service == service_plotter:
@ -1144,7 +1176,6 @@ class WebSocketServer:
"queue": self.extract_plot_queue(),
}
else:
self.remote_address_map[websocket] = service
if self.ping_job is None:
self.ping_job = asyncio.create_task(self.ping_task())
self.log.info(f"registered for service {service}")
@ -1266,31 +1297,39 @@ def launch_service(root_path: Path, service_command) -> Tuple[subprocess.Popen,
return process, pid_path
async def kill_process(
process: subprocess.Popen, root_path: Path, service_name: str, id: str, delay_before_kill: int = 15
async def kill_processes(
processes: List[subprocess.Popen],
root_path: Path,
service_name: str,
id: str,
delay_before_kill: int = 15,
) -> bool:
pid_path = pid_path_for_service(root_path, service_name, id)
if sys.platform == "win32" or sys.platform == "cygwin":
log.info("sending CTRL_BREAK_EVENT signal to %s", service_name)
# pylint: disable=E1101
kill(process.pid, signal.SIGBREAK)
for process in processes:
kill(process.pid, signal.SIGBREAK)
else:
log.info("sending term signal to %s", service_name)
process.terminate()
for process in processes:
process.terminate()
count: float = 0
while count < delay_before_kill:
if process.poll() is not None:
if all(process.poll() is not None for process in processes):
break
await asyncio.sleep(0.5)
count += 0.5
else:
process.kill()
for process in processes:
process.kill()
log.info("sending kill signal to %s", service_name)
r = process.wait()
log.info("process %s returned %d", service_name, r)
for process in processes:
r = process.wait()
log.info("process %s returned %d", service_name, r)
try:
pid_path_killed = pid_path.with_suffix(".pid-killed")
if pid_path_killed.exists():
@ -1303,13 +1342,13 @@ async def kill_process(
async def kill_service(
root_path: Path, services: Dict[str, subprocess.Popen], service_name: str, delay_before_kill: int = 15
root_path: Path, services: Dict[str, List[subprocess.Popen]], service_name: str, delay_before_kill: int = 15
) -> bool:
process = services.get(service_name)
if process is None:
processes = services.get(service_name)
if processes is None:
return False
del services[service_name]
result = await kill_process(process, root_path, service_name, "", delay_before_kill)
result = await kill_processes(processes, root_path, service_name, "", delay_before_kill)
return result
@ -1350,20 +1389,17 @@ async def async_run_daemon(root_path: Path, wait_for_unlock: bool = False) -> in
beta_metrics = BetaMetricsLogger(root_path)
beta_metrics.start_logging()
shutdown_event = asyncio.Event()
ws_server = WebSocketServer(
root_path,
ca_crt_path,
ca_key_path,
crt_path,
key_path,
shutdown_event,
run_check_keys_on_unlock=wait_for_unlock,
)
await ws_server.setup_process_global_state()
await ws_server.start()
await shutdown_event.wait()
async with ws_server.run():
await ws_server.shutdown_event.wait()
if beta_metrics is not None:
await beta_metrics.stop_logging()

View File

@ -117,7 +117,6 @@ class DataLayerServer:
async def async_start(root_path: Path) -> int:
shutdown_event = asyncio.Event()
dl_config = load_config(

View File

@ -78,7 +78,7 @@ async def _dot_dump(data_store: DataStore, store_id: bytes32, root_hash: bytes32
dot_connections.append(f"""node_{hash} -> node_{left} [label="L"];""")
dot_connections.append(f"""node_{hash} -> node_{right} [label="R"];""")
dot_pair_boxes.append(
f"node [shape = box]; " f"{{rank = same; node_{left}->node_{right}[style=invis]; rankdir = LR}}"
f"node [shape = box]; {{rank = same; node_{left}->node_{right}[style=invis]; rankdir = LR}}"
)
lines = [

View File

@ -4,10 +4,11 @@ import dataclasses
import logging
import time
from operator import attrgetter
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, TypeVar
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from blspy import G1Element, G2Element
from clvm.EvalError import EvalError
from typing_extensions import final
from chia.consensus.block_record import BlockRecord
from chia.data_layer.data_layer_errors import OfferIntegrityError
@ -100,9 +101,7 @@ class Mirror:
)
_T_DataLayerWallet = TypeVar("_T_DataLayerWallet", bound="DataLayerWallet")
@final
class DataLayerWallet:
wallet_state_manager: WalletStateManager
log: logging.Logger
@ -115,57 +114,47 @@ class DataLayerWallet:
@classmethod
async def create(
cls: Type[_T_DataLayerWallet],
cls,
wallet_state_manager: WalletStateManager,
wallet: Wallet,
wallet_info: WalletInfo,
name: Optional[str] = None,
) -> _T_DataLayerWallet:
) -> DataLayerWallet:
self = cls()
self.wallet_state_manager = wallet_state_manager
self.log = logging.getLogger(name if name else __name__)
self.standard_wallet = wallet
self.log = logging.getLogger(__name__)
self.standard_wallet = wallet_state_manager.main_wallet
self.wallet_info = wallet_info
self.wallet_id = uint8(self.wallet_info.id)
return self
@classmethod
def type(cls) -> uint8:
return uint8(WalletType.DATA_LAYER)
def type(cls) -> WalletType:
return WalletType.DATA_LAYER
def id(self) -> uint32:
return self.wallet_info.id
@classmethod
async def create_new_dl_wallet(
cls: Type[_T_DataLayerWallet],
wallet_state_manager: WalletStateManager,
wallet: Wallet,
name: Optional[str] = "DataLayer Wallet",
) -> _T_DataLayerWallet:
async def create_new_dl_wallet(cls, wallet_state_manager: WalletStateManager) -> DataLayerWallet:
"""
This must be called under the wallet state manager lock
"""
self = cls()
self.wallet_state_manager = wallet_state_manager
self.log = logging.getLogger(name if name else __name__)
self.standard_wallet = wallet
self.log = logging.getLogger(__name__)
self.standard_wallet = wallet_state_manager.main_wallet
for _, w in self.wallet_state_manager.wallets.items():
if w.type() == uint8(WalletType.DATA_LAYER):
if w.type() == WalletType.DATA_LAYER:
raise ValueError("DataLayer Wallet already exists for this key")
assert name is not None
self.wallet_info = await wallet_state_manager.user_store.create_wallet(
name,
"DataLayer Wallet",
WalletType.DATA_LAYER.value,
"",
)
self.wallet_id = uint8(self.wallet_info.id)
await self.wallet_state_manager.add_new_wallet(self, self.wallet_info.id)
await self.wallet_state_manager.add_new_wallet(self)
return self
@ -613,10 +602,12 @@ class DataLayerWallet:
type=uint32(TransactionType.OUTGOING_TX.value),
name=singleton_record.coin_id,
)
assert dl_tx.spend_bundle is not None
if fee > 0:
chia_tx = await self.create_tandem_xch_tx(
fee, Announcement(current_coin.name(), b"$"), coin_announcement=True
)
assert chia_tx.spend_bundle is not None
aggregate_bundle = SpendBundle.aggregate([dl_tx.spend_bundle, chia_tx.spend_bundle])
dl_tx = dataclasses.replace(dl_tx, spend_bundle=aggregate_bundle)
chia_tx = dataclasses.replace(chia_tx, spend_bundle=None)
@ -818,6 +809,8 @@ class DataLayerWallet:
fee=uint64(excess_fee),
coin_announcements_to_consume={Announcement(mirror_coin.name(), b"$")},
)
assert txs[0].spend_bundle is not None
assert chia_tx.spend_bundle is not None
txs = [
dataclasses.replace(
txs[0], spend_bundle=SpendBundle.aggregate([txs[0].spend_bundle, chia_tx.spend_bundle])
@ -1216,8 +1209,8 @@ class DataLayerWallet:
# Build a mapping of launcher IDs to their new innerpuz
singleton_to_innerpuzhash: Dict[bytes32, bytes32] = {}
singleton_to_root: Dict[bytes32, bytes32] = {}
all_parent_ids: List[bytes32] = [cs.coin.parent_coin_info for cs in offer.bundle.coin_spends]
for spend in offer.bundle.coin_spends:
all_parent_ids: List[bytes32] = [cs.coin.parent_coin_info for cs in offer.coin_spends()]
for spend in offer.coin_spends():
matched, curried_args = match_dl_singleton(spend.puzzle_reveal.to_program())
if matched and spend.coin.name() not in all_parent_ids:
innerpuz, root, launcher_id = curried_args
@ -1227,7 +1220,7 @@ class DataLayerWallet:
# Create all of the new solutions
new_spends: List[CoinSpend] = []
for spend in offer.bundle.coin_spends:
for spend in offer.coin_spends():
solution = spend.solution.to_program()
if match_dl_singleton(spend.puzzle_reveal.to_program())[0]:
try:
@ -1283,12 +1276,12 @@ class DataLayerWallet:
spend = new_spend
new_spends.append(spend)
return Offer({}, SpendBundle(new_spends, offer.bundle.aggregated_signature), offer.driver_dict, offer.old)
return Offer({}, SpendBundle(new_spends, offer.aggregated_signature()), offer.driver_dict, offer.old)
@staticmethod
async def get_offer_summary(offer: Offer) -> Dict[str, Any]:
summary: Dict[str, Any] = {"offered": []}
for spend in offer.bundle.coin_spends:
for spend in offer.coin_spends():
solution = spend.solution.to_program()
matched, curried_args = match_dl_singleton(spend.puzzle_reveal.to_program())
if matched:
@ -1299,7 +1292,7 @@ class DataLayerWallet:
mod, graftroot_curried_args = graftroot.uncurry()
if mod == GRAFTROOT_DL_OFFERS:
child_spend: CoinSpend = next(
cs for cs in offer.bundle.coin_spends if cs.coin.parent_coin_info == spend.coin.name()
cs for cs in offer.coin_spends() if cs.coin.parent_coin_info == spend.coin.name()
)
_, child_curried_args = match_dl_singleton(child_spend.puzzle_reveal.to_program())
singleton_summary = {

View File

@ -172,9 +172,11 @@ class DataStore:
async with self.db_wrapper.writer() as writer:
if generation is None:
existing_generation = await self.get_tree_generation(tree_id=tree_id)
if existing_generation is None:
try:
existing_generation = await self.get_tree_generation(tree_id=tree_id)
except Exception as e:
if not str(e).startswith("No generations found for tree ID:"):
raise
generation = 0
else:
generation = existing_generation + 1
@ -456,10 +458,13 @@ class DataStore:
)
row = await cursor.fetchone()
if row is None:
raise Exception(f"No generations found for tree ID: {tree_id.hex()}")
generation: int = row["MAX(generation)"]
return generation
if row is not None:
generation: Optional[int] = row["MAX(generation)"]
if generation is not None:
return generation
raise Exception(f"No generations found for tree ID: {tree_id.hex()}")
async def get_tree_root(self, tree_id: bytes32, generation: Optional[int] = None) -> Root:
async with self.db_wrapper.reader() as reader:

View File

@ -31,7 +31,7 @@ from chia.protocols.pool_protocol import (
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.rpc.rpc_server import StateChangedProtocol, default_get_connections
from chia.server.outbound_message import NodeType, make_msg
from chia.server.server import ssl_context_for_root
from chia.server.server import ChiaServer, ssl_context_for_root
from chia.server.ws_connection import WSChiaConnection
from chia.ssl.create_ssl import get_mozilla_ca_crt
from chia.types.blockchain_format.proof_of_space import ProofOfSpace
@ -41,7 +41,7 @@ from chia.util.byte_types import hexstr_to_bytes
from chia.util.config import config_path_for_filename, load_config, lock_and_load_config, save_config
from chia.util.errors import KeychainProxyConnectionFailure
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64
from chia.util.ints import uint8, uint16, uint64
from chia.util.keychain import Keychain
from chia.util.logging import TimedDuplicateFilter
from chia.wallet.derive_keys import (
@ -70,8 +70,8 @@ class Farmer:
def __init__(
self,
root_path: Path,
farmer_config: Dict,
pool_config: Dict,
farmer_config: Dict[str, Any],
pool_config: Dict[str, Any],
consensus_constants: ConsensusConstants,
local_keychain: Optional[Keychain] = None,
):
@ -98,8 +98,8 @@ class Farmer:
self.plot_sync_receivers: Dict[bytes32, Receiver] = {}
self.cache_clear_task: Optional[asyncio.Task] = None
self.update_pool_state_task: Optional[asyncio.Task] = None
self.cache_clear_task: Optional[asyncio.Task[None]] = None
self.update_pool_state_task: Optional[asyncio.Task[None]] = None
self.constants = consensus_constants
self._shut_down = False
self.server: Any = None
@ -109,16 +109,18 @@ class Farmer:
self.log.addFilter(TimedDuplicateFilter("No pool specific difficulty has been set.*", 60 * 10))
self.started = False
self.harvester_handshake_task: Optional[asyncio.Task] = None
self.harvester_handshake_task: Optional[asyncio.Task[None]] = None
# From p2_singleton_puzzle_hash to pool state dict
self.pool_state: Dict[bytes32, Dict] = {}
self.pool_state: Dict[bytes32, Dict[str, Any]] = {}
# From p2_singleton to auth PrivateKey
self.authentication_keys: Dict[bytes32, PrivateKey] = {}
# Last time we updated pool_state based on the config file
self.last_config_access_time: uint64 = uint64(0)
self.last_config_access_time: float = 0
self.all_root_sks: List[PrivateKey] = []
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type)
@ -133,14 +135,14 @@ class Farmer:
raise KeychainProxyConnectionFailure()
return self.keychain_proxy
async def get_all_private_keys(self):
async def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]:
keychain_proxy = await self.ensure_keychain_proxy()
return await keychain_proxy.get_all_private_keys()
async def setup_keys(self) -> bool:
no_keys_error_str = "No keys exist. Please run 'chia keys generate' or open the UI."
try:
self.all_root_sks: List[PrivateKey] = [sk for sk, _ in await self.get_all_private_keys()]
self.all_root_sks = [sk for sk, _ in await self.get_all_private_keys()]
except KeychainProxyConnectionFailure:
return False
@ -170,9 +172,7 @@ class Farmer:
# This is the self pooling configuration, which is only used for original self-pooled plots
self.pool_target_encoded = self.pool_config["xch_target_address"]
self.pool_target = decode_puzzle_hash(self.pool_target_encoded)
self.pool_sks_map: Dict = {}
for key in self.get_private_keys():
self.pool_sks_map[bytes(key.get_g1())] = key
self.pool_sks_map = {bytes(key.get_g1()): key for key in self.get_private_keys()}
assert len(self.farmer_target) == 32
assert len(self.pool_target) == 32
@ -182,8 +182,8 @@ class Farmer:
return True
async def _start(self):
async def start_task():
async def _start(self) -> None:
async def start_task() -> None:
# `Farmer.setup_keys` returns `False` if there are no keys setup yet. In this case we just try until it
# succeeds or until we need to shut down.
while not self._shut_down:
@ -197,10 +197,10 @@ class Farmer:
asyncio.create_task(start_task())
def _close(self):
def _close(self) -> None:
self._shut_down = True
async def _await_closed(self, shutting_down: bool = True):
async def _await_closed(self, shutting_down: bool = True) -> None:
if self.cache_clear_task is not None:
await self.cache_clear_task
if self.update_pool_state_task is not None:
@ -215,10 +215,10 @@ class Farmer:
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback
async def on_connect(self, peer: WSChiaConnection):
async def on_connect(self, peer: WSChiaConnection) -> None:
self.state_changed("add_connection", {})
async def handshake_task():
async def handshake_task() -> None:
# Wait until the task in `Farmer._start` is done so that we have keys available for the handshake. Bail out
# early if we need to shut down or if the harvester is not longer connected.
while not self.started and not self._shut_down and peer in self.server.get_connections():
@ -247,20 +247,20 @@ class Farmer:
self.plot_sync_receivers[peer.peer_node_id] = Receiver(peer, self.plot_sync_callback)
self.harvester_handshake_task = asyncio.create_task(handshake_task())
def set_server(self, server):
def set_server(self, server: ChiaServer) -> None:
self.server = server
def state_changed(self, change: str, data: Dict[str, Any]):
def state_changed(self, change: str, data: Dict[str, Any]) -> None:
if self.state_changed_callback is not None:
self.state_changed_callback(change, data)
def handle_failed_pool_response(self, p2_singleton_puzzle_hash: bytes32, error_message: str):
def handle_failed_pool_response(self, p2_singleton_puzzle_hash: bytes32, error_message: str) -> None:
self.log.error(error_message)
self.pool_state[p2_singleton_puzzle_hash]["pool_errors_24h"].append(
ErrorResponse(uint16(PoolErrorCode.REQUEST_FAILED.value), error_message).to_json_dict()
)
def on_disconnect(self, connection: WSChiaConnection):
def on_disconnect(self, connection: WSChiaConnection) -> None:
self.log.info(f"peer disconnected {connection.get_peer_logging()}")
self.state_changed("close_connection", {})
if connection.connection_type is NodeType.HARVESTER:
@ -274,14 +274,14 @@ class Farmer:
if receiver.initial_sync() or harvester_updated:
self.state_changed("harvester_update", receiver.to_dict(True))
async def _pool_get_pool_info(self, pool_config: PoolWalletConfig) -> Optional[Dict]:
async def _pool_get_pool_info(self, pool_config: PoolWalletConfig) -> Optional[Dict[str, Any]]:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
f"{pool_config.pool_url}/pool_info", ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log)
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
self.log.info(f"GET /pool_info response: {response}")
return response
else:
@ -299,7 +299,7 @@ class Farmer:
async def _pool_get_farmer(
self, pool_config: PoolWalletConfig, authentication_token_timeout: uint8, authentication_sk: PrivateKey
) -> Optional[Dict]:
) -> Optional[Dict[str, Any]]:
authentication_token = get_current_authentication_token(authentication_token_timeout)
message: bytes32 = std_hash(
AuthenticationPayload(
@ -320,7 +320,7 @@ class Farmer:
ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log),
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
log_level = logging.INFO
if "error_code" in response:
log_level = logging.WARNING
@ -340,7 +340,7 @@ class Farmer:
async def _pool_post_farmer(
self, pool_config: PoolWalletConfig, authentication_token_timeout: uint8, owner_sk: PrivateKey
) -> Optional[Dict]:
) -> Optional[Dict[str, Any]]:
auth_sk: Optional[PrivateKey] = self.get_authentication_sk(pool_config)
assert auth_sk is not None
post_farmer_payload: PostFarmerPayload = PostFarmerPayload(
@ -362,7 +362,7 @@ class Farmer:
ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log),
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
log_level = logging.INFO
if "error_code" in response:
log_level = logging.WARNING
@ -404,7 +404,7 @@ class Farmer:
ssl=ssl_context_for_root(get_mozilla_ca_crt(), log=self.log),
) as resp:
if resp.ok:
response: Dict = json.loads(await resp.text())
response: Dict[str, Any] = json.loads(await resp.text())
log_level = logging.INFO
if "error_code" in response:
log_level = logging.WARNING
@ -428,7 +428,7 @@ class Farmer:
self.authentication_keys[pool_config.p2_singleton_puzzle_hash] = auth_sk
return auth_sk
async def update_pool_state(self):
async def update_pool_state(self) -> None:
config = load_config(self._root_path, "config.yaml")
pool_config_list: List[PoolWalletConfig] = load_pool_config(self._root_path)
@ -514,9 +514,7 @@ class Farmer:
farmer_info, error_code = await update_pool_farmer_info()
if error_code == PoolErrorCode.FARMER_NOT_KNOWN:
# Make the farmer known on the pool with a POST /farmer
owner_sk_and_index: Optional[Tuple[PrivateKey, uint32]] = find_owner_sk(
self.all_root_sks, pool_config.owner_public_key
)
owner_sk_and_index = find_owner_sk(self.all_root_sks, pool_config.owner_public_key)
assert owner_sk_and_index is not None
post_response = await self._pool_post_farmer(
pool_config, authentication_token_timeout, owner_sk_and_index[0]
@ -538,9 +536,7 @@ class Farmer:
and pool_config.payout_instructions.lower() != farmer_info.payout_instructions.lower()
)
if payout_instructions_update_required or error_code == PoolErrorCode.INVALID_SIGNATURE:
owner_sk_and_index: Optional[Tuple[PrivateKey, uint32]] = find_owner_sk(
self.all_root_sks, pool_config.owner_public_key
)
owner_sk_and_index = find_owner_sk(self.all_root_sks, pool_config.owner_public_key)
assert owner_sk_and_index is not None
await self._pool_put_farmer(
pool_config, authentication_token_timeout, owner_sk_and_index[0]
@ -555,13 +551,13 @@ class Farmer:
tb = traceback.format_exc()
self.log.error(f"Exception in update_pool_state for {pool_config.pool_url}, {e} {tb}")
def get_public_keys(self):
def get_public_keys(self) -> List[G1Element]:
return [child_sk.get_g1() for child_sk in self._private_keys]
def get_private_keys(self):
def get_private_keys(self) -> List[PrivateKey]:
return self._private_keys
async def get_reward_targets(self, search_for_private_key: bool, max_ph_to_search: int = 500) -> Dict:
async def get_reward_targets(self, search_for_private_key: bool, max_ph_to_search: int = 500) -> Dict[str, Any]:
if search_for_private_key:
all_sks = await self.get_all_private_keys()
have_farmer_sk, have_pool_sk = False, False
@ -591,7 +587,7 @@ class Farmer:
"pool_target": self.pool_target_encoded,
}
def set_reward_targets(self, farmer_target_encoded: Optional[str], pool_target_encoded: Optional[str]):
def set_reward_targets(self, farmer_target_encoded: Optional[str], pool_target_encoded: Optional[str]) -> None:
with lock_and_load_config(self._root_path, "config.yaml") as config:
if farmer_target_encoded is not None:
self.farmer_target_encoded = farmer_target_encoded
@ -603,7 +599,7 @@ class Farmer:
config["pool"]["xch_target_address"] = pool_target_encoded
save_config(self._root_path, "config.yaml", config)
async def set_payout_instructions(self, launcher_id: bytes32, payout_instructions: str):
async def set_payout_instructions(self, launcher_id: bytes32, payout_instructions: str) -> None:
for p2_singleton_puzzle_hash, pool_state_dict in self.pool_state.items():
if launcher_id == pool_state_dict["pool_config"].launcher_id:
with lock_and_load_config(self._root_path, "config.yaml") as config:
@ -627,7 +623,6 @@ class Farmer:
for pool_state in self.pool_state.values():
pool_config: PoolWalletConfig = pool_state["pool_config"]
if pool_config.launcher_id == launcher_id:
authentication_sk: Optional[PrivateKey] = self.get_authentication_sk(pool_config)
if authentication_sk is None:
self.log.error(f"Could not find authentication sk for {pool_config.p2_singleton_puzzle_hash}")
@ -655,8 +650,8 @@ class Farmer:
return None
async def get_harvesters(self, counts_only: bool = False) -> Dict:
harvesters: List = []
async def get_harvesters(self, counts_only: bool = False) -> Dict[str, Any]:
harvesters: List[Dict[str, Any]] = []
for connection in self.server.get_connections(NodeType.HARVESTER):
self.log.debug(f"get_harvesters host: {connection.peer_host}, node_id: {connection.peer_node_id}")
receiver = self.plot_sync_receivers.get(connection.peer_node_id)
@ -675,26 +670,26 @@ class Farmer:
raise KeyError(f"Receiver missing for {node_id}")
return receiver
async def _periodically_update_pool_state_task(self):
time_slept: uint64 = uint64(0)
async def _periodically_update_pool_state_task(self) -> None:
time_slept = 0
config_path: Path = config_path_for_filename(self._root_path, "config.yaml")
while not self._shut_down:
# Every time the config file changes, read it to check the pool state
stat_info = config_path.stat()
if stat_info.st_mtime > self.last_config_access_time:
# If we detect the config file changed, refresh private keys first just in case
self.all_root_sks: List[PrivateKey] = [sk for sk, _ in await self.get_all_private_keys()]
self.all_root_sks = [sk for sk, _ in await self.get_all_private_keys()]
self.last_config_access_time = stat_info.st_mtime
await self.update_pool_state()
time_slept = uint64(0)
time_slept = 0
elif time_slept > 60:
await self.update_pool_state()
time_slept = uint64(0)
time_slept = 0
time_slept += 1
await asyncio.sleep(1)
async def _periodically_clear_cache_and_refresh_task(self):
time_slept: uint64 = uint64(0)
async def _periodically_clear_cache_and_refresh_task(self) -> None:
time_slept = 0
refresh_slept = 0
while not self._shut_down:
try:
@ -710,7 +705,7 @@ class Farmer:
removed_keys.append(key)
for key in removed_keys:
self.cache_add_time.pop(key, None)
time_slept = uint64(0)
time_slept = 0
log.debug(
f"Cleared farmer cache. Num sps: {len(self.sps)} {len(self.proofs_of_space)} "
f"{len(self.quality_str_to_identifiers)} {len(self.number_of_responses)}"

View File

@ -1,13 +1,12 @@
from __future__ import annotations
from chia.full_node.fee_estimate_store import FeeStore
from chia.full_node.fee_estimation import EmptyFeeMempoolInfo, FeeBlockInfo, FeeMempoolInfo
from chia.full_node.fee_estimation import EmptyFeeMempoolInfo, FeeBlockInfo, FeeMempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator import SmartFeeEstimator
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.fee_tracker import FeeTracker
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRateV2
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint32, uint64
@ -35,11 +34,11 @@ class BitcoinFeeEstimator(FeeEstimatorInterface):
self.block_height = block_info.block_height
self.tracker.process_block(block_info.block_height, block_info.included_items)
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
self.last_mempool_info = mempool_info
self.tracker.add_tx(mempool_item)
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
self.last_mempool_info = mempool_info
self.tracker.remove_tx(mempool_item)

View File

@ -157,7 +157,6 @@ class BlockHeightMap:
# time until we hit a match in the existing map, at which point we can
# assume all previous blocks have already been populated
async def _load_blocks_from(self, height: uint32, prev_hash: bytes32) -> None:
while height > 0:
# load 5000 blocks at a time
window_end = max(0, height - 5000)
@ -175,7 +174,6 @@ class BlockHeightMap:
async with self.db.reader_no_transaction() as conn:
async with conn.execute(query, (window_end, height)) as cursor:
# maps block-hash -> (height, prev-hash, sub-epoch-summary)
ordered: Dict[bytes32, Tuple[uint32, bytes32, Optional[bytes]]] = {}
@ -195,7 +193,6 @@ class BlockHeightMap:
assert height == entry[0] + 1
height = entry[0]
if entry[2] is not None:
if (
self.get_hash(height) == prev_hash
and height in self.__sub_epoch_summaries

View File

@ -34,10 +34,8 @@ class BlockStore:
self = cls(LRUCache(1000), db_wrapper, LRUCache(50))
async with self.db_wrapper.writer_maybe_transaction() as conn:
log.info("DB: Creating block store tables and indexes.")
if self.db_wrapper.db_version == 2:
# TODO: most data in block is duplicated in block_record. The only
# reason for this is that our parsing of a FullBlock is so slow,
# it's faster to store duplicate data to parse less when we just
@ -84,7 +82,6 @@ class BlockStore:
)
else:
await conn.execute(
"CREATE TABLE IF NOT EXISTS full_blocks(header_hash text PRIMARY KEY, height bigint,"
" is_block tinyint, is_fully_compactified tinyint, block blob)"
@ -168,7 +165,6 @@ class BlockStore:
raise RuntimeError(f"The blockchain database is corrupt. All of {header_hashes} should exist")
async def replace_proof(self, header_hash: bytes32, block: FullBlock) -> None:
assert header_hash == block.header_hash
block_bytes: bytes
@ -193,7 +189,6 @@ class BlockStore:
self.block_cache.put(header_hash, block)
if self.db_wrapper.db_version == 2:
ses: Optional[bytes] = (
None
if block_record.sub_epoch_summary_included is None
@ -284,9 +279,7 @@ class BlockStore:
async def get_full_block(self, header_hash: bytes32) -> Optional[FullBlock]:
cached: Optional[FullBlock] = self.block_cache.get(header_hash)
if cached is not None:
log.debug(f"cache hit for block {header_hash.hex()}")
return cached
log.debug(f"cache miss for block {header_hash.hex()}")
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT block from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
@ -301,9 +294,7 @@ class BlockStore:
async def get_full_block_bytes(self, header_hash: bytes32) -> Optional[bytes]:
cached = self.block_cache.get(header_hash)
if cached is not None:
log.debug(f"cache hit for block {header_hash.hex()}")
return bytes(cached)
log.debug(f"cache miss for block {header_hash.hex()}")
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT block from full_blocks WHERE header_hash=?", (self.maybe_to_hex(header_hash),)
@ -331,10 +322,8 @@ class BlockStore:
return ret
async def get_block_info(self, header_hash: bytes32) -> Optional[GeneratorBlockInfo]:
cached = self.block_cache.get(header_hash)
if cached is not None:
log.debug(f"cache hit for block {header_hash.hex()}")
return GeneratorBlockInfo(
cached.foliage.prev_block_hash, cached.transactions_generator, cached.transactions_generator_ref_list
)
@ -362,10 +351,8 @@ class BlockStore:
)
async def get_generator(self, header_hash: bytes32) -> Optional[SerializedProgram]:
cached = self.block_cache.get(header_hash)
if cached is not None:
log.debug(f"cache hit for block {header_hash.hex()}")
return cached.transactions_generator
formatted_str = "SELECT block, height from full_blocks WHERE header_hash=?"
@ -521,9 +508,7 @@ class BlockStore:
return ret
async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT block_record FROM full_blocks WHERE header_hash=?",
@ -556,7 +541,6 @@ class BlockStore:
ret: Dict[bytes32, BlockRecord] = {}
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ? AND height <= ?",
@ -567,7 +551,6 @@ class BlockStore:
ret[header_hash] = BlockRecord.from_bytes(row[1])
else:
formatted_str = f"SELECT header_hash, block from block_records WHERE height >= {start} and height <= {stop}"
async with self.db_wrapper.reader_no_transaction() as conn:
@ -601,7 +584,6 @@ class BlockStore:
return [maybe_decompress_blob(row[0]) for row in rows]
async def get_peak(self) -> Optional[Tuple[bytes32, uint32]]:
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute("SELECT hash FROM current_peak WHERE key = 0") as cursor:
@ -636,7 +618,6 @@ class BlockStore:
ret: Dict[bytes32, BlockRecord] = {}
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(
"SELECT header_hash, block_record FROM full_blocks WHERE height >= ?",
@ -683,7 +664,6 @@ class BlockStore:
return bool(row[0])
async def get_random_not_compactified(self, number: int) -> List[int]:
if self.db_wrapper.db_version == 2:
async with self.db_wrapper.reader_no_transaction() as conn:
async with conn.execute(

View File

@ -36,10 +36,8 @@ class CoinStore:
self = CoinStore(db_wrapper, LRUCache(100))
async with self.db_wrapper.writer_maybe_transaction() as conn:
log.info("DB: Creating coin store tables and indexes.")
if self.db_wrapper.db_version == 2:
# the coin_name is unique in this table because the CoinStore always
# only represent a single peak
await conn.execute(
@ -55,7 +53,6 @@ class CoinStore:
)
else:
# the coin_name is unique in this table because the CoinStore always
# only represent a single peak
await conn.execute(
@ -193,7 +190,7 @@ class CoinStore:
if self.db_wrapper.db_version == 2:
names_db = tuple(names_chunk)
else:
names_db = tuple([n.hex() for n in names_chunk])
names_db = tuple(n.hex() for n in names_chunk)
cursors.append(
await conn.execute(
f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
@ -273,7 +270,6 @@ class CoinStore:
start_height: uint32 = uint32(0),
end_height: uint32 = uint32((2**32) - 1),
) -> List[CoinRecord]:
coins = set()
async with self.db_wrapper.reader_no_transaction() as conn:
@ -284,7 +280,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
(self.maybe_to_hex(puzzle_hash), start_height, end_height),
) as cursor:
for row in await cursor.fetchall():
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -316,7 +311,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
puzzle_hashes_db + (start_height, end_height),
) as cursor:
for row in await cursor.fetchall():
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -348,7 +342,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
names_db + (start_height, end_height),
) as cursor:
for row in await cursor.fetchall():
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -370,11 +363,13 @@ class CoinStore:
include_spent_coins: bool,
puzzle_hashes: List[bytes32],
min_height: uint32 = uint32(0),
*,
max_items: int = 50000,
) -> List[CoinState]:
if len(puzzle_hashes) == 0:
return []
coins = set()
coins: Set[CoinState] = set()
async with self.db_wrapper.reader_no_transaction() as conn:
for puzzles in chunks(puzzle_hashes, SQLITE_MAX_VARIABLE_NUMBER):
puzzle_hashes_db: Tuple[Any, ...]
@ -387,13 +382,17 @@ class CoinStore:
f"coin_parent, amount, timestamp FROM coin_record INDEXED BY coin_puzzle_hash "
f'WHERE puzzle_hash in ({"?," * (len(puzzles) - 1)}?) '
f"AND (confirmed_index>=? OR spent_index>=?)"
f"{'' if include_spent_coins else 'AND spent_index=0'}",
puzzle_hashes_db + (min_height, min_height),
f"{'' if include_spent_coins else 'AND spent_index=0'}"
" LIMIT ?",
puzzle_hashes_db + (min_height, min_height, max_items - len(coins)),
) as cursor:
row: sqlite3.Row
async for row in cursor:
for row in await cursor.fetchall():
coins.add(self.row_to_coin_state(row))
if len(coins) >= max_items:
break
return list(coins)
async def get_coin_records_by_parent_ids(
@ -421,7 +420,6 @@ class CoinStore:
f"{'' if include_spent_coins else 'AND spent_index=0'}",
parent_ids_db + (start_height, end_height),
) as cursor:
async for row in cursor:
coin = self.row_to_coin(row)
coins.add(CoinRecord(coin, row[0], row[1], row[2], row[6]))
@ -433,11 +431,13 @@ class CoinStore:
include_spent_coins: bool,
coin_ids: List[bytes32],
min_height: uint32 = uint32(0),
*,
max_items: int = 50000,
) -> List[CoinState]:
if len(coin_ids) == 0:
return []
coins = set()
coins: Set[CoinState] = set()
async with self.db_wrapper.reader_no_transaction() as conn:
for ids in chunks(coin_ids, SQLITE_MAX_VARIABLE_NUMBER):
coin_ids_db: Tuple[Any, ...]
@ -449,11 +449,15 @@ class CoinStore:
f"SELECT confirmed_index, spent_index, coinbase, puzzle_hash, "
f'coin_parent, amount, timestamp FROM coin_record WHERE coin_name in ({"?," * (len(ids) - 1)}?) '
f"AND (confirmed_index>=? OR spent_index>=?)"
f"{'' if include_spent_coins else 'AND spent_index=0'}",
coin_ids_db + (min_height, min_height),
f"{'' if include_spent_coins else 'AND spent_index=0'}"
" LIMIT ?",
coin_ids_db + (min_height, min_height, max_items - len(coins)),
) as cursor:
async for row in cursor:
for row in await cursor.fetchall():
coins.add(self.row_to_coin_state(row))
if len(coins) >= max_items:
break
return list(coins)
async def rollback_to_block(self, block_index: int) -> List[CoinRecord]:
@ -501,7 +505,6 @@ class CoinStore:
# Store CoinRecord in DB
async def _add_coin_records(self, records: List[CoinRecord]) -> None:
if self.db_wrapper.db_version == 2:
values2 = []
for record in records:
@ -548,7 +551,6 @@ class CoinStore:
# Update coin_record to be spent in DB
async def _set_spent(self, coin_names: List[bytes32], index: uint32) -> None:
assert len(coin_names) == 0 or index > 0
if len(coin_names) == 0:

View File

@ -6,11 +6,26 @@ from typing import List
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRate
from chia.types.mempool_item import MempoolItem
from chia.types.mojos import Mojos
from chia.util.ints import uint32, uint64
@dataclass(frozen=True)
class MempoolItemInfo:
"""
The information the fee estimator is passed for each mempool item that's
added, removed from the mempool and included in blocks
"""
cost: int
fee: int
height_added_to_mempool: uint32
@property
def fee_per_cost(self) -> float:
return self.fee / self.cost
@dataclass(frozen=True)
class MempoolInfo:
"""
@ -37,7 +52,7 @@ class FeeMempoolInfo:
Attributes:
mempool_info (MempoolInfo): A `MempoolInfo`, defined above. Parameters of our mempool.
current_mempool_cost (uint64):This is the current capacity of the mempool, measured in XCH per CLVM Cost
current_mempool_fees (Mojos): Sum of fees for all spends waiting in the Mempool
current_mempool_fees (int): Sum of fees for all spends waiting in the Mempool
time (datetime): Local time this sample was taken
Note that we use the node's local time, not "Blockchain time" for the timestamp above
@ -45,7 +60,7 @@ class FeeMempoolInfo:
mempool_info: MempoolInfo
current_mempool_cost: CLVMCost # Current sum of CLVM cost of all SpendBundles in mempool (mempool "size")
current_mempool_fees: Mojos # Sum of fees for all spends waiting in the Mempool
current_mempool_fees: int # Sum of fees for all spends waiting in the Mempool
time: datetime # Local time this sample was taken
@ -57,7 +72,7 @@ EmptyMempoolInfo = MempoolInfo(
EmptyFeeMempoolInfo = FeeMempoolInfo(
EmptyMempoolInfo,
CLVMCost(uint64(0)),
Mojos(uint64(0)),
0,
datetime.min,
)
@ -75,4 +90,4 @@ class FeeBlockInfo: # See BlockRecord
"""
block_height: uint32
included_items: List[MempoolItem]
included_items: List[MempoolItemInfo]

View File

@ -3,11 +3,10 @@ from __future__ import annotations
from typing import Any, Dict, List
from chia.full_node.fee_estimate import FeeEstimateV2
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRateV2
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint64
MIN_MOJO_PER_COST = 5
@ -31,10 +30,10 @@ class FeeEstimatorExample(FeeEstimatorInterface):
def new_block(self, block_info: FeeBlockInfo) -> None:
pass
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
pass
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
pass
def estimate_fee_rate(self, *, time_offset_seconds: int) -> FeeRateV2:

View File

@ -2,10 +2,9 @@ from __future__ import annotations
from typing_extensions import Protocol
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo
from chia.full_node.fee_estimation import FeeBlockInfo, FeeMempoolInfo, MempoolItemInfo
from chia.types.clvm_cost import CLVMCost
from chia.types.fee_rate import FeeRateV2
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint32
@ -18,11 +17,11 @@ class FeeEstimatorInterface(Protocol):
"""A new transaction block has been added to the blockchain"""
pass
def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def add_mempool_item(self, mempool_item_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
"""A MempoolItem (transaction and associated info) has been added to the mempool"""
pass
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItem) -> None:
def remove_mempool_item(self, mempool_info: FeeMempoolInfo, mempool_item: MempoolItemInfo) -> None:
"""A MempoolItem (transaction and associated info) has been removed from the mempool"""
pass

View File

@ -6,6 +6,7 @@ from dataclasses import dataclass
from typing import List, Optional, Tuple
from chia.full_node.fee_estimate_store import FeeStore
from chia.full_node.fee_estimation import MempoolItemInfo
from chia.full_node.fee_estimator_constants import (
FEE_ESTIMATOR_VERSION,
INFINITE_FEE_RATE,
@ -26,7 +27,6 @@ from chia.full_node.fee_estimator_constants import (
SUFFICIENT_FEE_TXS,
)
from chia.full_node.fee_history import FeeStatBackup, FeeTrackerBackup
from chia.types.mempool_item import MempoolItem
from chia.util.ints import uint8, uint32, uint64
@ -129,7 +129,7 @@ class FeeStat: # TxConfirmStats
self.old_unconfirmed_txs = [0 for _ in range(0, len(buckets))]
def tx_confirmed(self, blocks_to_confirm: int, item: MempoolItem) -> None:
def tx_confirmed(self, blocks_to_confirm: int, item: MempoolItemInfo) -> None:
if blocks_to_confirm < 1:
raise ValueError("tx_confirmed called with < 1 block to confirm")
@ -164,7 +164,7 @@ class FeeStat: # TxConfirmStats
self.unconfirmed_txs[block_index][bucket_index] += 1
return bucket_index
def remove_tx(self, latest_seen_height: uint32, item: MempoolItem, bucket_index: int) -> None:
def remove_tx(self, latest_seen_height: uint32, item: MempoolItemInfo, bucket_index: int) -> None:
if item.height_added_to_mempool is None:
return
block_ago = latest_seen_height - item.height_added_to_mempool
@ -475,7 +475,7 @@ class FeeTracker:
)
self.fee_store.store_fee_data(backup)
def process_block(self, block_height: uint32, items: List[MempoolItem]) -> None:
def process_block(self, block_height: uint32, items: List[MempoolItemInfo]) -> None:
"""A new block has been farmed and these transactions have been included in that block"""
if block_height <= self.latest_seen_height:
# Ignore reorgs
@ -498,7 +498,7 @@ class FeeTracker:
self.first_recorded_height = block_height
self.log.info(f"Fee Estimator first recorded height: {self.first_recorded_height}")
def process_block_tx(self, current_height: uint32, item: MempoolItem) -> None:
def process_block_tx(self, current_height: uint32, item: MempoolItemInfo) -> None:
if item.height_added_to_mempool is None:
raise ValueError("process_block_tx called with item.height_added_to_mempool=None")
@ -510,8 +510,7 @@ class FeeTracker:
self.med_horizon.tx_confirmed(blocks_to_confirm, item)
self.long_horizon.tx_confirmed(blocks_to_confirm, item)
def add_tx(self, item: MempoolItem) -> None:
def add_tx(self, item: MempoolItemInfo) -> None:
if item.height_added_to_mempool < self.latest_seen_height:
self.log.info(f"Processing Item from pending pool: cost={item.cost} fee={item.fee}")
@ -522,7 +521,7 @@ class FeeTracker:
self.med_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
self.long_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
def remove_tx(self, item: MempoolItem) -> None:
def remove_tx(self, item: MempoolItemInfo) -> None:
bucket_index = get_bucket_index(self.buckets, item.fee_per_cost * 1000)
self.short_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
self.med_horizon.remove_tx(self.latest_seen_height, item, bucket_index)

View File

@ -17,7 +17,7 @@ from blspy import AugSchemeMPL
from chia.consensus.block_creation import unfinished_block_to_full_block
from chia.consensus.block_record import BlockRecord
from chia.consensus.blockchain import Blockchain, ReceiveBlockResult, StateChangeSummary
from chia.consensus.blockchain import AddBlockResult, Blockchain, StateChangeSummary
from chia.consensus.blockchain_interface import BlockchainInterface
from chia.consensus.constants import ConsensusConstants
from chia.consensus.cost_calculator import NPCResult
@ -115,7 +115,7 @@ class FullNode:
_transaction_queue: Optional[TransactionQueue]
_compact_vdf_sem: Optional[LimitedSemaphore]
_new_peak_sem: Optional[LimitedSemaphore]
_respond_transaction_semaphore: Optional[asyncio.Semaphore]
_add_transaction_semaphore: Optional[asyncio.Semaphore]
_db_wrapper: Optional[DBWrapper2]
_hint_store: Optional[HintStore]
transaction_responses: List[Tuple[bytes32, MempoolInclusionStatus, Optional[Err]]]
@ -180,7 +180,7 @@ class FullNode:
self._transaction_queue = None
self._compact_vdf_sem = None
self._new_peak_sem = None
self._respond_transaction_semaphore = None
self._add_transaction_semaphore = None
self._db_wrapper = None
self._hint_store = None
self.transaction_responses = []
@ -231,9 +231,9 @@ class FullNode:
return self._coin_store
@property
def respond_transaction_semaphore(self) -> asyncio.Semaphore:
assert self._respond_transaction_semaphore is not None
return self._respond_transaction_semaphore
def add_transaction_semaphore(self) -> asyncio.Semaphore:
assert self._add_transaction_semaphore is not None
return self._add_transaction_semaphore
@property
def transaction_queue(self) -> TransactionQueue:
@ -308,7 +308,7 @@ class FullNode:
self._new_peak_sem = LimitedSemaphore.create(active_limit=2, waiting_limit=20)
# These many respond_transaction tasks can be active at any point in time
self._respond_transaction_semaphore = asyncio.Semaphore(200)
self._add_transaction_semaphore = asyncio.Semaphore(200)
sql_log_path: Optional[Path] = None
if self.config.get("log_sqlite_cmds", False):
@ -442,7 +442,7 @@ class FullNode:
async def _handle_one_transaction(self, entry: TransactionQueueEntry) -> None:
peer = entry.peer
try:
inc_status, err = await self.respond_transaction(entry.transaction, entry.spend_name, peer, entry.test)
inc_status, err = await self.add_transaction(entry.transaction, entry.spend_name, peer, entry.test)
self.transaction_responses.append((entry.spend_name, inc_status, err))
if len(self.transaction_responses) > 50:
self.transaction_responses = self.transaction_responses[1:]
@ -455,14 +455,14 @@ class FullNode:
if peer is not None:
await peer.close()
finally:
self.respond_transaction_semaphore.release()
self.add_transaction_semaphore.release()
async def _handle_transactions(self) -> None:
try:
while not self._shut_down:
# We use a semaphore to make sure we don't send more than 200 concurrent calls of respond_transaction.
# However, doing them one at a time would be slow, because they get sent to other processes.
await self.respond_transaction_semaphore.acquire()
await self.add_transaction_semaphore.acquire()
item: TransactionQueueEntry = await self.transaction_queue.pop()
asyncio.create_task(self._handle_one_transaction(item))
except asyncio.CancelledError:
@ -578,7 +578,7 @@ class FullNode:
raise ValueError(f"Error short batch syncing, invalid/no response for {height}-{end_height}")
async with self._blockchain_lock_high_priority:
state_change_summary: Optional[StateChangeSummary]
success, state_change_summary = await self.receive_block_batch(response.blocks, peer, None)
success, state_change_summary = await self.add_block_batch(response.blocks, peer, None)
if not success:
raise ValueError(f"Error short batch syncing, failed to validate blocks {height}-{end_height}")
if state_change_summary is not None:
@ -629,7 +629,7 @@ class FullNode:
unfinished_block: Optional[UnfinishedBlock] = self.full_node_store.get_unfinished_block(target_unf_hash)
curr_height: int = target_height
found_fork_point = False
responses = []
blocks = []
while curr_height > peak_height - 5:
# If we already have the unfinished block, don't fetch the transactions. In the normal case, we will
# already have the unfinished block, from when it was broadcast, so we just need to download the header,
@ -644,14 +644,14 @@ class FullNode:
raise ValueError(
f"Failed to fetch block {curr_height} from {peer.get_peer_logging()}, wrong type {type(curr)}"
)
responses.append(curr)
blocks.append(curr.block)
if self.blockchain.contains_block(curr.block.prev_header_hash) or curr_height == 0:
found_fork_point = True
break
curr_height -= 1
if found_fork_point:
for response in reversed(responses):
await self.respond_block(response, peer)
for block in reversed(blocks):
await self.add_block(block, peer)
except (asyncio.CancelledError, Exception):
self.sync_store.backtrack_syncing[peer.peer_node_id] -= 1
raise
@ -988,7 +988,7 @@ class FullNode:
self.log.info(f"Total of {len(peers_with_peak)} peers with peak {target_peak.height}")
weight_proof_peer: WSChiaConnection = random.choice(peers_with_peak)
self.log.info(
f"Requesting weight proof from peer {weight_proof_peer.peer_host} up to height" f" {target_peak.height}"
f"Requesting weight proof from peer {weight_proof_peer.peer_host} up to height {target_peak.height}"
)
cur_peak: Optional[BlockRecord] = self.blockchain.get_peak()
if cur_peak is not None and target_peak.weight <= cur_peak.weight:
@ -1110,7 +1110,7 @@ class FullNode:
peer, blocks = res
start_height = blocks[0].height
end_height = blocks[-1].height
success, state_change_summary = await self.receive_block_batch(
success, state_change_summary = await self.add_block_batch(
blocks, peer, None if advanced_peak else uint32(fork_point_height), summaries
)
if success is False:
@ -1171,7 +1171,7 @@ class FullNode:
lookup_coin_ids: List[bytes32],
) -> None:
# Looks up coin records in DB for the coins that wallets are interested in
new_states: List[CoinRecord] = await self.coin_store.get_coin_records(list(lookup_coin_ids))
new_states: List[CoinRecord] = await self.coin_store.get_coin_records(lookup_coin_ids)
# Re-arrange to a map, and filter out any non-ph sized hint
coin_id_to_ph_hint: Dict[bytes32, bytes32] = {
@ -1179,7 +1179,7 @@ class FullNode:
}
changes_for_peer: Dict[bytes32, Set[CoinState]] = {}
for coin_record in state_change_summary.rolled_back_records + [s for s in new_states if s is not None]:
for coin_record in state_change_summary.rolled_back_records + new_states:
cr_name: bytes32 = coin_record.name
for peer in self.subscriptions.peers_for_coin_id(cr_name):
@ -1211,7 +1211,7 @@ class FullNode:
msg = make_msg(ProtocolMessageTypes.coin_state_update, state)
await ws_peer.send_message(msg)
async def receive_block_batch(
async def add_block_batch(
self,
all_blocks: List[FullBlock],
peer: WSChiaConnection,
@ -1256,11 +1256,11 @@ class FullNode:
assert pre_validation_results[i].required_iters is not None
state_change_summary: Optional[StateChangeSummary]
advanced_peak = agg_state_change_summary is not None
result, error, state_change_summary = await self.blockchain.receive_block(
result, error, state_change_summary = await self.blockchain.add_block(
block, pre_validation_results[i], None if advanced_peak else fork_point
)
if result == ReceiveBlockResult.NEW_PEAK:
if result == AddBlockResult.NEW_PEAK:
assert state_change_summary is not None
# Since all blocks are contiguous, we can simply append the rollback changes and npc results
if agg_state_change_summary is None:
@ -1275,7 +1275,7 @@ class FullNode:
agg_state_change_summary.new_npc_results + state_change_summary.new_npc_results,
agg_state_change_summary.new_rewards + state_change_summary.new_rewards,
)
elif result == ReceiveBlockResult.INVALID_BLOCK or result == ReceiveBlockResult.DISCONNECTED_BLOCK:
elif result == AddBlockResult.INVALID_BLOCK or result == AddBlockResult.DISCONNECTED_BLOCK:
if error is not None:
self.log.error(f"Error: {error}, Invalid block from peer: {peer.get_peer_logging()} ")
return False, agg_state_change_summary
@ -1581,16 +1581,15 @@ class FullNode:
await self.server.send_to_all([msg], NodeType.WALLET)
self._state_changed("new_peak")
async def respond_block(
async def add_block(
self,
respond_block: full_node_protocol.RespondBlock,
block: FullBlock,
peer: Optional[WSChiaConnection] = None,
raise_on_disconnected: bool = False,
) -> Optional[Message]:
"""
Receive a full block from a peer full node (or ourselves).
Add a full block from a peer full node (or ourselves).
"""
block: FullBlock = respond_block.block
if self.sync_store.get_sync_mode():
return None
@ -1649,7 +1648,7 @@ class FullNode:
f"same farmer with the same pospace."
)
# This recursion ends here, we cannot recurse again because transactions_generator is not None
return await self.respond_block(block_response, peer)
return await self.add_block(new_block, peer)
state_change_summary: Optional[StateChangeSummary] = None
ppp_result: Optional[PeakPostProcessingResult] = None
async with self._blockchain_lock_high_priority:
@ -1667,14 +1666,14 @@ class FullNode:
pre_validation_results = await self.blockchain.pre_validate_blocks_multiprocessing(
[block], npc_results, validate_signatures=False
)
added: Optional[ReceiveBlockResult] = None
added: Optional[AddBlockResult] = None
pre_validation_time = time.time() - validation_start
try:
if len(pre_validation_results) < 1:
raise ValueError(f"Failed to validate block {header_hash} height {block.height}")
if pre_validation_results[0].error is not None:
if Err(pre_validation_results[0].error) == Err.INVALID_PREV_BLOCK_HASH:
added = ReceiveBlockResult.DISCONNECTED_BLOCK
added = AddBlockResult.DISCONNECTED_BLOCK
error_code: Optional[Err] = Err.INVALID_PREV_BLOCK_HASH
else:
raise ValueError(
@ -1686,36 +1685,36 @@ class FullNode:
pre_validation_results[0] if pre_validation_result is None else pre_validation_result
)
assert result_to_validate.required_iters == pre_validation_results[0].required_iters
(added, error_code, state_change_summary) = await self.blockchain.receive_block(
(added, error_code, state_change_summary) = await self.blockchain.add_block(
block, result_to_validate, None
)
if added == ReceiveBlockResult.ALREADY_HAVE_BLOCK:
if added == AddBlockResult.ALREADY_HAVE_BLOCK:
return None
elif added == ReceiveBlockResult.INVALID_BLOCK:
elif added == AddBlockResult.INVALID_BLOCK:
assert error_code is not None
self.log.error(f"Block {header_hash} at height {block.height} is invalid with code {error_code}.")
raise ConsensusError(error_code, [header_hash])
elif added == ReceiveBlockResult.DISCONNECTED_BLOCK:
elif added == AddBlockResult.DISCONNECTED_BLOCK:
self.log.info(f"Disconnected block {header_hash} at height {block.height}")
if raise_on_disconnected:
raise RuntimeError("Expected block to be added, received disconnected block.")
return None
elif added == ReceiveBlockResult.NEW_PEAK:
elif added == AddBlockResult.NEW_PEAK:
# Only propagate blocks which extend the blockchain (becomes one of the heads)
assert state_change_summary is not None
ppp_result = await self.peak_post_processing(block, state_change_summary, peer)
elif added == ReceiveBlockResult.ADDED_AS_ORPHAN:
elif added == AddBlockResult.ADDED_AS_ORPHAN:
self.log.info(
f"Received orphan block of height {block.height} rh " f"{block.reward_chain_block.get_hash()}"
f"Received orphan block of height {block.height} rh {block.reward_chain_block.get_hash()}"
)
else:
# Should never reach here, all the cases are covered
raise RuntimeError(f"Invalid result from receive_block {added}")
raise RuntimeError(f"Invalid result from add_block {added}")
except asyncio.CancelledError:
# We need to make sure to always call this method even when we get a cancel exception, to make sure
# the node stays in sync
if added == ReceiveBlockResult.NEW_PEAK:
if added == AddBlockResult.NEW_PEAK:
assert state_change_summary is not None
await self.peak_post_processing(block, state_change_summary, peer)
raise
@ -1786,19 +1785,18 @@ class FullNode:
self._segment_task = asyncio.create_task(self.weight_proof_handler.create_prev_sub_epoch_segments())
return None
async def respond_unfinished_block(
async def add_unfinished_block(
self,
respond_unfinished_block: full_node_protocol.RespondUnfinishedBlock,
block: UnfinishedBlock,
peer: Optional[WSChiaConnection],
farmed_block: bool = False,
block_bytes: Optional[bytes] = None,
) -> None:
"""
We have received an unfinished block, either created by us, or from another peer.
We can validate it and if it's a good block, propagate it to other peers and
We can validate and add it and if it's a good block, propagate it to other peers and
timelords.
"""
block = respond_unfinished_block.unfinished_block
receive_time = time.time()
if block.prev_header_hash != self.constants.GENESIS_CHALLENGE and not self.blockchain.contains_block(
@ -1883,7 +1881,11 @@ class FullNode:
# blockchain.run_generator throws on errors, so npc_result is
# guaranteed to represent a successful run
assert npc_result.conds is not None
pairs_pks, pairs_msgs = pkm_pairs(npc_result.conds, self.constants.AGG_SIG_ME_ADDITIONAL_DATA)
pairs_pks, pairs_msgs = pkm_pairs(
npc_result.conds,
self.constants.AGG_SIG_ME_ADDITIONAL_DATA,
soft_fork=height >= self.constants.SOFT_FORK_HEIGHT,
)
if not cached_bls.aggregate_verify(
pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature, True
):
@ -2090,7 +2092,7 @@ class FullNode:
self.log.warning("Trying to make a pre-farm block but height is not 0")
return None
try:
await self.respond_block(full_node_protocol.RespondBlock(block), raise_on_disconnected=True)
await self.add_block(block, raise_on_disconnected=True)
except Exception as e:
self.log.warning(f"Consensus error validating block: {e}")
if timelord_peer is not None:
@ -2098,11 +2100,10 @@ class FullNode:
await self.send_peak_to_timelords(peer=timelord_peer)
return None
async def respond_end_of_sub_slot(
self, request: full_node_protocol.RespondEndOfSubSlot, peer: WSChiaConnection
async def add_end_of_sub_slot(
self, end_of_slot_bundle: EndOfSubSlotBundle, peer: WSChiaConnection
) -> Tuple[Optional[Message], bool]:
fetched_ss = self.full_node_store.get_sub_slot(request.end_of_slot_bundle.challenge_chain.get_hash())
fetched_ss = self.full_node_store.get_sub_slot(end_of_slot_bundle.challenge_chain.get_hash())
# We are not interested in sub-slots which have the same challenge chain but different reward chain. If there
# is a reorg, we will find out through the broadcast of blocks instead.
@ -2112,16 +2113,16 @@ class FullNode:
async with self.timelord_lock:
fetched_ss = self.full_node_store.get_sub_slot(
request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
)
if (
(fetched_ss is None)
and request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
and end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge
!= self.constants.GENESIS_CHALLENGE
):
# If we don't have the prev, request the prev instead
full_node_request = full_node_protocol.RequestSignagePointOrEndOfSubSlot(
request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
uint8(0),
bytes32([0] * 32),
)
@ -2140,7 +2141,7 @@ class FullNode:
# Adds the sub slot and potentially get new infusions
new_infusions = self.full_node_store.new_finished_sub_slot(
request.end_of_slot_bundle,
end_of_slot_bundle,
self.blockchain,
peak,
await self.blockchain.get_full_peak(),
@ -2149,19 +2150,19 @@ class FullNode:
if new_infusions is not None:
self.log.info(
f"⏲️ Finished sub slot, SP {self.constants.NUM_SPS_SUB_SLOT}/{self.constants.NUM_SPS_SUB_SLOT}, "
f"{request.end_of_slot_bundle.challenge_chain.get_hash()}, "
f"{end_of_slot_bundle.challenge_chain.get_hash()}, "
f"number of sub-slots: {len(self.full_node_store.finished_sub_slots)}, "
f"RC hash: {request.end_of_slot_bundle.reward_chain.get_hash()}, "
f"Deficit {request.end_of_slot_bundle.reward_chain.deficit}"
f"RC hash: {end_of_slot_bundle.reward_chain.get_hash()}, "
f"Deficit {end_of_slot_bundle.reward_chain.deficit}"
)
# Reset farmer response timer for sub slot (SP 0)
self.signage_point_times[0] = time.time()
# Notify full nodes of the new sub-slot
broadcast = full_node_protocol.NewSignagePointOrEndOfSubSlot(
request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
request.end_of_slot_bundle.challenge_chain.get_hash(),
end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge,
end_of_slot_bundle.challenge_chain.get_hash(),
uint8(0),
request.end_of_slot_bundle.reward_chain.end_of_slot_vdf.challenge,
end_of_slot_bundle.reward_chain.end_of_slot_vdf.challenge,
)
msg = make_msg(ProtocolMessageTypes.new_signage_point_or_end_of_sub_slot, broadcast)
await self.server.send_to_all([msg], NodeType.FULL_NODE, peer.peer_node_id)
@ -2171,9 +2172,9 @@ class FullNode:
# Notify farmers of the new sub-slot
broadcast_farmer = farmer_protocol.NewSignagePoint(
request.end_of_slot_bundle.challenge_chain.get_hash(),
request.end_of_slot_bundle.challenge_chain.get_hash(),
request.end_of_slot_bundle.reward_chain.get_hash(),
end_of_slot_bundle.challenge_chain.get_hash(),
end_of_slot_bundle.challenge_chain.get_hash(),
end_of_slot_bundle.reward_chain.get_hash(),
next_difficulty,
next_sub_slot_iters,
uint8(0),
@ -2184,11 +2185,11 @@ class FullNode:
else:
self.log.info(
f"End of slot not added CC challenge "
f"{request.end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge}"
f"{end_of_slot_bundle.challenge_chain.challenge_chain_end_of_slot_vdf.challenge}"
)
return None, False
async def respond_transaction(
async def add_transaction(
self,
transaction: SpendBundle,
spend_name: bytes32,
@ -2234,8 +2235,8 @@ class FullNode:
if status == MempoolInclusionStatus.SUCCESS:
self.log.debug(
f"Added transaction to mempool: {spend_name} mempool size: "
f"{self.mempool_manager.mempool.total_mempool_cost} normalized "
f"{self.mempool_manager.mempool.total_mempool_cost / 5000000}"
f"{self.mempool_manager.mempool.total_mempool_cost()} normalized "
f"{self.mempool_manager.mempool.total_mempool_cost() / 5000000}"
)
# Only broadcast successful transactions, not pending ones. Otherwise it's a DOS
@ -2259,9 +2260,7 @@ class FullNode:
await self.simulator_transaction_callback(spend_name) # pylint: disable=E1102
else:
self.mempool_manager.remove_seen(spend_name)
self.log.debug(
f"Wasn't able to add transaction with id {spend_name}, " f"status {status} error: {error}"
)
self.log.debug(f"Wasn't able to add transaction with id {spend_name}, status {status} error: {error}")
return status, error
async def _needs_compact_proof(
@ -2355,7 +2354,6 @@ class FullNode:
header_hash: bytes32,
field_vdf: CompressibleVDFField,
) -> bool:
block = await self.block_store.get_full_block(header_hash)
if block is None:
return False
@ -2403,7 +2401,7 @@ class FullNode:
)
raise
async def respond_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
async def add_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
field_vdf = CompressibleVDFField(int(request.field_vdf))
if not await self._can_accept_compact_proof(
request.vdf_info, request.vdf_proof, request.height, request.header_hash, field_vdf
@ -2438,7 +2436,7 @@ class FullNode:
)
response = await peer.call_api(FullNodeAPI.request_compact_vdf, peer_request, timeout=10)
if response is not None and isinstance(response, full_node_protocol.RespondCompactVDF):
await self.respond_compact_vdf(response, peer)
await self.add_compact_vdf(response, peer)
async def request_compact_vdf(self, request: full_node_protocol.RequestCompactVDF, peer: WSChiaConnection) -> None:
header_block = await self.blockchain.get_header_block_by_height(
@ -2484,7 +2482,7 @@ class FullNode:
msg = make_msg(ProtocolMessageTypes.respond_compact_vdf, compact_vdf)
await peer.send_message(msg)
async def respond_compact_vdf(self, request: full_node_protocol.RespondCompactVDF, peer: WSChiaConnection) -> None:
async def add_compact_vdf(self, request: full_node_protocol.RespondCompactVDF, peer: WSChiaConnection) -> None:
field_vdf = CompressibleVDFField(int(request.field_vdf))
if not await self._can_accept_compact_proof(
request.vdf_info, request.vdf_proof, request.height, request.header_hash, field_vdf
@ -2521,7 +2519,6 @@ class FullNode:
self.log.info("Heights found for bluebox to compact: [%s]" % ", ".join(map(str, heights)))
for h in heights:
headers = await self.blockchain.get_header_blocks_in_range(h, h, tx_filter=False)
records: Dict[bytes32, BlockRecord] = {}
if sanitize_weight_proof_only:
@ -2612,7 +2609,6 @@ class FullNode:
async def node_next_block_check(
peer: WSChiaConnection, potential_peek: uint32, blockchain: BlockchainInterface
) -> bool:
block_response: Optional[Any] = await peer.call_api(
FullNodeAPI.request_block, full_node_protocol.RequestBlock(potential_peek, True)
)

View File

@ -49,8 +49,8 @@ from chia.types.end_of_slot_bundle import EndOfSubSlotBundle
from chia.types.full_block import FullBlock
from chia.types.generator_types import BlockGenerator
from chia.types.mempool_inclusion_status import MempoolInclusionStatus
from chia.types.mempool_item import MempoolItem
from chia.types.peer_info import PeerInfo
from chia.types.spend_bundle import SpendBundle
from chia.types.transaction_queue_entry import TransactionQueueEntry
from chia.types.unfinished_block import UnfinishedBlock
from chia.util.api_decorators import api_request
@ -106,7 +106,7 @@ class FullNodeAPI:
) -> Optional[Message]:
self.log.debug(f"Received {len(request.peer_list)} peers")
if self.full_node.full_node_peers is not None:
await self.full_node.full_node_peers.respond_peers(request, peer.get_peer_info(), True)
await self.full_node.full_node_peers.add_peers(request.peer_list, peer.get_peer_info(), True)
return None
@api_request(peer_required=True)
@ -115,7 +115,7 @@ class FullNodeAPI:
) -> Optional[Message]:
self.log.debug(f"Received {len(request.peer_list)} peers from introducer")
if self.full_node.full_node_peers is not None:
await self.full_node.full_node_peers.respond_peers(request, peer.get_peer_info(), False)
await self.full_node.full_node_peers.add_peers(request.peer_list, peer.get_peer_info(), False)
await peer.close()
return None
@ -465,8 +465,8 @@ class FullNodeAPI:
) -> Optional[Message]:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_unfinished_block(
respond_unfinished_block, peer, block_bytes=respond_unfinished_block_bytes
await self.full_node.add_unfinished_block(
respond_unfinished_block.unfinished_block, peer, block_bytes=respond_unfinished_block_bytes
)
return None
@ -558,7 +558,6 @@ class FullNodeAPI:
async def request_signage_point_or_end_of_sub_slot(
self, request: full_node_protocol.RequestSignagePointOrEndOfSubSlot
) -> Optional[Message]:
if request.index_from_challenge == 0:
sub_slot: Optional[Tuple[EndOfSubSlotBundle, int, uint128]] = self.full_node.full_node_store.get_sub_slot(
request.challenge_hash
@ -619,7 +618,6 @@ class FullNodeAPI:
return None
peak = self.full_node.blockchain.get_peak()
if peak is not None and peak.height > self.full_node.constants.MAX_SUB_SLOT_BLOCKS:
next_sub_slot_iters = self.full_node.blockchain.get_next_slot_iters(peak.header_hash, True)
sub_slots_for_peak = await self.full_node.blockchain.get_sp_and_ip_sub_slots(peak.header_hash)
assert sub_slots_for_peak is not None
@ -658,7 +656,7 @@ class FullNodeAPI:
) -> Optional[Message]:
if self.full_node.sync_store.get_sync_mode():
return None
msg, _ = await self.full_node.respond_end_of_sub_slot(request, peer)
msg, _ = await self.full_node.add_end_of_sub_slot(request.end_of_slot_bundle, peer)
return msg
@api_request(peer_required=True)
@ -669,10 +667,10 @@ class FullNodeAPI:
) -> Optional[Message]:
received_filter = PyBIP158(bytearray(request.filter))
items: List[MempoolItem] = await self.full_node.mempool_manager.get_items_not_in_filter(received_filter)
items: List[SpendBundle] = self.full_node.mempool_manager.get_items_not_in_filter(received_filter)
for item in items:
transaction = full_node_protocol.RespondTransaction(item.spend_bundle)
transaction = full_node_protocol.RespondTransaction(item)
msg = make_msg(ProtocolMessageTypes.respond_transaction, transaction)
await peer.send_message(msg)
return None
@ -1008,12 +1006,11 @@ class FullNodeAPI:
return None
# Propagate to ourselves (which validates and does further propagations)
request = full_node_protocol.RespondUnfinishedBlock(new_candidate)
try:
await self.full_node.respond_unfinished_block(request, None, True)
await self.full_node.add_unfinished_block(new_candidate, None, True)
except Exception as e:
# If we have an error with this block, try making an empty block
self.full_node.log.error(f"Error farming block {e} {request}")
self.full_node.log.error(f"Error farming block {e} {new_candidate}")
candidate_tuple = self.full_node.full_node_store.get_candidate_block(
farmer_request.quality_string, backup=True
)
@ -1071,8 +1068,7 @@ class FullNodeAPI:
):
return None
# Calls our own internal message to handle the end of sub slot, and potentially broadcasts to other peers.
full_node_message = full_node_protocol.RespondEndOfSubSlot(request.end_of_sub_slot_bundle)
msg, added = await self.full_node.respond_end_of_sub_slot(full_node_message, peer)
msg, added = await self.full_node.add_end_of_sub_slot(request.end_of_sub_slot_bundle, peer)
if not added:
self.log.error(
f"Was not able to add end of sub-slot: "
@ -1098,7 +1094,6 @@ class FullNodeAPI:
tx_additions: List[Coin] = []
if block.transactions_generator is not None:
block_generator: Optional[BlockGenerator] = await self.full_node.blockchain.get_block_generator(block)
# get_block_generator() returns None in case the block we specify
# does not have a generator (i.e. is not a transaction block).
@ -1112,7 +1107,6 @@ class FullNodeAPI:
get_name_puzzle_conditions,
block_generator,
self.full_node.constants.MAX_BLOCK_COST_CLVM,
cost_per_byte=self.full_node.constants.COST_PER_BYTE,
mempool_mode=False,
height=request.height,
),
@ -1411,7 +1405,7 @@ class FullNodeAPI:
async def respond_compact_proof_of_time(self, request: timelord_protocol.RespondCompactProofOfTime) -> None:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_compact_proof_of_time(request)
await self.full_node.add_compact_proof_of_time(request)
return None
@api_request(peer_required=True, bytes_required=True, execute_task=True)
@ -1452,37 +1446,77 @@ class FullNodeAPI:
async def respond_compact_vdf(self, request: full_node_protocol.RespondCompactVDF, peer: WSChiaConnection) -> None:
if self.full_node.sync_store.get_sync_mode():
return None
await self.full_node.respond_compact_vdf(request, peer)
await self.full_node.add_compact_vdf(request, peer)
return None
@api_request(peer_required=True)
async def register_interest_in_puzzle_hash(
self, request: wallet_protocol.RegisterForPhUpdates, peer: WSChiaConnection
) -> Message:
if self.is_trusted(peer):
max_items = self.full_node.config.get("trusted_max_subscribe_items", 2000000)
trusted = self.is_trusted(peer)
if trusted:
max_subscriptions = self.full_node.config.get("trusted_max_subscribe_items", 2000000)
max_items = self.full_node.config.get("trusted_max_subscribe_response_items", 500000)
else:
max_items = self.full_node.config.get("max_subscribe_items", 200000)
max_subscriptions = self.full_node.config.get("max_subscribe_items", 200000)
max_items = self.full_node.config.get("max_subscribe_response_items", 100000)
self.full_node.subscriptions.add_ph_subscriptions(peer.peer_node_id, request.puzzle_hashes, max_items)
# the returned puzzle hashes are the ones we ended up subscribing to.
# It will have filtered duplicates and ones exceeding the subscription
# limit.
puzzle_hashes = self.full_node.subscriptions.add_ph_subscriptions(
peer.peer_node_id, request.puzzle_hashes, max_subscriptions
)
hint_coin_ids = []
for puzzle_hash in request.puzzle_hashes:
ph_hint_coins = await self.full_node.hint_store.get_coin_ids(puzzle_hash)
hint_coin_ids.extend(ph_hint_coins)
start_time = time.monotonic()
# Note that coin state updates may arrive out-of-order on the client side.
# We add the subscription before we're done collecting all the coin
# state that goes into the response. CoinState updates may be sent
# before we send the response
# Send all coins with requested puzzle hash that have been created after the specified height
states: List[CoinState] = await self.full_node.coin_store.get_coin_states_by_puzzle_hashes(
include_spent_coins=True, puzzle_hashes=request.puzzle_hashes, min_height=request.min_height
include_spent_coins=True, puzzle_hashes=puzzle_hashes, min_height=request.min_height, max_items=max_items
)
max_items -= len(states)
hint_coin_ids: Set[bytes32] = set()
if max_items > 0:
for puzzle_hash in puzzle_hashes:
ph_hint_coins = await self.full_node.hint_store.get_coin_ids(puzzle_hash, max_items=max_items)
hint_coin_ids.update(ph_hint_coins)
max_items -= len(ph_hint_coins)
if max_items <= 0:
break
hint_states: List[CoinState] = []
if len(hint_coin_ids) > 0:
hint_states = await self.full_node.coin_store.get_coin_states_by_ids(
include_spent_coins=True, coin_ids=hint_coin_ids, min_height=request.min_height
include_spent_coins=True,
coin_ids=list(hint_coin_ids),
min_height=request.min_height,
max_items=len(hint_coin_ids),
)
states.extend(hint_states)
end_time = time.monotonic()
truncated = max_items <= 0
if truncated or end_time - start_time > 5:
self.log.log(
logging.WARNING if trusted and truncated else logging.INFO,
"RegisterForPhUpdates resulted in %d coin states. "
"Request had %d (unique) puzzle hashes and matched %d hints. %s"
"The request took %0.2fs",
len(states),
len(puzzle_hashes),
len(hint_states),
"The response was truncated. " if truncated else "",
end_time - start_time,
)
response = wallet_protocol.RespondToPhUpdates(request.puzzle_hashes, request.min_height, states)
msg = make_msg(ProtocolMessageTypes.respond_to_ph_update, response)
return msg
@ -1491,15 +1525,20 @@ class FullNodeAPI:
async def register_interest_in_coin(
self, request: wallet_protocol.RegisterForCoinUpdates, peer: WSChiaConnection
) -> Message:
if self.is_trusted(peer):
max_items = self.full_node.config.get("trusted_max_subscribe_items", 2000000)
max_subscriptions = self.full_node.config.get("trusted_max_subscribe_items", 2000000)
max_items = self.full_node.config.get("trusted_max_subscribe_response_items", 500000)
else:
max_items = self.full_node.config.get("max_subscribe_items", 200000)
self.full_node.subscriptions.add_coin_subscriptions(peer.peer_node_id, request.coin_ids, max_items)
max_subscriptions = self.full_node.config.get("max_subscribe_items", 200000)
max_items = self.full_node.config.get("max_subscribe_response_items", 100000)
# TODO: apparently we have tests that expect to receive a
# RespondToCoinUpdates even when subscribing to the same coin multiple
# times, so we can't optimize away such DB lookups (yet)
self.full_node.subscriptions.add_coin_subscriptions(peer.peer_node_id, request.coin_ids, max_subscriptions)
states: List[CoinState] = await self.full_node.coin_store.get_coin_states_by_ids(
include_spent_coins=True, coin_ids=request.coin_ids, min_height=request.min_height
include_spent_coins=True, coin_ids=request.coin_ids, min_height=request.min_height, max_items=max_items
)
response = wallet_protocol.RespondToCoinUpdates(request.coin_ids, request.min_height, states)

View File

@ -805,6 +805,6 @@ class FullNodeStore:
found_last_challenge = True
break
if not found_last_challenge:
log.warning(f"Did not find hash {last_challenge_to_add} connected to " f"{challenge_in_chain}")
log.warning(f"Did not find hash {last_challenge_to_add} connected to {challenge_in_chain}")
return None
return collected_sub_slots

View File

@ -33,15 +33,12 @@ class HintStore:
await conn.execute("CREATE INDEX IF NOT EXISTS hint_index on hints(hint)")
return self
async def get_coin_ids(self, hint: bytes) -> List[bytes32]:
async def get_coin_ids(self, hint: bytes, *, max_items: int = 50000) -> List[bytes32]:
async with self.db_wrapper.reader_no_transaction() as conn:
cursor = await conn.execute("SELECT coin_id from hints WHERE hint=?", (hint,))
cursor = await conn.execute("SELECT coin_id from hints WHERE hint=? LIMIT ?", (hint, max_items))
rows = await cursor.fetchall()
await cursor.close()
coin_ids = []
for row in rows:
coin_ids.append(row[0])
return coin_ids
return [bytes32(row[0]) for row in rows]
async def add_hints(self, coin_hint_list: List[Tuple[bytes32, bytes]]) -> None:
if len(coin_hint_list) == 0:

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import asyncio
import dataclasses
import logging
import traceback
from types import TracebackType
from typing import Awaitable, Callable
@ -58,8 +57,7 @@ class LockQueue:
await prioritized_callback.af()
await self._release_event.wait()
except asyncio.CancelledError:
error_stack = traceback.format_exc()
log.debug(f"LockQueue._run() cancelled: {error_stack}")
log.debug("LockQueue._run() cancelled")
def close(self) -> None:
self._run_task.cancel()

View File

@ -1,38 +1,179 @@
from __future__ import annotations
import logging
import sqlite3
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from typing import Callable, Dict, Iterator, List, Optional, Tuple
from sortedcontainers import SortedDict
from chia_rs import Coin
from chia.full_node.fee_estimation import FeeMempoolInfo, MempoolInfo
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.fee_estimation import FeeMempoolInfo, MempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.clvm_cost import CLVMCost
from chia.types.mempool_item import MempoolItem
from chia.types.mojos import Mojos
from chia.util.ints import uint64
from chia.types.spend_bundle import SpendBundle
from chia.util.chunks import chunks
from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER
from chia.util.ints import uint32, uint64
log = logging.getLogger(__name__)
# We impose a limit on the fee a single transaction can pay in order to have the
# sum of all fees in the mempool be less than 2^63. That's the limit of sqlite's
# integers, which we rely on for computing fee per cost as well as the fee sum
MEMPOOL_ITEM_FEE_LIMIT = 2**50
SQLITE_NO_GENERATED_COLUMNS: bool = sqlite3.sqlite_version_info < (3, 31, 0)
class MempoolRemoveReason(Enum):
CONFLICT = 1
BLOCK_INCLUSION = 2
POOL_FULL = 3
EXPIRED = 4
@dataclass(frozen=True)
class InternalMempoolItem:
spend_bundle: SpendBundle
npc_result: NPCResult
height_added_to_mempool: uint32
class Mempool:
_db_conn: sqlite3.Connection
# it's expensive to serialize and deserialize G2Element, so we keep those in
# this separate dictionary
_items: Dict[bytes32, InternalMempoolItem]
def __init__(self, mempool_info: MempoolInfo, fee_estimator: FeeEstimatorInterface):
self.log: logging.Logger = logging.getLogger(__name__)
self.spends: Dict[bytes32, MempoolItem] = {}
self.sorted_spends: SortedDict = SortedDict()
self._db_conn = sqlite3.connect(":memory:")
self._items = {}
with self._db_conn:
# name means SpendBundle hash
# assert_height may be NIL
generated = ""
if not SQLITE_NO_GENERATED_COLUMNS:
generated = " GENERATED ALWAYS AS (CAST(fee AS REAL) / cost) VIRTUAL"
# the seq field indicates the order of items being added to the
# mempool. It's used as a tie-breaker for items with the same fee
# rate
self._db_conn.execute(
f"""CREATE TABLE tx(
name BLOB,
cost INT NOT NULL,
fee INT NOT NULL,
assert_height INT,
assert_before_height INT,
assert_before_seconds INT,
fee_per_cost REAL{generated},
seq INTEGER PRIMARY KEY AUTOINCREMENT)
"""
)
self._db_conn.execute("CREATE INDEX name_idx ON tx(name)")
self._db_conn.execute("CREATE INDEX fee_sum ON tx(fee)")
self._db_conn.execute("CREATE INDEX cost_sum ON tx(cost)")
self._db_conn.execute("CREATE INDEX feerate ON tx(fee_per_cost)")
self._db_conn.execute(
"CREATE INDEX assert_before_height ON tx(assert_before_height) WHERE assert_before_height != NULL"
)
self._db_conn.execute(
"CREATE INDEX assert_before_seconds ON tx(assert_before_seconds) WHERE assert_before_seconds != NULL"
)
# This table maps coin IDs to spend bundles hashes
self._db_conn.execute(
"""CREATE TABLE spends(
coin_id BLOB NOT NULL,
tx BLOB NOT NULL,
UNIQUE(coin_id, tx))
"""
)
self._db_conn.execute("CREATE INDEX spend_by_coin ON spends(coin_id)")
self._db_conn.execute("CREATE INDEX spend_by_bundle ON spends(tx)")
self.mempool_info: MempoolInfo = mempool_info
self.fee_estimator: FeeEstimatorInterface = fee_estimator
self.removal_coin_id_to_spendbundle_ids: Dict[bytes32, List[bytes32]] = {}
self.total_mempool_cost: CLVMCost = CLVMCost(uint64(0))
self.total_mempool_fees: Mojos = Mojos(uint64(0))
def __del__(self) -> None:
self._db_conn.close()
def _row_to_item(self, row: sqlite3.Row) -> MempoolItem:
name = bytes32(row[0])
fee = int(row[2])
assert_height = row[3]
assert_before_height = row[4]
assert_before_seconds = row[5]
item = self._items[name]
return MempoolItem(
item.spend_bundle,
uint64(fee),
item.npc_result,
name,
uint32(item.height_added_to_mempool),
assert_height,
assert_before_height,
assert_before_seconds,
)
def total_mempool_fees(self) -> int:
with self._db_conn:
cursor = self._db_conn.execute("SELECT SUM(fee) FROM tx")
val = cursor.fetchone()[0]
return uint64(0) if val is None else uint64(val)
def total_mempool_cost(self) -> CLVMCost:
with self._db_conn:
cursor = self._db_conn.execute("SELECT SUM(cost) FROM tx")
val = cursor.fetchone()[0]
return CLVMCost(uint64(0) if val is None else uint64(val))
def all_spends(self) -> Iterator[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT * FROM tx")
for row in cursor:
yield self._row_to_item(row)
def all_spend_ids(self) -> List[bytes32]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT name FROM tx")
return [bytes32(row[0]) for row in cursor]
# TODO: move "process_mempool_items()" into this class in order to do this a
# bit more efficiently
def spends_by_feerate(self) -> Iterator[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT * FROM tx ORDER BY fee_per_cost DESC, seq ASC")
for row in cursor:
yield self._row_to_item(row)
def size(self) -> int:
with self._db_conn:
cursor = self._db_conn.execute("SELECT Count(name) FROM tx")
val = cursor.fetchone()
return 0 if val is None else int(val[0])
def get_spend_by_id(self, spend_bundle_id: bytes32) -> Optional[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute("SELECT * FROM tx WHERE name=?", (spend_bundle_id,))
row = cursor.fetchone()
return None if row is None else self._row_to_item(row)
# TODO: we need a bulk lookup function like this too
def get_spends_by_coin_id(self, spent_coin_id: bytes32) -> List[MempoolItem]:
with self._db_conn:
cursor = self._db_conn.execute(
"SELECT * FROM tx WHERE name in (SELECT tx FROM spends WHERE coin_id=?)",
(spent_coin_id,),
)
return [self._row_to_item(row) for row in cursor]
def get_min_fee_rate(self, cost: int) -> float:
"""
@ -40,82 +181,185 @@ class Mempool:
"""
if self.at_full_capacity(cost):
current_cost = self.total_mempool_cost
# TODO: make MempoolItem.cost be CLVMCost
current_cost = int(self.total_mempool_cost())
# Iterates through all spends in increasing fee per cost
fee_per_cost: float
for fee_per_cost, spends_with_fpc in self.sorted_spends.items():
for spend_name, item in spends_with_fpc.items():
current_cost -= item.cost
with self._db_conn:
cursor = self._db_conn.execute("SELECT cost,fee_per_cost FROM tx ORDER BY fee_per_cost ASC, seq DESC")
item_cost: int
fee_per_cost: float
for item_cost, fee_per_cost in cursor:
current_cost -= item_cost
# Removing one at a time, until our transaction of size cost fits
if current_cost + cost <= self.mempool_info.max_size_in_cost:
return fee_per_cost
raise ValueError(
f"Transaction with cost {cost} does not fit in mempool of max cost {self.mempool_info.max_size_in_cost}"
)
else:
return 0
def new_tx_block(self, block_height: uint32, timestamp: uint64) -> None:
"""
Remove all items that became invalid because of this new height and
timestamp. (we don't know about which coins were spent in this new block
here, so those are handled separately)
"""
with self._db_conn:
cursor = self._db_conn.execute(
"SELECT name FROM tx WHERE assert_before_seconds <= ? OR assert_before_height <= ?",
(timestamp, block_height),
)
to_remove = [bytes32(row[0]) for row in cursor]
self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED)
def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) -> None:
"""
Removes an item from the mempool.
"""
for spend_bundle_id in items:
item: Optional[MempoolItem] = self.spends.get(spend_bundle_id)
if item is None:
continue
assert item.name == spend_bundle_id
removals: List[Coin] = item.removals
for rem in removals:
rem_name: bytes32 = rem.name()
self.removal_coin_id_to_spendbundle_ids[rem_name].remove(spend_bundle_id)
if len(self.removal_coin_id_to_spendbundle_ids[rem_name]) == 0:
del self.removal_coin_id_to_spendbundle_ids[rem_name]
del self.spends[item.name]
del self.sorted_spends[item.fee_per_cost][item.name]
dic = self.sorted_spends[item.fee_per_cost]
if len(dic.values()) == 0:
del self.sorted_spends[item.fee_per_cost]
self.total_mempool_cost = CLVMCost(uint64(self.total_mempool_cost - item.cost))
self.total_mempool_fees = Mojos(uint64(self.total_mempool_fees - item.fee))
assert self.total_mempool_cost >= 0
info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost, self.total_mempool_fees, datetime.now())
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
self.fee_estimator.remove_mempool_item(info, item)
if items == []:
return
removed_items: List[MempoolItemInfo] = []
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
for spend_bundle_ids in chunks(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(spend_bundle_ids))
with self._db_conn:
cursor = self._db_conn.execute(
f"SELECT name, cost, fee FROM tx WHERE name in ({args})", spend_bundle_ids
)
for row in cursor:
name = bytes32(row[0])
internal_item = self._items[name]
item = MempoolItemInfo(int(row[1]), int(row[2]), internal_item.height_added_to_mempool)
removed_items.append(item)
for name in items:
self._items.pop(name)
for spend_bundle_ids in chunks(items, SQLITE_MAX_VARIABLE_NUMBER):
args = ",".join(["?"] * len(spend_bundle_ids))
with self._db_conn:
self._db_conn.execute(f"DELETE FROM tx WHERE name in ({args})", spend_bundle_ids)
self._db_conn.execute(f"DELETE FROM spends WHERE tx in ({args})", spend_bundle_ids)
if reason != MempoolRemoveReason.BLOCK_INCLUSION:
info = FeeMempoolInfo(
self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now()
)
for iteminfo in removed_items:
self.fee_estimator.remove_mempool_item(info, iteminfo)
def add_to_pool(self, item: MempoolItem) -> None:
"""
Adds an item to the mempool by kicking out transactions (if it doesn't fit), in order of increasing fee per cost
"""
while self.at_full_capacity(item.cost):
# Val is Dict[hash, MempoolItem]
fee_per_cost, val = self.sorted_spends.peekitem(index=0)
to_remove: MempoolItem = list(val.values())[0]
self.remove_from_pool([to_remove.name], MempoolRemoveReason.POOL_FULL)
assert item.fee < MEMPOOL_ITEM_FEE_LIMIT
assert item.npc_result.conds is not None
assert item.cost <= self.mempool_info.max_block_clvm_cost
self.spends[item.name] = item
with self._db_conn:
total_cost = int(self.total_mempool_cost())
if total_cost + item.cost > self.mempool_info.max_size_in_cost:
# pick the items with the lowest fee per cost to remove
cursor = self._db_conn.execute(
"""SELECT name FROM tx
WHERE name NOT IN (
SELECT name FROM (
SELECT name,
SUM(cost) OVER (ORDER BY fee_per_cost DESC, seq ASC) AS total_cost
FROM tx) AS tx_with_cost
WHERE total_cost <= ?)
""",
(self.mempool_info.max_size_in_cost - item.cost,),
)
to_remove: List[bytes32] = [bytes32(row[0]) for row in cursor]
self.remove_from_pool(to_remove, MempoolRemoveReason.POOL_FULL)
# sorted_spends is Dict[float, Dict[bytes32, MempoolItem]]
if item.fee_per_cost not in self.sorted_spends:
self.sorted_spends[item.fee_per_cost] = {}
if SQLITE_NO_GENERATED_COLUMNS:
self._db_conn.execute(
"INSERT INTO "
"tx(name,cost,fee,assert_height,assert_before_height,assert_before_seconds,fee_per_cost) "
"VALUES(?, ?, ?, ?, ?, ?, ?)",
(
item.name,
item.cost,
item.fee,
item.assert_height,
item.assert_before_height,
item.assert_before_seconds,
item.fee / item.cost,
),
)
else:
self._db_conn.execute(
"INSERT INTO "
"tx(name,cost,fee,assert_height,assert_before_height,assert_before_seconds) "
"VALUES(?, ?, ?, ?, ?, ?)",
(
item.name,
item.cost,
item.fee,
item.assert_height,
item.assert_before_height,
item.assert_before_seconds,
),
)
self.sorted_spends[item.fee_per_cost][item.name] = item
all_coin_spends = [(s.coin_id, item.name) for s in item.npc_result.conds.spends]
self._db_conn.executemany("INSERT INTO spends VALUES(?, ?)", all_coin_spends)
for coin in item.removals:
coin_id = coin.name()
if coin_id not in self.removal_coin_id_to_spendbundle_ids:
self.removal_coin_id_to_spendbundle_ids[coin_id] = []
self.removal_coin_id_to_spendbundle_ids[coin_id].append(item.name)
self._items[item.name] = InternalMempoolItem(
item.spend_bundle, item.npc_result, item.height_added_to_mempool
)
self.total_mempool_cost = CLVMCost(uint64(self.total_mempool_cost + item.cost))
self.total_mempool_fees = Mojos(uint64(self.total_mempool_fees + item.fee))
info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost, self.total_mempool_fees, datetime.now())
self.fee_estimator.add_mempool_item(info, item)
info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now())
self.fee_estimator.add_mempool_item(info, MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
def at_full_capacity(self, cost: int) -> bool:
"""
Checks whether the mempool is at full capacity and cannot accept a transaction with size cost.
"""
return self.total_mempool_cost + cost > self.mempool_info.max_size_in_cost
return self.total_mempool_cost() + cost > self.mempool_info.max_size_in_cost
def create_bundle_from_mempool_items(
self, item_inclusion_filter: Callable[[bytes32], bool]
) -> Optional[Tuple[SpendBundle, List[Coin], List[Coin]]]:
cost_sum = 0 # Checks that total cost does not exceed block maximum
fee_sum = 0 # Checks that total fees don't exceed 64 bits
spend_bundles: List[SpendBundle] = []
removals: List[Coin] = []
additions: List[Coin] = []
log.info(f"Starting to make block, max cost: {self.mempool_info.max_block_clvm_cost}")
for item in self.spends_by_feerate():
if not item_inclusion_filter(item.name):
continue
log.info("Cumulative cost: %d, fee per cost: %0.4f", cost_sum, item.fee_per_cost)
if (
item.cost + cost_sum > self.mempool_info.max_block_clvm_cost
or item.fee + fee_sum > DEFAULT_CONSTANTS.MAX_COIN_AMOUNT
):
break
spend_bundles.append(item.spend_bundle)
cost_sum += item.cost
fee_sum += item.fee
removals.extend(item.removals)
if item.npc_result.conds is not None:
for spend in item.npc_result.conds.spends:
for puzzle_hash, amount, _ in spend.create_coin:
coin = Coin(spend.coin_id, puzzle_hash, amount)
additions.append(coin)
if len(spend_bundles) == 0:
return None
log.info(
f"Cumulative cost of block (real cost should be less) {cost_sum}. Proportion "
f"full: {cost_sum / self.mempool_info.max_block_clvm_cost}"
)
agg = SpendBundle.aggregate(spend_bundles)
return agg, additions, removals

View File

@ -3,14 +3,14 @@ from __future__ import annotations
import logging
from typing import Dict, List, Optional, Tuple
from chia_rs import LIMIT_STACK, MEMPOOL_MODE, NO_NEG_DIV
from chia_rs import ENABLE_ASSERT_BEFORE, LIMIT_STACK, MEMPOOL_MODE, NO_RELATIVE_CONDITIONS_ON_EPHEMERAL
from chia_rs import get_puzzle_and_solution_for_coin as get_puzzle_and_solution_for_coin_rust
from chia_rs import run_chia_program
from chia_rs import run_block_generator, run_chia_program
from clvm.casts import int_from_bytes
from chia.consensus.constants import ConsensusConstants
from chia.consensus.cost_calculator import NPCResult
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.full_node.generator import setup_generator_args
from chia.types.blockchain_format.coin import Coin
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.serialized_program import SerializedProgram
@ -34,36 +34,32 @@ log = logging.getLogger(__name__)
def get_name_puzzle_conditions(
generator: BlockGenerator, max_cost: int, *, cost_per_byte: int, mempool_mode: bool, height: Optional[uint32] = None
generator: BlockGenerator,
max_cost: int,
*,
mempool_mode: bool,
height: uint32,
constants: ConsensusConstants = DEFAULT_CONSTANTS,
) -> NPCResult:
block_program, block_program_args = setup_generator_args(generator)
size_cost = len(bytes(generator.program)) * cost_per_byte
max_cost -= size_cost
if max_cost < 0:
return NPCResult(uint16(Err.INVALID_BLOCK_COST.value), None, uint64(0))
# in mempool mode, the height doesn't matter, because it's always strict.
# But otherwise, height must be specified to know which rules to apply
assert mempool_mode or height is not None
if mempool_mode:
flags = MEMPOOL_MODE
elif height is not None and height >= DEFAULT_CONSTANTS.SOFT_FORK_HEIGHT:
flags = NO_NEG_DIV | LIMIT_STACK
elif height >= constants.SOFT_FORK_HEIGHT:
flags = LIMIT_STACK
else:
# conditions must use integers in canonical encoding (i.e. no redundant
# leading zeros)
# the division operator may not be used with negative operands
flags = NO_NEG_DIV
flags = 0
if height >= constants.SOFT_FORK2_HEIGHT:
flags = flags | ENABLE_ASSERT_BEFORE | NO_RELATIVE_CONDITIONS_ON_EPHEMERAL
try:
err, result = GENERATOR_MOD.run_as_generator(max_cost, flags, block_program, block_program_args)
block_args = [bytes(gen) for gen in generator.generator_refs]
err, result = run_block_generator(bytes(generator.program), block_args, max_cost, flags)
assert (err is None) != (result is None)
if err is not None:
return NPCResult(uint16(err), None, uint64(0))
else:
assert result is not None
return NPCResult(None, result, uint64(result.cost + size_cost))
return NPCResult(None, result, uint64(result.cost))
except BaseException:
log.exception("get_name_puzzle_condition failed")
return NPCResult(uint16(Err.GENERATOR_RUNTIME_ERROR.value), None, uint64(0))
@ -132,12 +128,32 @@ def mempool_check_time_locks(
return Err.ASSERT_HEIGHT_ABSOLUTE_FAILED
if timestamp < bundle_conds.seconds_absolute:
return Err.ASSERT_SECONDS_ABSOLUTE_FAILED
if bundle_conds.before_height_absolute is not None:
if prev_transaction_block_height >= bundle_conds.before_height_absolute:
return Err.ASSERT_BEFORE_HEIGHT_ABSOLUTE_FAILED
if bundle_conds.before_seconds_absolute is not None:
if timestamp >= bundle_conds.before_seconds_absolute:
return Err.ASSERT_BEFORE_SECONDS_ABSOLUTE_FAILED
for spend in bundle_conds.spends:
unspent = removal_coin_records[bytes32(spend.coin_id)]
if spend.birth_height is not None:
if spend.birth_height != unspent.confirmed_block_index:
return Err.ASSERT_MY_BIRTH_HEIGHT_FAILED
if spend.birth_seconds is not None:
if spend.birth_seconds != unspent.timestamp:
return Err.ASSERT_MY_BIRTH_SECONDS_FAILED
if spend.height_relative is not None:
if prev_transaction_block_height < unspent.confirmed_block_index + spend.height_relative:
return Err.ASSERT_HEIGHT_RELATIVE_FAILED
if timestamp < unspent.timestamp + spend.seconds_relative:
return Err.ASSERT_SECONDS_RELATIVE_FAILED
if spend.seconds_relative is not None:
if timestamp < unspent.timestamp + spend.seconds_relative:
return Err.ASSERT_SECONDS_RELATIVE_FAILED
if spend.before_height_relative is not None:
if prev_transaction_block_height >= unspent.confirmed_block_index + spend.before_height_relative:
return Err.ASSERT_BEFORE_HEIGHT_RELATIVE_FAILED
if spend.before_seconds_relative is not None:
if timestamp >= unspent.timestamp + spend.before_seconds_relative:
return Err.ASSERT_BEFORE_SECONDS_RELATIVE_FAILED
return None

View File

@ -5,20 +5,21 @@ import logging
import time
from concurrent.futures import Executor
from concurrent.futures.process import ProcessPoolExecutor
from dataclasses import dataclass
from multiprocessing.context import BaseContext
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple
from typing import Awaitable, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from blspy import GTElement
from chiabip158 import PyBIP158
from chia.consensus.block_record import BlockRecord
from chia.consensus.block_record import BlockRecordProtocol
from chia.consensus.constants import ConsensusConstants
from chia.consensus.cost_calculator import NPCResult
from chia.full_node.bitcoin_fee_estimator import create_bitcoin_fee_estimator
from chia.full_node.bundle_tools import simple_solution_generator
from chia.full_node.fee_estimation import FeeBlockInfo, MempoolInfo
from chia.full_node.fee_estimation import FeeBlockInfo, MempoolInfo, MempoolItemInfo
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.mempool import Mempool, MempoolRemoveReason
from chia.full_node.mempool import MEMPOOL_ITEM_FEE_LIMIT, Mempool, MempoolRemoveReason
from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, mempool_check_time_locks
from chia.full_node.pending_tx_cache import ConflictTxCache, PendingTxCache
from chia.types.blockchain_format.coin import Coin
@ -33,6 +34,7 @@ from chia.types.spend_bundle_conditions import SpendBundleConditions
from chia.util import cached_bls
from chia.util.cached_bls import LOCAL_CACHE
from chia.util.condition_tools import pkm_pairs
from chia.util.db_wrapper import SQLITE_INT_MAX
from chia.util.errors import Err, ValidationError
from chia.util.generator_tools import additions_for_npc
from chia.util.inline_executor import InlineExecutor
@ -42,21 +44,30 @@ from chia.util.setproctitle import getproctitle, setproctitle
log = logging.getLogger(__name__)
# mempool items replacing existing ones must increase the total fee at least by
# this amount. 0.00001 XCH
MEMPOOL_MIN_FEE_INCREASE = uint64(10000000)
# TODO: once the 1.8.0 soft-fork has activated, we don't really need to pass
# the constants through here
def validate_clvm_and_signature(
spend_bundle_bytes: bytes, max_cost: int, cost_per_byte: int, additional_data: bytes
spend_bundle_bytes: bytes, max_cost: int, constants: ConsensusConstants, height: uint32
) -> Tuple[Optional[Err], bytes, Dict[bytes32, bytes]]:
"""
Validates CLVM and aggregate signature for a spendbundle. This is meant to be called under a ProcessPoolExecutor
in order to validate the heavy parts of a transaction in a different thread. Returns an optional error,
the NPCResult and a cache of the new pairings validated (if not error)
"""
additional_data = constants.AGG_SIG_ME_ADDITIONAL_DATA
try:
bundle: SpendBundle = SpendBundle.from_bytes(spend_bundle_bytes)
program = simple_solution_generator(bundle)
# npc contains names of the coins removed, puzzle_hashes and their spend conditions
result: NPCResult = get_name_puzzle_conditions(
program, max_cost, cost_per_byte=cost_per_byte, mempool_mode=True
program, max_cost, mempool_mode=True, constants=constants, height=height
)
if result.error is not None:
@ -65,7 +76,7 @@ def validate_clvm_and_signature(
pks: List[bytes48] = []
msgs: List[bytes] = []
assert result.conds is not None
pks, msgs = pkm_pairs(result.conds, additional_data)
pks, msgs = pkm_pairs(result.conds, additional_data, soft_fork=True)
# Verify aggregated signature
cache: LRUCache[bytes32, GTElement] = LRUCache(10000)
@ -82,25 +93,54 @@ def validate_clvm_and_signature(
return None, bytes(result), new_cache_entries
@dataclass
class TimelockConditions:
assert_height: uint32 = uint32(0)
assert_before_height: Optional[uint32] = None
assert_before_seconds: Optional[uint64] = None
def compute_assert_height(
removal_coin_records: Dict[bytes32, CoinRecord],
conds: SpendBundleConditions,
) -> uint32:
) -> TimelockConditions:
"""
Computes the most restrictive height assertion in the spend bundle. Relative
height assertions are resolved using the confirmed heights from the coin
records.
Computes the most restrictive height- and seconds assertion in the spend bundle.
Relative heights and times are resolved using the confirmed heights and
timestamps from the coin records.
"""
height: uint32 = uint32(conds.height_absolute)
ret = TimelockConditions()
ret.assert_height = uint32(conds.height_absolute)
ret.assert_before_height = (
uint32(conds.before_height_absolute) if conds.before_height_absolute is not None else None
)
ret.assert_before_seconds = (
uint64(conds.before_seconds_absolute) if conds.before_seconds_absolute is not None else None
)
for spend in conds.spends:
if spend.height_relative is None:
continue
h = uint32(removal_coin_records[bytes32(spend.coin_id)].confirmed_block_index + spend.height_relative)
height = max(height, h)
if spend.height_relative is not None:
h = uint32(removal_coin_records[bytes32(spend.coin_id)].confirmed_block_index + spend.height_relative)
ret.assert_height = max(ret.assert_height, h)
return height
if spend.before_height_relative is not None:
h = uint32(
removal_coin_records[bytes32(spend.coin_id)].confirmed_block_index + spend.before_height_relative
)
if ret.assert_before_height is not None:
ret.assert_before_height = min(ret.assert_before_height, h)
else:
ret.assert_before_height = h
if spend.before_seconds_relative is not None:
s = uint64(removal_coin_records[bytes32(spend.coin_id)].timestamp + spend.before_seconds_relative)
if ret.assert_before_seconds is not None:
ret.assert_before_seconds = min(ret.assert_before_seconds, s)
else:
ret.assert_before_seconds = s
return ret
class MempoolManager:
@ -115,7 +155,7 @@ class MempoolManager:
# cache of MempoolItems with height conditions making them not valid yet
_pending_cache: PendingTxCache
seen_cache_size: int
peak: Optional[BlockRecord]
peak: Optional[BlockRecordProtocol]
mempool: Mempool
def __init__(
@ -157,7 +197,7 @@ class MempoolManager:
)
# The mempool will correspond to a certain peak
self.peak: Optional[BlockRecord] = None
self.peak: Optional[BlockRecordProtocol] = None
self.fee_estimator: FeeEstimatorInterface = create_bitcoin_fee_estimator(self.max_block_clvm_cost)
mempool_info = MempoolInfo(
CLVMCost(uint64(self.mempool_max_total_cost)),
@ -169,35 +209,10 @@ class MempoolManager:
def shut_down(self) -> None:
self.pool.shutdown(wait=True)
def process_mempool_items(
self, item_inclusion_filter: Callable[[MempoolManager, MempoolItem], bool]
) -> Tuple[List[SpendBundle], uint64, List[Coin], List[Coin]]:
cost_sum = 0 # Checks that total cost does not exceed block maximum
fee_sum = 0 # Checks that total fees don't exceed 64 bits
spend_bundles: List[SpendBundle] = []
removals: List[Coin] = []
additions: List[Coin] = []
for dic in reversed(self.mempool.sorted_spends.values()):
for item in dic.values():
if not item_inclusion_filter(self, item):
continue
log.info(f"Cumulative cost: {cost_sum}, fee per cost: {item.fee / item.cost}")
if (
item.cost + cost_sum > self.max_block_clvm_cost
or item.fee + fee_sum > self.constants.MAX_COIN_AMOUNT
):
return (spend_bundles, uint64(cost_sum), additions, removals)
spend_bundles.append(item.spend_bundle)
cost_sum += item.cost
fee_sum += item.fee
removals.extend(item.removals)
additions.extend(item.additions)
return (spend_bundles, uint64(cost_sum), additions, removals)
def create_bundle_from_mempool(
self,
last_tb_header_hash: bytes32,
item_inclusion_filter: Optional[Callable[[MempoolManager, MempoolItem], bool]] = None,
item_inclusion_filter: Optional[Callable[[bytes32], bool]] = None,
) -> Optional[Tuple[SpendBundle, List[Coin], List[Coin]]]:
"""
Returns aggregated spendbundle that can be used for creating new block,
@ -205,29 +220,18 @@ class MempoolManager:
"""
if self.peak is None or self.peak.header_hash != last_tb_header_hash:
return None
if item_inclusion_filter is None:
def always(mm: MempoolManager, mi: MempoolItem) -> bool:
def always(bundle_name: bytes32) -> bool:
return True
item_inclusion_filter = always
log.info(f"Starting to make block, max cost: {self.max_block_clvm_cost}")
spend_bundles, cost_sum, additions, removals = self.process_mempool_items(item_inclusion_filter)
if len(spend_bundles) == 0:
return None
log.info(
f"Cumulative cost of block (real cost should be less) {cost_sum}. Proportion "
f"full: {cost_sum / self.max_block_clvm_cost}"
)
agg = SpendBundle.aggregate(spend_bundles)
return agg, additions, removals
return self.mempool.create_bundle_from_mempool_items(item_inclusion_filter)
def get_filter(self) -> bytes:
all_transactions: Set[bytes32] = set()
byte_array_list = []
for key, _ in self.mempool.spends.items():
for key in self.mempool.all_spend_ids():
if key not in all_transactions:
all_transactions.add(key)
byte_array_list.append(bytearray(key))
@ -263,52 +267,6 @@ class MempoolManager:
if bundle_hash in self.seen_bundle_hashes:
self.seen_bundle_hashes.pop(bundle_hash)
@staticmethod
def get_min_fee_increase() -> int:
# 0.00001 XCH
return 10000000
def can_replace(
self,
conflicting_items: Dict[bytes32, MempoolItem],
removals: Dict[bytes32, CoinRecord],
fees: uint64,
fees_per_cost: float,
) -> bool:
conflicting_fees = 0
conflicting_cost = 0
for item in conflicting_items.values():
conflicting_fees += item.fee
conflicting_cost += item.cost
# All coins spent in all conflicting items must also be spent in the new item. (superset rule). This is
# important because otherwise there exists an attack. A user spends coin A. An attacker replaces the
# bundle with AB with a higher fee. An attacker then replaces the bundle with just B with a higher
# fee than AB therefore kicking out A altogether. The better way to solve this would be to keep a cache
# of booted transactions like A, and retry them after they get removed from mempool due to a conflict.
for coin in item.removals:
if coin.name() not in removals:
log.debug(f"Rejecting conflicting tx as it does not spend conflicting coin {coin.name()}")
return False
# New item must have higher fee per cost
conflicting_fees_per_cost = conflicting_fees / conflicting_cost
if fees_per_cost <= conflicting_fees_per_cost:
log.debug(
f"Rejecting conflicting tx due to not increasing fees per cost "
f"({fees_per_cost} <= {conflicting_fees_per_cost})"
)
return False
# New item must increase the total fee at least by a certain amount
fee_increase = fees - conflicting_fees
if fee_increase < self.get_min_fee_increase():
log.debug(f"Rejecting conflicting tx due to low fee increase ({fee_increase})")
return False
log.info(f"Replacing conflicting tx in mempool. New tx fee: {fees}, old tx fees: {conflicting_fees}")
return True
async def pre_validate_spendbundle(
self, new_spend: SpendBundle, new_spend_bytes: Optional[bytes], spend_name: bytes32
) -> NPCResult:
@ -323,13 +281,15 @@ class MempoolManager:
if new_spend.coin_spends == []:
raise ValidationError(Err.INVALID_SPEND_BUNDLE, "Empty SpendBundle")
assert self.peak is not None
err, cached_result_bytes, new_cache_entries = await asyncio.get_running_loop().run_in_executor(
self.pool,
validate_clvm_and_signature,
new_spend_bytes,
self.max_block_clvm_cost,
self.constants.COST_PER_BYTE,
self.constants.AGG_SIG_ME_ADDITIONAL_DATA,
self.constants,
self.peak.height,
)
if err is not None:
@ -366,10 +326,9 @@ class MempoolManager:
"""
# Skip if already added
if spend_name in self.mempool.spends:
cost: Optional[uint64] = self.mempool.spends[spend_name].cost
assert cost is not None
return uint64(cost), MempoolInclusionStatus.SUCCESS, None
existing_item = self.mempool.get_spend_by_id(spend_name)
if existing_item is not None:
return existing_item.cost, MempoolInclusionStatus.SUCCESS, None
err, item, remove_items = await self.validate_spend_bundle(
new_spend, npc_result, spend_name, first_added_height
@ -430,9 +389,9 @@ class MempoolManager:
log.debug(f"Cost: {cost}")
assert npc_result.conds is not None
# build removal list
removal_names: List[bytes32] = [bytes32(spend.coin_id) for spend in npc_result.conds.spends]
if set(removal_names) != set([s.name() for s in new_spend.removals()]):
# build set of removals
removal_names: Set[bytes32] = set(bytes32(spend.coin_id) for spend in npc_result.conds.spends)
if removal_names != set(s.name() for s in new_spend.removals()):
# If you reach here it's probably because your program reveal doesn't match the coin's puzzle hash
return Err.INVALID_SPEND_BUNDLE, None, []
@ -470,18 +429,21 @@ class MempoolManager:
removal_amount = removal_amount + removal_record.coin.amount
removal_record_dict[name] = removal_record
if addition_amount > removal_amount:
return Err.MINTING_COIN, None, []
fees = uint64(removal_amount - addition_amount)
assert_fee_sum: uint64 = uint64(npc_result.conds.reserve_fee)
if fees < assert_fee_sum:
return Err.RESERVE_FEE_CONDITION_FAILED, None, []
if cost == 0:
return Err.UNKNOWN, None, []
if cost > self.max_block_clvm_cost:
return Err.BLOCK_COST_EXCEEDS_MAX, None, []
# this is not very likely to happen, but it's here to ensure SQLite
# never runs out of precision in its computation of fees.
# sqlite's integers are signed int64, so the max value they can
# represent is 2^63-1
if fees > MEMPOOL_ITEM_FEE_LIMIT or SQLITE_INT_MAX - self.mempool.total_mempool_fees() <= fees:
return Err.INVALID_BLOCK_FEE_AMOUNT, None, []
fees_per_cost: float = fees / cost
# If pool is at capacity check the fee, if not then accept even without the fee
if self.mempool.at_full_capacity(cost):
@ -492,8 +454,6 @@ class MempoolManager:
# Check removals against UnspentDB + DiffStore + Mempool + SpendBundle
# Use this information later when constructing a block
fail_reason, conflicts = self.check_removals(removal_record_dict)
# If there is a mempool conflict check if this SpendBundle has a higher fee per cost than all others
conflicting_pool_items: Dict[bytes32, MempoolItem] = {}
# If we have a mempool conflict, continue, since we still want to keep around the TX in the pending pool.
if fail_reason is not None and fail_reason is not Err.MEMPOOL_CONFLICT:
@ -508,24 +468,30 @@ class MempoolManager:
log.warning(f"{spend.puzzle_hash.hex()} != {coin_record.coin.puzzle_hash.hex()}")
return Err.WRONG_PUZZLE_HASH, None, []
chialisp_height = (
self.peak.prev_transaction_block_height if not self.peak.is_transaction_block else self.peak.height
)
# the height and time we pass in here represent the previous transaction
# block's height and timestamp. In the mempool, the most recent peak
# block we've received will be the previous transaction block, from the
# point-of-view of the next block to be farmed. Therefore we pass in the
# current peak's height and timestamp
assert self.peak.timestamp is not None
tl_error: Optional[Err] = mempool_check_time_locks(
removal_record_dict,
npc_result.conds,
uint32(chialisp_height),
self.peak.height,
self.peak.timestamp,
)
assert_height: Optional[uint32] = None
if tl_error:
assert_height = compute_assert_height(removal_record_dict, npc_result.conds)
timelocks: TimelockConditions = compute_assert_height(removal_record_dict, npc_result.conds)
potential = MempoolItem(
new_spend, uint64(fees), npc_result, cost, spend_name, additions, first_added_height, assert_height
new_spend,
uint64(fees),
npc_result,
spend_name,
first_added_height,
timelocks.assert_height,
timelocks.assert_before_height,
timelocks.assert_before_seconds,
)
if tl_error:
@ -535,12 +501,8 @@ class MempoolManager:
return tl_error, None, [] # MempoolInclusionStatus.FAILED
if fail_reason is Err.MEMPOOL_CONFLICT:
for conflicting in conflicts:
for c_sb_id in self.mempool.removal_coin_id_to_spendbundle_ids[conflicting.name()]:
sb: MempoolItem = self.mempool.spends[c_sb_id]
conflicting_pool_items[sb.name] = sb
log.debug(f"Replace attempted. number of MempoolItems: {len(conflicting_pool_items)}")
if not self.can_replace(conflicting_pool_items, removal_record_dict, fees, fees_per_cost):
log.debug(f"Replace attempted. number of MempoolItems: {len(conflicts)}")
if not can_replace(conflicts, removal_names, potential):
return Err.MEMPOOL_CONFLICT, potential, []
duration = time.time() - start_time
@ -551,36 +513,37 @@ class MempoolManager:
f"Cost: {cost} ({round(100.0 * cost/self.constants.MAX_BLOCK_COST_CLVM, 3)}% of max block cost)",
)
return None, potential, list(conflicting_pool_items.keys())
return None, potential, [item.name for item in conflicts]
def check_removals(self, removals: Dict[bytes32, CoinRecord]) -> Tuple[Optional[Err], List[Coin]]:
def check_removals(self, removals: Dict[bytes32, CoinRecord]) -> Tuple[Optional[Err], Set[MempoolItem]]:
"""
This function checks for double spends, unknown spends and conflicting transactions in mempool.
Returns Error (if any), dictionary of Unspents, list of coins with conflict errors (if any any).
Returns Error (if any), the set of existing MempoolItems with conflicting spends (if any).
Note that additions are not checked for duplicates, because having duplicate additions requires also
having duplicate removals.
"""
assert self.peak is not None
conflicts: List[Coin] = []
conflicts: Set[MempoolItem] = set()
for record in removals.values():
removal = record.coin
# 1. Checks if it's been spent already
if record.spent:
return Err.DOUBLE_SPEND, []
return Err.DOUBLE_SPEND, set()
# 2. Checks if there's a mempool conflict
if removal.name() in self.mempool.removal_coin_id_to_spendbundle_ids:
conflicts.append(removal)
items: List[MempoolItem] = self.mempool.get_spends_by_coin_id(removal.name())
conflicts.update(items)
if len(conflicts) > 0:
return Err.MEMPOOL_CONFLICT, conflicts
# 5. If coins can be spent return list of unspents as we see them in local storage
return None, []
return None, set()
def get_spendbundle(self, bundle_hash: bytes32) -> Optional[SpendBundle]:
"""Returns a full SpendBundle if it's inside one the mempools"""
if bundle_hash in self.mempool.spends:
return self.mempool.spends[bundle_hash].spend_bundle
item: Optional[MempoolItem] = self.mempool.get_spend_by_id(bundle_hash)
if item is not None:
return item.spend_bundle
return None
def get_mempool_item(self, bundle_hash: bytes32, include_pending: bool = False) -> Optional[MempoolItem]:
@ -590,7 +553,7 @@ class MempoolManager:
If include_pending is specified, also check the PENDING cache.
"""
item = self.mempool.spends.get(bundle_hash, None)
item = self.mempool.get_spend_by_id(bundle_hash)
if not item and include_pending:
# no async lock needed since we're not mutating the pending_cache
item = self._pending_cache.get(bundle_hash)
@ -600,20 +563,23 @@ class MempoolManager:
return item
async def new_peak(
self, new_peak: Optional[BlockRecord], last_npc_result: Optional[NPCResult]
self, new_peak: Optional[BlockRecordProtocol], last_npc_result: Optional[NPCResult]
) -> List[Tuple[SpendBundle, NPCResult, bytes32]]:
"""
Called when a new peak is available, we try to recreate a mempool for the new tip.
"""
if new_peak is None:
return []
# we're only interested in transaction blocks
if new_peak.is_transaction_block is False:
return []
if self.peak == new_peak:
return []
assert new_peak.timestamp is not None
self.fee_estimator.new_block_height(new_peak.height)
included_items = []
included_items: List[MempoolItemInfo] = []
self.mempool.new_tx_block(new_peak.height, new_peak.timestamp)
use_optimization: bool = self.peak is not None and new_peak.prev_transaction_block_hash == self.peak.header_hash
self.peak = new_peak
@ -621,22 +587,23 @@ class MempoolManager:
if use_optimization and last_npc_result is not None:
# We don't reinitialize a mempool, just kick removed items
if last_npc_result.conds is not None:
# transactions in the mempool may be spending multiple coins,
# when looking up transactions by all coin IDs, we're likely to
# find the same transaction multiple times. We put them in a set
# to deduplicate
spendbundle_ids_to_remove: Set[bytes32] = set()
for spend in last_npc_result.conds.spends:
if spend.coin_id in self.mempool.removal_coin_id_to_spendbundle_ids:
spendbundle_ids: List[bytes32] = self.mempool.removal_coin_id_to_spendbundle_ids[
bytes32(spend.coin_id)
]
for spendbundle_id in spendbundle_ids:
item = self.mempool.spends.get(spendbundle_id)
if item:
included_items.append(item)
self.remove_seen(spendbundle_id)
self.mempool.remove_from_pool(spendbundle_ids, MempoolRemoveReason.BLOCK_INCLUSION)
items: List[MempoolItem] = self.mempool.get_spends_by_coin_id(bytes32(spend.coin_id))
for item in items:
included_items.append(MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
self.remove_seen(item.name)
spendbundle_ids_to_remove.add(item.name)
self.mempool.remove_from_pool(list(spendbundle_ids_to_remove), MempoolRemoveReason.BLOCK_INCLUSION)
else:
old_pool = self.mempool
self.mempool = Mempool(old_pool.mempool_info, old_pool.fee_estimator)
self.seen_bundle_hashes = {}
for item in old_pool.spends.values():
for item in old_pool.all_spends():
_, result, err = await self.add_spend_bundle(
item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool
)
@ -648,7 +615,7 @@ class MempoolManager:
if result == MempoolInclusionStatus.FAILED and err == Err.DOUBLE_SPEND:
# Item was in mempool, but after the new block it's a double spend.
# Item is most likely included in the block.
included_items.append(item)
included_items.append(MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool))
potential_txs = self._pending_cache.drain(new_peak.height)
potential_txs.update(self._conflict_cache.drain())
@ -660,31 +627,115 @@ class MempoolManager:
if status == MempoolInclusionStatus.SUCCESS:
txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name))
log.info(
f"Size of mempool: {len(self.mempool.spends)} spends, "
f"cost: {self.mempool.total_mempool_cost} "
f"Size of mempool: {self.mempool.size()} spends, "
f"cost: {self.mempool.total_mempool_cost()} "
f"minimum fee rate (in FPC) to get in for 5M cost tx: {self.mempool.get_min_fee_rate(5000000)}"
)
self.mempool.fee_estimator.new_block(FeeBlockInfo(new_peak.height, included_items))
return txs_added
async def get_items_not_in_filter(self, mempool_filter: PyBIP158, limit: int = 100) -> List[MempoolItem]:
items: List[MempoolItem] = []
counter = 0
broke_from_inner_loop = False
def get_items_not_in_filter(self, mempool_filter: PyBIP158, limit: int = 100) -> List[SpendBundle]:
items: List[SpendBundle] = []
assert limit > 0
# Send 100 with the highest fee per cost
for dic in reversed(self.mempool.sorted_spends.values()):
if broke_from_inner_loop:
break
for item in dic.values():
if counter == limit:
broke_from_inner_loop = True
break
if mempool_filter.Match(bytearray(item.spend_bundle_name)):
continue
items.append(item)
counter += 1
for item in self.mempool.spends_by_feerate():
if len(items) >= limit:
return items
if mempool_filter.Match(bytearray(item.spend_bundle_name)):
continue
items.append(item.spend_bundle)
return items
T = TypeVar("T", uint32, uint64)
def optional_min(a: Optional[T], b: Optional[T]) -> Optional[T]:
return min((v for v in [a, b] if v is not None), default=None)
def optional_max(a: Optional[T], b: Optional[T]) -> Optional[T]:
return max((v for v in [a, b] if v is not None), default=None)
def can_replace(
conflicting_items: Set[MempoolItem],
removal_names: Set[bytes32],
new_item: MempoolItem,
) -> bool:
"""
This function implements the mempool replacement rules. Given a Mempool item
we're attempting to insert into the mempool (new_item) and the set of existing
mempool items that conflict with it, this function answers the question whether
the existing items can be replaced by the new one. The removals parameter are
the coin IDs the new mempool item is spending.
"""
conflicting_fees = 0
conflicting_cost = 0
assert_height: Optional[uint32] = None
assert_before_height: Optional[uint32] = None
assert_before_seconds: Optional[uint64] = None
for item in conflicting_items:
conflicting_fees += item.fee
conflicting_cost += item.cost
# All coins spent in all conflicting items must also be spent in the new item. (superset rule). This is
# important because otherwise there exists an attack. A user spends coin A. An attacker replaces the
# bundle with AB with a higher fee. An attacker then replaces the bundle with just B with a higher
# fee than AB therefore kicking out A altogether. The better way to solve this would be to keep a cache
# of booted transactions like A, and retry them after they get removed from mempool due to a conflict.
for coin in item.removals:
if coin.name() not in removal_names:
log.debug(f"Rejecting conflicting tx as it does not spend conflicting coin {coin.name()}")
return False
assert_height = optional_max(assert_height, item.assert_height)
assert_before_height = optional_min(assert_before_height, item.assert_before_height)
assert_before_seconds = optional_min(assert_before_seconds, item.assert_before_seconds)
# New item must have higher fee per cost
conflicting_fees_per_cost = conflicting_fees / conflicting_cost
if new_item.fee_per_cost <= conflicting_fees_per_cost:
log.debug(
f"Rejecting conflicting tx due to not increasing fees per cost "
f"({new_item.fee_per_cost} <= {conflicting_fees_per_cost})"
)
return False
# New item must increase the total fee at least by a certain amount
fee_increase = new_item.fee - conflicting_fees
if fee_increase < MEMPOOL_MIN_FEE_INCREASE:
log.debug(f"Rejecting conflicting tx due to low fee increase ({fee_increase})")
return False
# New item may not have a different effective height/time lock (time-lock rule)
if new_item.assert_height != assert_height:
log.debug(
"Rejecting conflicting tx due to changing ASSERT_HEIGHT constraints %s -> %s",
assert_height,
new_item.assert_height,
)
return False
if new_item.assert_before_height != assert_before_height:
log.debug(
"Rejecting conflicting tx due to changing ASSERT_BEFORE_HEIGHT constraints %s -> %s",
assert_before_height,
new_item.assert_before_height,
)
return False
if new_item.assert_before_seconds != assert_before_seconds:
log.debug(
"Rejecting conflicting tx due to changing ASSERT_BEFORE_SECONDS constraints %s -> %s",
assert_before_seconds,
new_item.assert_before_seconds,
)
return False
log.info(f"Replacing conflicting tx in mempool. New tx fee: {new_item.fee}, old tx fees: {conflicting_fees}")
return True

View File

@ -34,7 +34,6 @@ class ConflictTxCache:
self._cache_cost += item.cost
while self._cache_cost > self._cache_max_total_cost or len(self._txs) > self._cache_max_size:
first_in = list(self._txs.keys())[0]
self._cache_cost -= self._txs[first_in].cost
self._txs.pop(first_in)
@ -77,7 +76,6 @@ class PendingTxCache:
self._by_height.setdefault(item.assert_height, {})[name] = item
while self._cache_cost > self._cache_max_total_cost or len(self._txs) > self._cache_max_size:
# we start removing items with the highest assert_height first
to_evict = self._by_height.items()[-1]
if to_evict[1] == {}:

View File

@ -31,10 +31,20 @@ class PeerSubscriptions:
def has_coin_subscription(self, coin_id: bytes32) -> bool:
return coin_id in self._coin_subscriptions
def add_ph_subscriptions(self, peer_id: bytes32, phs: List[bytes32], max_items: int) -> None:
def add_ph_subscriptions(self, peer_id: bytes32, phs: List[bytes32], max_items: int) -> List[bytes32]:
"""
returns the puzzle hashes that were actually subscribed to. These may be
fewer than requested in case:
* there are duplicate puzzle_hashes
* some puzzle hashes are already subscribed to
* the max_items limit is exceeded
"""
puzzle_hash_peers = self._peer_puzzle_hash.setdefault(peer_id, set())
existing_sub_count = self._peer_sub_counter.setdefault(peer_id, 0)
ret: List[bytes32] = []
# if we've reached the limit on number of subscriptions, just bail
if existing_sub_count >= max_items:
log.info(
@ -42,7 +52,7 @@ class PeerSubscriptions:
"Not all its coin states will be reported",
peer_id,
)
return
return ret
# decrement this counter as we go, to know if we've hit the limit of
# number of subscriptions
@ -53,6 +63,7 @@ class PeerSubscriptions:
if peer_id in ph_sub:
continue
ret.append(ph)
ph_sub.add(peer_id)
puzzle_hash_peers.add(ph)
self._peer_sub_counter[peer_id] += 1
@ -65,6 +76,7 @@ class PeerSubscriptions:
peer_id,
)
break
return ret
def add_coin_subscriptions(self, peer_id: bytes32, coin_ids: List[bytes32], max_items: int) -> None:
coin_id_peers = self._peer_coin_ids.setdefault(peer_id, set())
@ -100,7 +112,6 @@ class PeerSubscriptions:
break
def remove_peer(self, peer_id: bytes32) -> None:
counter = 0
puzzle_hashes = self._peer_puzzle_hash.get(peer_id)
if puzzle_hashes is not None:

View File

@ -54,7 +54,6 @@ def _create_shutdown_file() -> IO:
class WeightProofHandler:
LAMBDA_L = 100
C = 0.5
MAX_SAMPLES = 20
@ -74,7 +73,6 @@ class WeightProofHandler:
self.multiprocessing_context = multiprocessing_context
async def get_proof_of_weight(self, tip: bytes32) -> Optional[WeightProof]:
tip_rec = self.blockchain.try_block_record(tip)
if tip_rec is None:
log.error("unknown tip")
@ -527,7 +525,7 @@ class WeightProofHandler:
assert curr.reward_chain_block.challenge_chain_sp_vdf
cc_sp_vdf_info = curr.reward_chain_block.challenge_chain_sp_vdf
if not curr.challenge_chain_sp_proof.normalized_to_identity:
(_, _, _, _, cc_vdf_iters, _,) = get_signage_point_vdf_info(
(_, _, _, _, cc_vdf_iters, _) = get_signage_point_vdf_info(
self.constants,
curr.finished_sub_slots,
block_record.overflow,
@ -732,7 +730,7 @@ async def _challenge_block_vdfs(
block_rec: BlockRecord,
sub_blocks: Dict[bytes32, BlockRecord],
):
(_, _, _, _, cc_vdf_iters, _,) = get_signage_point_vdf_info(
(_, _, _, _, cc_vdf_iters, _) = get_signage_point_vdf_info(
constants,
header_block.finished_sub_slots,
block_rec.overflow,
@ -859,7 +857,6 @@ def _validate_sub_epoch_summaries(
constants: ConsensusConstants,
weight_proof: WeightProof,
) -> Tuple[Optional[List[SubEpochSummary]], Optional[List[uint128]]]:
last_ses_hash, last_ses_sub_height = _get_last_ses_hash(constants, weight_proof.recent_chain_data)
if last_ses_hash is None:
log.warning("could not find last ses block")
@ -1074,7 +1071,6 @@ def _validate_sub_slot_data(
sub_slots: List[SubSlotData],
ssi: uint64,
) -> Tuple[bool, List[Tuple[VDFProof, ClassgroupElement, VDFInfo]]]:
sub_slot_data = sub_slots[sub_slot_idx]
assert sub_slot_idx > 0
prev_ssd = sub_slots[sub_slot_idx - 1]
@ -1225,7 +1221,7 @@ def validate_recent_blocks(
ses_blocks, sub_slots, transaction_blocks = 0, 0, 0
challenge, prev_challenge = recent_chain.recent_chain_data[0].reward_chain_block.pos_ss_cc_challenge_hash, None
tip_height = recent_chain.recent_chain_data[-1].height
prev_block_record = None
prev_block_record: Optional[BlockRecord] = None
deficit = uint8(0)
adjusted = False
for idx, block in enumerate(recent_chain.recent_chain_data):
@ -1249,10 +1245,10 @@ def validate_recent_blocks(
if (challenge is not None) and (prev_challenge is not None):
overflow = is_overflow_block(constants, block.reward_chain_block.signage_point_index)
if not adjusted:
assert prev_block_record is not None
prev_block_record = dataclasses.replace(
prev_block_record, deficit=deficit % constants.MIN_BLOCKS_PER_CHALLENGE_BLOCK
)
assert prev_block_record is not None
sub_blocks.add_block_record(prev_block_record)
adjusted = True
deficit = get_deficit(constants, deficit, prev_block_record, overflow, len(block.finished_sub_slots))
@ -1381,7 +1377,6 @@ def __get_rc_sub_slot(
summaries: List[SubEpochSummary],
curr_ssi: uint64,
) -> RewardChainSubSlot:
ses = summaries[uint32(segment.sub_epoch_n - 1)]
# find first challenge in sub epoch
first_idx = None
@ -1646,7 +1641,6 @@ def _validate_vdf_batch(
vdf_list: List[Tuple[bytes, bytes, bytes]],
shutdown_file_path: Optional[pathlib.Path] = None,
):
for vdf_proof_bytes, class_group_bytes, info in vdf_list:
vdf = VDFProof.from_bytes(vdf_proof_bytes)
class_group = ClassgroupElement.from_bytes(class_group_bytes)

View File

@ -8,6 +8,8 @@ from concurrent.futures.thread import ThreadPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import Literal
from chia.consensus.constants import ConsensusConstants
from chia.plot_sync.sender import Sender
from chia.plotting.manager import PlotManager
@ -35,7 +37,6 @@ class Harvester:
_shut_down: bool
executor: ThreadPoolExecutor
state_changed_callback: Optional[StateChangedProtocol] = None
cached_challenges: List
constants: ConsensusConstants
_refresh_lock: asyncio.Lock
event_loop: asyncio.events.AbstractEventLoop
@ -50,7 +51,7 @@ class Harvester:
return self._server
def __init__(self, root_path: Path, config: Dict, constants: ConsensusConstants):
def __init__(self, root_path: Path, config: Dict[str, Any], constants: ConsensusConstants):
self.log = log
self.root_path = root_path
# TODO, remove checks below later after some versions / time
@ -76,38 +77,37 @@ class Harvester:
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=config["num_threads"])
self._server = None
self.constants = constants
self.cached_challenges = []
self.state_changed_callback: Optional[StateChangedProtocol] = None
self.parallel_read: bool = config.get("parallel_read", True)
async def _start(self):
async def _start(self) -> None:
self._refresh_lock = asyncio.Lock()
self.event_loop = asyncio.get_running_loop()
def _close(self):
def _close(self) -> None:
self._shut_down = True
self.executor.shutdown(wait=True)
self.plot_manager.stop_refreshing()
self.plot_manager.reset()
self.plot_sync_sender.stop()
async def _await_closed(self):
async def _await_closed(self) -> None:
await self.plot_sync_sender.await_closed()
def get_connections(self, request_node_type: Optional[NodeType]) -> List[Dict[str, Any]]:
return default_get_connections(server=self.server, request_node_type=request_node_type)
async def on_connect(self, connection: WSChiaConnection):
async def on_connect(self, connection: WSChiaConnection) -> None:
self.state_changed("add_connection")
def _set_state_changed_callback(self, callback: StateChangedProtocol) -> None:
self.state_changed_callback = callback
def state_changed(self, change: str, change_data: Dict[str, Any] = None):
def state_changed(self, change: str, change_data: Optional[Dict[str, Any]] = None) -> None:
if self.state_changed_callback is not None:
self.state_changed_callback(change, change_data)
def _plot_refresh_callback(self, event: PlotRefreshEvents, update_result: PlotRefreshResult):
def _plot_refresh_callback(self, event: PlotRefreshEvents, update_result: PlotRefreshResult) -> None:
log_function = self.log.debug if event == PlotRefreshEvents.batch_processed else self.log.info
log_function(
f"_plot_refresh_callback: event {event.name}, loaded {len(update_result.loaded)}, "
@ -123,16 +123,16 @@ class Harvester:
if event == PlotRefreshEvents.done:
self.plot_sync_sender.sync_done(update_result.removed, update_result.duration)
def on_disconnect(self, connection: WSChiaConnection):
def on_disconnect(self, connection: WSChiaConnection) -> None:
self.log.info(f"peer disconnected {connection.get_peer_logging()}")
self.state_changed("close_connection")
self.plot_sync_sender.stop()
asyncio.run_coroutine_threadsafe(self.plot_sync_sender.await_closed(), asyncio.get_running_loop())
self.plot_manager.stop_refreshing()
def get_plots(self) -> Tuple[List[Dict], List[str], List[str]]:
def get_plots(self) -> Tuple[List[Dict[str, Any]], List[str], List[str]]:
self.log.debug(f"get_plots prover items: {self.plot_manager.plot_count()}")
response_plots: List[Dict] = []
response_plots: List[Dict[str, Any]] = []
with self.plot_manager:
for path, plot_info in self.plot_manager.plots.items():
prover = plot_info.prover
@ -159,7 +159,7 @@ class Harvester:
[str(s) for s in self.plot_manager.no_key_filenames],
)
def delete_plot(self, str_path: str):
def delete_plot(self, str_path: str) -> Literal[True]:
remove_plot(Path(str_path))
self.plot_manager.trigger_refresh()
self.state_changed("plots")

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import asyncio
import time
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
from blspy import AugSchemeMPL, G1Element, G2Element
@ -14,7 +14,7 @@ from chia.protocols import harvester_protocol
from chia.protocols.farmer_protocol import FarmingInfo
from chia.protocols.harvester_protocol import Plot, PlotSyncResponse
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.server.outbound_message import make_msg
from chia.server.outbound_message import Message, make_msg
from chia.server.ws_connection import WSChiaConnection
from chia.types.blockchain_format.proof_of_space import (
ProofOfSpace,
@ -37,7 +37,7 @@ class HarvesterAPI:
@api_request(peer_required=True)
async def harvester_handshake(
self, harvester_handshake: harvester_protocol.HarvesterHandshake, peer: WSChiaConnection
):
) -> None:
"""
Handshake between the harvester and farmer. The harvester receives the pool public keys,
as well as the farmer pks, which must be put into the plots, before the plotting process begins.
@ -53,7 +53,7 @@ class HarvesterAPI:
@api_request(peer_required=True)
async def new_signage_point_harvester(
self, new_challenge: harvester_protocol.NewSignagePointHarvester, peer: WSChiaConnection
):
) -> None:
"""
The harvester receives a new signage point from the farmer, this happens at the start of each slot.
The harvester does a few things:
@ -246,7 +246,7 @@ class HarvesterAPI:
)
@api_request()
async def request_signatures(self, request: harvester_protocol.RequestSignatures):
async def request_signatures(self, request: harvester_protocol.RequestSignatures) -> Optional[Message]:
"""
The farmer requests a signature on the header hash, for one of the proofs that we found.
A signature is created on the header hash using the harvester private key. This can also
@ -295,7 +295,7 @@ class HarvesterAPI:
return make_msg(ProtocolMessageTypes.respond_signatures, response)
@api_request()
async def request_plots(self, _: harvester_protocol.RequestPlots):
async def request_plots(self, _: harvester_protocol.RequestPlots) -> Message:
plots_response = []
plots, failed_to_open_filenames, no_key_filenames = self.harvester.get_plots()
for plot in plots:
@ -316,5 +316,5 @@ class HarvesterAPI:
return make_msg(ProtocolMessageTypes.respond_plots, response)
@api_request()
async def plot_sync_response(self, response: PlotSyncResponse):
async def plot_sync_response(self, response: PlotSyncResponse) -> None:
self.harvester.plot_sync_sender.set_response(response)

View File

@ -118,8 +118,7 @@ class Sender:
await self.await_closed()
if self._task is None:
self._task = asyncio.create_task(self._run())
# TODO, Add typing in PlotManager
if not self._plot_manager.initial_refresh() or self._sync_id != 0: # type:ignore[no-untyped-call]
if not self._plot_manager.initial_refresh() or self._sync_id != 0:
self._reset()
else:
raise AlreadyStartedError()
@ -173,7 +172,7 @@ class Sender:
return False
if response.identifier.sync_id != self._response.identifier.sync_id:
log.warning(
"set_response unexpected sync-id: " f"{response.identifier.sync_id}/{self._response.identifier.sync_id}"
"set_response unexpected sync-id: {response.identifier.sync_id}/{self._response.identifier.sync_id}"
)
return False
if response.identifier.message_id != self._response.identifier.message_id:
@ -184,7 +183,7 @@ class Sender:
return False
if response.message_type != int16(self._response.message_type.value):
log.warning(
"set_response unexpected message-type: " f"{response.message_type}/{self._response.message_type.value}"
"set_response unexpected message-type: {response.message_type}/{self._response.message_type.value}"
)
return False
log.debug(f"set_response valid {response}")

View File

@ -12,7 +12,7 @@ from typing import Any, Dict, Optional
import pkg_resources
from chia.plotting.create_plots import create_plots, resolve_plot_keys
from chia.plotting.util import add_plot_directory, validate_plot_size
from chia.plotting.util import Params, add_plot_directory, validate_plot_size
log = logging.getLogger(__name__)
@ -22,22 +22,6 @@ def get_chiapos_install_info() -> Optional[Dict[str, Any]]:
return {"display_name": "Chia Proof of Space", "version": chiapos_version, "installed": True}
class Params:
def __init__(self, args):
self.size = args.size
self.num = args.count
self.buffer = args.buffer
self.num_threads = args.threads
self.buckets = args.buckets
self.stripe_size = args.stripes
self.tmp_dir = Path(args.tmpdir)
self.tmp2_dir = Path(args.tmpdir2) if args.tmpdir2 else None
self.final_dir = Path(args.finaldir)
self.plotid = args.id
self.memo = args.memo
self.nobitfield = args.nobitfield
def plot_chia(args, root_path):
try:
validate_plot_size(root_path, args.size, args.override)
@ -56,7 +40,21 @@ def plot_chia(args, root_path):
args.connect_to_daemon,
)
)
asyncio.run(create_plots(Params(args), plot_keys))
params = Params(
size=args.size,
num=args.count,
buffer=args.buffer,
num_threads=args.threads,
buckets=args.buckets,
stripe_size=args.stripes,
tmp_dir=Path(args.tmpdir),
tmp2_dir=Path(args.tmpdir2) if args.tmpdir2 else None,
final_dir=Path(args.finaldir),
plotid=args.id,
memo=args.memo,
nobitfield=args.nobitfield,
)
asyncio.run(create_plots(params, plot_keys))
if not args.exclude_final_dir:
try:
add_plot_directory(root_path, args.finaldir)

View File

@ -4,7 +4,7 @@ import logging
from collections import Counter
from pathlib import Path
from time import sleep, time
from typing import List
from typing import List, Optional
from blspy import G1Element
from chiapos import Verifier
@ -21,20 +21,28 @@ from chia.plotting.util import (
from chia.util.bech32m import encode_puzzle_hash
from chia.util.config import load_config
from chia.util.hash import std_hash
from chia.util.ints import uint32
from chia.util.keychain import Keychain
from chia.wallet.derive_keys import master_sk_to_farmer_sk, master_sk_to_local_sk
log = logging.getLogger(__name__)
def plot_refresh_callback(event: PlotRefreshEvents, refresh_result: PlotRefreshResult):
def plot_refresh_callback(event: PlotRefreshEvents, refresh_result: PlotRefreshResult) -> None:
log.info(f"event: {event.name}, loaded {len(refresh_result.loaded)} plots, {refresh_result.remaining} remaining")
def check_plots(root_path, num, challenge_start, grep_string, list_duplicates, debug_show_memo):
def check_plots(
root_path: Path,
num: Optional[int],
challenge_start: Optional[int],
grep_string: str,
list_duplicates: bool,
debug_show_memo: bool,
) -> None:
config = load_config(root_path, "config.yaml")
address_prefix = config["network_overrides"]["config"][config["selected_network"]]["address_prefix"]
plot_refresh_parameter: PlotsRefreshParameter = PlotsRefreshParameter(batch_sleep_milliseconds=0)
plot_refresh_parameter: PlotsRefreshParameter = PlotsRefreshParameter(batch_sleep_milliseconds=uint32(0))
plot_manager: PlotManager = PlotManager(
root_path,
match_str=grep_string,
@ -97,7 +105,7 @@ def check_plots(root_path, num, challenge_start, grep_string, list_duplicates, d
log.info("")
log.info("")
log.info(f"Starting to test each plot with {num} challenges each\n")
total_good_plots: Counter = Counter()
total_good_plots: Counter[str] = Counter()
total_size = 0
bad_plots_list: List[Path] = []
@ -179,7 +187,7 @@ def check_plots(root_path, num, challenge_start, grep_string, list_duplicates, d
log.info("Summary")
total_plots: int = sum(list(total_good_plots.values()))
log.info(f"Found {total_plots} valid plots, total size {total_size / (1024 * 1024 * 1024 * 1024):.5f} TiB")
for (k, count) in sorted(dict(total_good_plots).items()):
for k, count in sorted(dict(total_good_plots).items()):
log.info(f"{count} plots of size {k}")
grand_total_bad = len(bad_plots_list) + len(plot_manager.failed_to_open_filenames)
if grand_total_bad > 0:

View File

@ -10,7 +10,7 @@ from blspy import AugSchemeMPL, G1Element, PrivateKey
from chiapos import DiskPlotter
from chia.daemon.keychain_proxy import KeychainProxy, connect_to_keychain_and_validate, wrap_local_keychain
from chia.plotting.util import stream_plot_info_ph, stream_plot_info_pk
from chia.plotting.util import Params, stream_plot_info_ph, stream_plot_info_pk
from chia.types.blockchain_format.proof_of_space import (
calculate_plot_id_ph,
calculate_plot_id_pk,
@ -51,8 +51,8 @@ class PlotKeysResolver:
pool_contract_address: Optional[str],
root_path: Path,
log: logging.Logger,
connect_to_daemon=False,
):
connect_to_daemon: bool = False,
) -> None:
self.farmer_public_key = farmer_public_key
self.alt_fingerprint = alt_fingerprint
self.pool_public_key = pool_public_key
@ -66,30 +66,33 @@ class PlotKeysResolver:
if self.resolved_keys is not None:
return self.resolved_keys
if self.connect_to_daemon:
keychain_proxy: Optional[KeychainProxy] = await connect_to_keychain_and_validate(self.root_path, self.log)
else:
keychain_proxy = wrap_local_keychain(Keychain(), log=self.log)
keychain_proxy: Optional[KeychainProxy] = None
try:
if self.connect_to_daemon:
keychain_proxy = await connect_to_keychain_and_validate(self.root_path, self.log)
else:
keychain_proxy = wrap_local_keychain(Keychain(), log=self.log)
farmer_public_key: G1Element
if self.farmer_public_key is not None:
farmer_public_key = G1Element.from_bytes(bytes.fromhex(self.farmer_public_key))
else:
farmer_public_key = await self.get_farmer_public_key(keychain_proxy)
farmer_public_key: G1Element
if self.farmer_public_key is not None:
farmer_public_key = G1Element.from_bytes(bytes.fromhex(self.farmer_public_key))
else:
farmer_public_key = await self.get_farmer_public_key(keychain_proxy)
pool_public_key: Optional[G1Element] = None
if self.pool_public_key is not None:
if self.pool_contract_address is not None:
raise RuntimeError("Choose one of pool_contract_address and pool_public_key")
pool_public_key = G1Element.from_bytes(bytes.fromhex(self.pool_public_key))
else:
if self.pool_contract_address is None:
# If nothing is set, farms to the provided key (or the first key)
pool_public_key = await self.get_pool_public_key(keychain_proxy)
pool_public_key: Optional[G1Element] = None
if self.pool_public_key is not None:
if self.pool_contract_address is not None:
raise RuntimeError("Choose one of pool_contract_address and pool_public_key")
pool_public_key = G1Element.from_bytes(bytes.fromhex(self.pool_public_key))
else:
if self.pool_contract_address is None:
# If nothing is set, farms to the provided key (or the first key)
pool_public_key = await self.get_pool_public_key(keychain_proxy)
self.resolved_keys = PlotKeys(farmer_public_key, pool_public_key, self.pool_contract_address)
if keychain_proxy is not None:
await keychain_proxy.close()
self.resolved_keys = PlotKeys(farmer_public_key, pool_public_key, self.pool_contract_address)
finally:
if keychain_proxy is not None:
await keychain_proxy.close()
return self.resolved_keys
async def get_sk(self, keychain_proxy: Optional[KeychainProxy] = None) -> Optional[Tuple[PrivateKey, bytes]]:
@ -138,7 +141,7 @@ async def resolve_plot_keys(
pool_contract_address: Optional[str],
root_path: Path,
log: logging.Logger,
connect_to_daemon=False,
connect_to_daemon: bool = False,
) -> PlotKeys:
return await PlotKeysResolver(
farmer_public_key, alt_fingerprint, pool_public_key, pool_contract_address, root_path, log, connect_to_daemon
@ -146,10 +149,10 @@ async def resolve_plot_keys(
async def create_plots(
args,
args: Params,
keys: PlotKeys,
use_datetime: bool = True,
test_private_keys: Optional[List] = None,
test_private_keys: Optional[List[PrivateKey]] = None,
) -> Tuple[Dict[bytes32, Path], Dict[bytes32, Path]]:
if args.tmp2_dir is None:
args.tmp2_dir = args.tmp_dir

View File

@ -73,7 +73,7 @@ class PlotManager:
def __exit__(self, exc_type, exc_value, exc_traceback):
self._lock.release()
def reset(self):
def reset(self) -> None:
with self:
self.last_refresh_time = time.time()
self.plots.clear()
@ -89,11 +89,11 @@ class PlotManager:
self.farmer_public_keys = farmer_public_keys
self.pool_public_keys = pool_public_keys
def initial_refresh(self):
def initial_refresh(self) -> bool:
return self._initial
def public_keys_available(self):
return len(self.farmer_public_keys) and len(self.pool_public_keys)
def public_keys_available(self) -> bool:
return len(self.farmer_public_keys) > 0 and len(self.pool_public_keys) > 0
def plot_count(self) -> int:
with self:

View File

@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Tuple, Union
from blspy import G1Element, PrivateKey
from chiapos import DiskProver
from typing_extensions import final
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.config import load_config, lock_and_load_config, save_config
@ -66,6 +67,23 @@ class PlotRefreshResult:
duration: float = 0
@final
@dataclass
class Params:
size: int
num: int
buffer: int
num_threads: int
buckets: int
tmp_dir: Path
tmp2_dir: Optional[Path]
final_dir: Path
plotid: Optional[str]
memo: Optional[str]
nobitfield: bool
stripe_size: int = 65536
def get_plot_directories(root_path: Path, config: Dict = None) -> List[str]:
if config is None:
config = load_config(root_path, "config.yaml")

View File

@ -62,7 +62,7 @@ def load_pool_config(root_path: Path) -> List[PoolWalletConfig]:
# TODO: remove this a few versions after 1.3, since authentication_public_key is deprecated. This is here to support
# downgrading to versions older than 1.3.
def add_auth_key(root_path: Path, config_entry: PoolWalletConfig, auth_key: G1Element):
def add_auth_key(root_path: Path, config_entry: PoolWalletConfig, auth_key: G1Element) -> None:
with lock_and_load_config(root_path, "config.yaml") as config:
pool_list = config["pool"].get("pool_list", [])
updated = False
@ -82,7 +82,7 @@ def add_auth_key(root_path: Path, config_entry: PoolWalletConfig, auth_key: G1El
save_config(root_path, "config.yaml", config)
async def update_pool_config(root_path: Path, pool_config_list: List[PoolWalletConfig]):
async def update_pool_config(root_path: Path, pool_config_list: List[PoolWalletConfig]) -> None:
with lock_and_load_config(root_path, "config.yaml") as full_config:
full_config["pool"]["pool_list"] = [c.to_json_dict() for c in pool_config_list]
save_config(root_path, "config.yaml", full_config)

View File

@ -45,7 +45,7 @@ from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.coin_spend import CoinSpend, compute_additions
from chia.types.spend_bundle import SpendBundle
from chia.util.ints import uint8, uint32, uint64, uint128
from chia.util.ints import uint32, uint64, uint128
from chia.wallet.derive_keys import find_owner_sk
from chia.wallet.sign_coin_spends import sign_coin_spends
from chia.wallet.transaction_record import TransactionRecord
@ -119,14 +119,14 @@ class PoolWallet:
"""
@classmethod
def type(cls) -> uint8:
return uint8(WalletType.POOLING_WALLET)
def type(cls) -> WalletType:
return WalletType.POOLING_WALLET
def id(self):
def id(self) -> uint32:
return self.wallet_info.id
@classmethod
def _verify_self_pooled(cls, state) -> Optional[str]:
def _verify_self_pooled(cls, state: PoolState) -> Optional[str]:
err = ""
if state.pool_url not in [None, ""]:
err += " Unneeded pool_url for self-pooling"
@ -137,7 +137,7 @@ class PoolWallet:
return None if err == "" else err
@classmethod
def _verify_pooling_state(cls, state) -> Optional[str]:
def _verify_pooling_state(cls, state: PoolState) -> Optional[str]:
err = ""
if state.relative_lock_height < cls.MINIMUM_RELATIVE_LOCK_HEIGHT:
err += (
@ -166,15 +166,18 @@ class PoolWallet:
f"to use this pooling wallet"
)
if state.state == PoolSingletonState.SELF_POOLING:
if state.state == PoolSingletonState.SELF_POOLING.value:
return cls._verify_self_pooled(state)
elif state.state == PoolSingletonState.FARMING_TO_POOL or state.state == PoolSingletonState.LEAVING_POOL:
elif (
state.state == PoolSingletonState.FARMING_TO_POOL.value
or state.state == PoolSingletonState.LEAVING_POOL.value
):
return cls._verify_pooling_state(state)
else:
return "Internal Error"
@classmethod
def _verify_initial_target_state(cls, initial_target_state):
def _verify_initial_target_state(cls, initial_target_state: PoolState) -> None:
err = cls._verify_pool_state(initial_target_state)
if err:
raise ValueError(f"Invalid internal Pool State: {err}: {initial_target_state}")
@ -330,7 +333,7 @@ class PoolWallet:
block_spends: List[CoinSpend],
block_height: uint32,
*,
name: str = None,
name: Optional[str] = None,
) -> PoolWallet:
"""
This creates a new PoolWallet with only one spend: the launcher spend. The DB MUST be committed after calling
@ -357,7 +360,7 @@ class PoolWallet:
await pool_wallet.update_pool_config()
p2_puzzle_hash: bytes32 = (await pool_wallet.get_current_state()).p2_singleton_puzzle_hash
await wallet_state_manager.add_new_wallet(pool_wallet, pool_wallet.wallet_id, create_puzzle_hashes=False)
await wallet_state_manager.add_new_wallet(pool_wallet)
await wallet_state_manager.add_interested_puzzle_hashes([p2_puzzle_hash], [pool_wallet.wallet_id])
return pool_wallet
@ -368,7 +371,7 @@ class PoolWallet:
wallet_state_manager: Any,
wallet: Wallet,
wallet_info: WalletInfo,
name: str = None,
name: Optional[str] = None,
) -> PoolWallet:
"""
This creates a PoolWallet from DB. However, all data is already handled by WalletPoolStore, so we don't need
@ -486,7 +489,11 @@ class PoolWallet:
self.wallet_state_manager.constants.MAX_BLOCK_COST_CLVM,
)
async def generate_fee_transaction(self, fee: uint64, coin_announcements=None) -> TransactionRecord:
async def generate_fee_transaction(
self,
fee: uint64,
coin_announcements: Optional[Set[Announcement]] = None,
) -> TransactionRecord:
fee_tx = await self.standard_wallet.generate_signed_transaction(
uint64(0),
(await self.standard_wallet.get_new_puzzlehash()),
@ -499,7 +506,7 @@ class PoolWallet:
)
return fee_tx
async def publish_transactions(self, travel_tx: TransactionRecord, fee_tx: Optional[TransactionRecord]):
async def publish_transactions(self, travel_tx: TransactionRecord, fee_tx: Optional[TransactionRecord]) -> None:
# We create two transaction records, one for the pool wallet to keep track of the travel TX, and another
# for the standard wallet to keep track of the fee. However, we will only submit the first one to the
# blockchain, and this one has the fee inside it as well.
@ -519,7 +526,7 @@ class PoolWallet:
delayed_seconds, delayed_puzhash = get_delayed_puz_info_from_launcher_spend(spend_history[0][1])
assert pool_wallet_info.target is not None
next_state = pool_wallet_info.target
if pool_wallet_info.current.state in [FARMING_TO_POOL]:
if pool_wallet_info.current.state == FARMING_TO_POOL.value:
next_state = create_pool_state(
LEAVING_POOL,
pool_wallet_info.current.target_puzzle_hash,
@ -589,6 +596,7 @@ class PoolWallet:
fee_tx: Optional[TransactionRecord] = None
if fee > 0:
fee_tx = await self.generate_fee_transaction(fee)
assert fee_tx.spend_bundle is not None
signed_spend_bundle = SpendBundle.aggregate([signed_spend_bundle, fee_tx.spend_bundle])
tx_record = TransactionRecord(
@ -656,9 +664,9 @@ class PoolWallet:
delay_ph,
)
if initial_target_state.state == SELF_POOLING:
if initial_target_state.state == SELF_POOLING.value:
puzzle = escaping_inner_puzzle
elif initial_target_state.state == FARMING_TO_POOL:
elif initial_target_state.state == FARMING_TO_POOL.value:
puzzle = self_pooling_inner_puzzle
else:
raise ValueError("Invalid initial state")
@ -698,7 +706,7 @@ class PoolWallet:
async def join_pool(
self, target_state: PoolState, fee: uint64
) -> Tuple[uint64, TransactionRecord, Optional[TransactionRecord]]:
if target_state.state != FARMING_TO_POOL:
if target_state.state != FARMING_TO_POOL.value:
raise ValueError(f"join_pool must be called with target_state={FARMING_TO_POOL} (FARMING_TO_POOL)")
if self.target_state is not None:
raise ValueError(f"Cannot join a pool while waiting for target state: {self.target_state}")
@ -715,9 +723,9 @@ class PoolWallet:
msg = f"Asked to change to current state. Target = {target_state}"
self.log.info(msg)
raise ValueError(msg)
elif current_state.current.state in [SELF_POOLING, LEAVING_POOL]:
elif current_state.current.state in [SELF_POOLING.value, LEAVING_POOL.value]:
total_fee = fee
elif current_state.current.state == FARMING_TO_POOL:
elif current_state.current.state == FARMING_TO_POOL.value:
total_fee = uint64(fee * 2)
if self.target_state is not None:
@ -725,7 +733,7 @@ class PoolWallet:
f"Cannot change to state {target_state} when already having target state: {self.target_state}"
)
PoolWallet._verify_initial_target_state(target_state)
if current_state.current.state == LEAVING_POOL:
if current_state.current.state == LEAVING_POOL.value:
history: List[Tuple[uint32, CoinSpend]] = await self.get_spend_history()
last_height: uint32 = history[-1][0]
if (
@ -747,7 +755,7 @@ class PoolWallet:
"Cannot self pool due to unconfirmed transaction. If this is stuck, delete the unconfirmed transaction."
)
pool_wallet_info: PoolWalletInfo = await self.get_current_state()
if pool_wallet_info.current.state == SELF_POOLING:
if pool_wallet_info.current.state == SELF_POOLING.value:
raise ValueError("Attempted to self pool when already self pooling")
if self.target_state is not None:
@ -760,7 +768,7 @@ class PoolWallet:
current_state: PoolWalletInfo = await self.get_current_state()
total_fee = uint64(fee * 2)
if current_state.current.state == LEAVING_POOL:
if current_state.current.state == LEAVING_POOL.value:
total_fee = fee
history: List[Tuple[uint32, CoinSpend]] = await self.get_spend_history()
last_height: uint32 = history[-1][0]
@ -856,7 +864,9 @@ class PoolWallet:
fee_tx = None
if fee > 0:
absorb_announce = Announcement(first_coin_record.coin.name(), b"$")
fee_tx = await self.generate_fee_transaction(fee, coin_announcements=[absorb_announce])
assert absorb_announce is not None
fee_tx = await self.generate_fee_transaction(fee, coin_announcements={absorb_announce})
assert fee_tx.spend_bundle is not None
full_spend = SpendBundle.aggregate([fee_tx.spend_bundle, claim_spend])
assert full_spend.fees() == fee
@ -892,13 +902,13 @@ class PoolWallet:
if self.target_state is None:
return
if self.target_state == pool_wallet_info.current.state:
if self.target_state == pool_wallet_info.current:
self.target_state = None
raise ValueError("Internal error")
raise ValueError(f"Internal error. Pool wallet {self.wallet_id} state: {pool_wallet_info.current}")
if (
self.target_state.state in [FARMING_TO_POOL, SELF_POOLING]
and pool_wallet_info.current.state == LEAVING_POOL
self.target_state.state in [FARMING_TO_POOL.value, SELF_POOLING.value]
and pool_wallet_info.current.state == LEAVING_POOL.value
):
leave_height = tip_height + pool_wallet_info.current.relative_lock_height
@ -916,13 +926,13 @@ class PoolWallet:
self.log.info(f"Attempting to leave from\n{pool_wallet_info.current}\nto\n{self.target_state}")
assert self.target_state.version == POOL_PROTOCOL_VERSION
assert pool_wallet_info.current.state == LEAVING_POOL
assert pool_wallet_info.current.state == LEAVING_POOL.value
assert self.target_state.target_puzzle_hash is not None
if self.target_state.state == SELF_POOLING:
if self.target_state.state == SELF_POOLING.value:
assert self.target_state.relative_lock_height == 0
assert self.target_state.pool_url is None
elif self.target_state.state == FARMING_TO_POOL:
elif self.target_state.state == FARMING_TO_POOL.value:
assert self.target_state.relative_lock_height >= self.MINIMUM_RELATIVE_LOCK_HEIGHT
assert self.target_state.pool_url is not None
@ -934,9 +944,9 @@ class PoolWallet:
)
return len(unconfirmed) > 0
async def get_confirmed_balance(self, _=None) -> uint128:
async def get_confirmed_balance(self, _: Optional[object] = None) -> uint128:
amount: uint128 = uint128(0)
if (await self.get_current_state()).current.state == SELF_POOLING:
if (await self.get_current_state()).current.state == SELF_POOLING.value:
unspent_coin_records: List[WalletCoinRecord] = list(
await self.wallet_state_manager.coin_store.get_unspent_coins_for_wallet(self.wallet_id)
)
@ -945,10 +955,10 @@ class PoolWallet:
amount = uint128(amount + record.coin.amount)
return amount
async def get_unconfirmed_balance(self, record_list=None) -> uint128:
async def get_unconfirmed_balance(self, record_list: Optional[object] = None) -> uint128:
return await self.get_confirmed_balance(record_list)
async def get_spendable_balance(self, record_list=None) -> uint128:
async def get_spendable_balance(self, record_list: Optional[object] = None) -> uint128:
return await self.get_confirmed_balance(record_list)
async def get_pending_change_balance(self) -> uint64:

View File

@ -173,5 +173,5 @@ def get_current_authentication_token(timeout: uint8) -> uint64:
# Validate a given authentication token against our local time
def validate_authentication_token(token: uint64, timeout: uint8):
def validate_authentication_token(token: uint64, timeout: uint8) -> bool:
return abs(token - get_current_authentication_token(timeout)) <= timeout

View File

@ -30,7 +30,7 @@ class CrawlerRpcApi:
return payloads
async def get_peer_counts(self, _request: Dict) -> EndpointResult:
async def get_peer_counts(self, _request: Dict[str, Any]) -> EndpointResult:
ipv6_addresses_count = 0
for host in self.service.best_timestamp_per_peer.keys():
try:
@ -54,7 +54,7 @@ class CrawlerRpcApi:
}
return data
async def get_ips_after_timestamp(self, _request: Dict) -> EndpointResult:
async def get_ips_after_timestamp(self, _request: Dict[str, Any]) -> EndpointResult:
after = _request.get("after", None)
if after is None:
raise ValueError("`after` is required and must be a unix timestamp")

View File

@ -11,71 +11,62 @@ from chia.util.ints import uint64
class DataLayerRpcClient(RpcClient):
async def create_data_store(self, fee: Optional[uint64]) -> Dict[str, Any]:
response = await self.fetch("create_data_store", {"fee": fee})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_value(self, store_id: bytes32, key: bytes, root_hash: Optional[bytes32]) -> Dict[str, Any]:
request: Dict[str, Any] = {"id": store_id.hex(), "key": key.hex()}
if root_hash is not None:
request["root_hash"] = root_hash.hex()
response = await self.fetch("get_value", request)
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def update_data_store(
self, store_id: bytes32, changelist: List[Dict[str, str]], fee: Optional[uint64]
) -> Dict[str, Any]:
response = await self.fetch("batch_update", {"id": store_id.hex(), "changelist": changelist, "fee": fee})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_keys_values(self, store_id: bytes32, root_hash: Optional[bytes32]) -> Dict[str, Any]:
request: Dict[str, Any] = {"id": store_id.hex()}
if root_hash is not None:
request["root_hash"] = root_hash.hex()
response = await self.fetch("get_keys_values", request)
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_keys(self, store_id: bytes32, root_hash: Optional[bytes32]) -> Dict[str, Any]:
request: Dict[str, Any] = {"id": store_id.hex()}
if root_hash is not None:
request["root_hash"] = root_hash.hex()
response = await self.fetch("get_keys", request)
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_ancestors(self, store_id: bytes32, hash: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_ancestors", {"id": store_id.hex(), "hash": hash})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_root(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_root", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_local_root(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_local_root", {"id": store_id.hex()})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def get_roots(self, store_ids: List[bytes32]) -> Dict[str, Any]:
response = await self.fetch("get_roots", {"ids": store_ids})
# TODO: better hinting for .fetch() (probably a TypedDict)
return response # type: ignore[no-any-return]
return response
async def subscribe(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("subscribe", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
return response
async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("remove_subscriptions", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
return response
async def unsubscribe(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("unsubscribe", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response
async def add_missing_files(
self, store_ids: Optional[List[bytes32]], overwrite: Optional[bool], foldername: Optional[Path]
@ -88,40 +79,40 @@ class DataLayerRpcClient(RpcClient):
if foldername is not None:
request["foldername"] = str(foldername)
response = await self.fetch("add_missing_files", request)
return response # type: ignore[no-any-return]
return response
async def get_kv_diff(self, store_id: bytes32, hash_1: bytes32, hash_2: bytes32) -> Dict[str, Any]:
response = await self.fetch(
"get_kv_diff", {"id": store_id.hex(), "hash_1": hash_1.hex(), "hash_2": hash_2.hex()}
)
return response # type: ignore[no-any-return]
return response
async def get_root_history(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_root_history", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response
async def add_mirror(
self, store_id: bytes32, urls: List[str], amount: int, fee: Optional[uint64]
) -> Dict[str, Any]:
response = await self.fetch("add_mirror", {"id": store_id.hex(), "urls": urls, "amount": amount, "fee": fee})
return response # type: ignore[no-any-return]
return response
async def delete_mirror(self, coin_id: bytes32, fee: Optional[uint64]) -> Dict[str, Any]:
response = await self.fetch("delete_mirror", {"coin_id": coin_id.hex(), "fee": fee})
return response # type: ignore[no-any-return]
return response
async def get_mirrors(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_mirrors", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response
async def get_subscriptions(self) -> Dict[str, Any]:
response = await self.fetch("subscriptions", {})
return response # type: ignore[no-any-return]
return response
async def get_owned_stores(self) -> Dict[str, Any]:
response = await self.fetch("get_owned_stores", {})
return response # type: ignore[no-any-return]
return response
async def get_sync_status(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("get_sync_status", {"id": store_id.hex()})
return response # type: ignore[no-any-return]
return response

Some files were not shown because too many files have changed in this diff Show More