Test get_bucket_index (#14294)

* Test get_bucket_index

* Apply epsilon to test constants

* Simplify get_bucket_index
This commit is contained in:
Adam Kelly 2023-01-13 09:16:23 -08:00 committed by GitHub
parent 745ae8496c
commit 06eb18217e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 85 additions and 45 deletions

View File

@ -39,7 +39,7 @@ class SmartFeeEstimator:
# get_bucket_index returns left (-1) bucket (-1). Start value is already -1
# We want +1 from the lowest bucket it failed at. Thus +3
max_val = len(self.fee_tracker.buckets) - 1
start_index = min(get_bucket_index(self.fee_tracker.sorted_buckets, fail_bucket.start) + 3, max_val)
start_index = min(get_bucket_index(self.fee_tracker.buckets, fail_bucket.start) + 3, max_val)
fee_val: float = self.fee_tracker.buckets[start_index]
return fee_val

View File

@ -2,9 +2,9 @@
from __future__ import annotations
MIN_FEE_RATE = 0 # Value of first bucket
INITIAL_STEP = 5 # First bucket after zero value
MAX_FEE_RATE = 40000000 # Mojo per 1000 cost unit
INFINITE_FEE_RATE = 1000000000
INITIAL_STEP = 5.0 # First bucket after zero value
MAX_FEE_RATE = 40000000.0 # Mojo per 1000 cost unit
INFINITE_FEE_RATE = 1000000000.0
STEP_SIZE = 1.05 # bucket increase by 1.05

View File

@ -1,11 +1,10 @@
from __future__ import annotations
import logging
from bisect import bisect_left
from dataclasses import dataclass
from typing import List, Optional, Tuple
from sortedcontainers import SortedDict
from chia.full_node.fee_estimate_store import FeeStore
from chia.full_node.fee_estimator_constants import (
FEE_ESTIMATOR_VERSION,
@ -61,21 +60,10 @@ def get_estimate_time_intervals() -> List[uint64]:
return [uint64(blocks * SECONDS_PER_BLOCK) for blocks in get_estimate_block_intervals()]
def get_bucket_index(sorted_buckets: SortedDict, fee_rate: float) -> int:
if fee_rate in sorted_buckets:
bucket_index = sorted_buckets[fee_rate]
else:
# Choose the bucket to the left if we do not have exactly this fee rate
bucket_index = sorted_buckets.bisect_left(fee_rate) - 1
return int(bucket_index)
# Implementation of bitcoin core fee estimation algorithm
# https://gist.github.com/morcos/d3637f015bc4e607e1fd10d8351e9f41
class FeeStat: # TxConfirmStats
buckets: List[float]
sorted_buckets: SortedDict # key is upper bound of bucket, val is index in buckets
buckets: List[float] # These elements represent the upper-bound of the range for the bucket
# For each bucket xL
# Count the total number of txs in each bucket
@ -111,7 +99,6 @@ class FeeStat: # TxConfirmStats
def __init__(
self,
buckets: List[float],
sorted_buckets: SortedDict,
max_periods: int,
decay: float,
scale: int,
@ -119,7 +106,6 @@ class FeeStat: # TxConfirmStats
my_type: str,
):
self.buckets = buckets
self.sorted_buckets = sorted_buckets
self.confirmed_average = [[] for _ in range(0, max_periods)]
self.failed_average = [[] for _ in range(0, max_periods)]
self.decay = decay
@ -150,7 +136,7 @@ class FeeStat: # TxConfirmStats
periods_to_confirm = int((blocks_to_confirm + self.scale - 1) / self.scale)
fee_rate = item.fee_per_cost * 1000
bucket_index = get_bucket_index(self.sorted_buckets, fee_rate)
bucket_index = get_bucket_index(self.buckets, fee_rate)
for i in range(periods_to_confirm, len(self.confirmed_average)):
self.confirmed_average[i - 1][bucket_index] += 1
@ -173,7 +159,7 @@ class FeeStat: # TxConfirmStats
self.unconfirmed_txs[block_height % len(self.unconfirmed_txs)][i] = 0
def new_mempool_tx(self, block_height: uint32, fee_rate: float) -> int:
bucket_index: int = get_bucket_index(self.sorted_buckets, fee_rate)
bucket_index: int = get_bucket_index(self.buckets, fee_rate)
block_index = block_height % len(self.unconfirmed_txs)
self.unconfirmed_txs[block_index][bucket_index] += 1
return bucket_index
@ -400,8 +386,32 @@ class FeeStat: # TxConfirmStats
return result
def clamp(n: int, smallest: int, largest: int) -> int:
return max(smallest, min(n, largest))
def get_bucket_index(buckets: List[float], fee_rate: float) -> int:
if len(buckets) < 1:
raise RuntimeError("get_bucket_index: buckets is invalid ({buckets})")
# Choose the bucket to the left if we do not have exactly this fee rate
# Python's list.bisect_left returns the index to insert a new element into a sorted list
bucket_index = bisect_left(buckets, fee_rate) - 1
return clamp(bucket_index, 0, len(buckets) - 1)
def init_buckets() -> List[float]:
fee_rate = INITIAL_STEP
buckets: List[float] = []
while fee_rate < MAX_FEE_RATE:
buckets.append(fee_rate)
fee_rate = fee_rate * STEP_SIZE
buckets.append(INFINITE_FEE_RATE)
return buckets
class FeeTracker:
sorted_buckets: SortedDict
short_horizon: FeeStat
med_horizon: FeeStat
long_horizon: FeeStat
@ -413,30 +423,13 @@ class FeeTracker:
def __init__(self, fee_store: FeeStore):
self.log = logging.Logger(__name__)
self.sorted_buckets = SortedDict()
self.buckets = []
self.latest_seen_height = uint32(0)
self.first_recorded_height = uint32(0)
self.fee_store = fee_store
fee_rate = 0.0
index = 0
while fee_rate < MAX_FEE_RATE:
self.buckets.append(fee_rate)
self.sorted_buckets[fee_rate] = index
if fee_rate == 0:
fee_rate = INITIAL_STEP
else:
fee_rate = fee_rate * STEP_SIZE
index += 1
self.buckets.append(INFINITE_FEE_RATE)
self.sorted_buckets[INFINITE_FEE_RATE] = index
assert len(self.sorted_buckets.keys()) == len(self.buckets)
self.buckets = init_buckets()
self.short_horizon = FeeStat(
self.buckets,
self.sorted_buckets,
SHORT_BLOCK_PERIOD,
SHORT_DECAY,
SHORT_SCALE,
@ -445,7 +438,6 @@ class FeeTracker:
)
self.med_horizon = FeeStat(
self.buckets,
self.sorted_buckets,
MED_BLOCK_PERIOD,
MED_DECAY,
MED_SCALE,
@ -454,7 +446,6 @@ class FeeTracker:
)
self.long_horizon = FeeStat(
self.buckets,
self.sorted_buckets,
LONG_BLOCK_PERIOD,
LONG_DECAY,
LONG_SCALE,
@ -525,14 +516,14 @@ class FeeTracker:
self.log.info(f"Processing Item from pending pool: cost={item.cost} fee={item.fee}")
fee_rate = item.fee_per_cost * 1000
bucket_index: int = get_bucket_index(self.sorted_buckets, fee_rate)
bucket_index: int = get_bucket_index(self.buckets, fee_rate)
self.short_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
self.med_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
self.long_horizon.new_mempool_tx(self.latest_seen_height, bucket_index)
def remove_tx(self, item: MempoolItem) -> None:
bucket_index = get_bucket_index(self.sorted_buckets, item.fee_per_cost * 1000)
bucket_index = get_bucket_index(self.buckets, item.fee_per_cost * 1000)
self.short_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
self.med_horizon.remove_tx(self.latest_seen_height, item, bucket_index)
self.long_horizon.remove_tx(self.latest_seen_height, item, bucket_index)

View File

@ -3,12 +3,15 @@ from __future__ import annotations
import logging
from typing import List
import pytest
from chia_rs import Coin
from chia.consensus.cost_calculator import NPCResult
from chia.full_node.bitcoin_fee_estimator import create_bitcoin_fee_estimator
from chia.full_node.fee_estimation import FeeBlockInfo
from chia.full_node.fee_estimator_constants import INFINITE_FEE_RATE, INITIAL_STEP
from chia.full_node.fee_estimator_interface import FeeEstimatorInterface
from chia.full_node.fee_tracker import get_bucket_index, init_buckets
from chia.simulator.block_tools import test_constants
from chia.simulator.wallet_tools import WalletTool
from chia.types.clvm_cost import CLVMCost
@ -145,3 +148,49 @@ def test_fee_estimation_inception() -> None:
# Confirm that estimates start after block 4
assert e1 == [0, 0, 0, 2, 2, 2, 2]
def test_init_buckets() -> None:
buckets = init_buckets()
assert len(buckets) > 1
assert buckets[0] == INITIAL_STEP
assert buckets[-1] == INFINITE_FEE_RATE
def test_get_bucket_index_empty_buckets() -> None:
buckets: List[float] = []
for rate in [0.5, 1.0, 2.0]:
with pytest.raises(RuntimeError):
a = get_bucket_index(buckets, rate)
log.warning(a)
def test_get_bucket_index_fee_rate_too_high() -> None:
buckets = [0.5, 1.0, 2.0]
index = get_bucket_index(buckets, 3.0)
assert index == len(buckets) - 1
def test_get_bucket_index_single_entry() -> None:
"""Test single entry with low, equal and high keys"""
from sys import float_info
e = float_info.epsilon * 10
buckets = [1.0]
print()
print(buckets)
for rate, expected_index in ((0.5, 0), (1.0 - e, 0), (1.5, 0)):
result_index = get_bucket_index(buckets, rate)
print(rate, expected_index, result_index)
assert expected_index == result_index
def test_get_bucket_index() -> None:
from sys import float_info
e = float_info.epsilon * 10
buckets = [1.0, 2.0]
for rate, expected_index in ((0.5, 0), (1.0 - e, 0), (1.5, 0), (2.0 - e, 0), (2.0 + e, 1), (2.1, 1)):
result_index = get_bucket_index(buckets, rate)
assert result_index == expected_index