diff --git a/chia/consensus/block_body_validation.py b/chia/consensus/block_body_validation.py index 6d8345ead43a..f0a6022f559f 100644 --- a/chia/consensus/block_body_validation.py +++ b/chia/consensus/block_body_validation.py @@ -2,7 +2,7 @@ import collections import logging from typing import Dict, List, Optional, Set, Tuple, Union, Callable -from blspy import AugSchemeMPL, G1Element +from blspy import G1Element from chiabip158 import PyBIP158 from clvm.casts import int_from_bytes @@ -27,6 +27,7 @@ from chia.types.full_block import FullBlock from chia.types.generator_types import BlockGenerator from chia.types.name_puzzle_condition import NPC from chia.types.unfinished_block import UnfinishedBlock +from chia.util import cached_bls from chia.util.condition_tools import ( pkm_pairs_for_conditions_dict, coin_announcements_names_for_npc, @@ -477,8 +478,15 @@ async def validate_block_body( if not block.transactions_info.aggregated_signature: return Err.BAD_AGGREGATE_SIGNATURE, None - # noinspection PyTypeChecker - if not AugSchemeMPL.aggregate_verify(pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature): + # The pairing cache is not useful while syncing as each pairing is seen + # only once, so the extra effort of populating it is not justified. + # However, we force caching of pairings just for unfinished blocks + # as the cache is likely to be useful when validating the corresponding + # finished blocks later. + force_cache: bool = isinstance(block, UnfinishedBlock) + if not cached_bls.aggregate_verify( + pairs_pks, pairs_msgs, block.transactions_info.aggregated_signature, force_cache + ): return Err.BAD_AGGREGATE_SIGNATURE, None return None, npc_result diff --git a/chia/full_node/mempool_manager.py b/chia/full_node/mempool_manager.py index 810f6278f607..af00abefcbef 100644 --- a/chia/full_node/mempool_manager.py +++ b/chia/full_node/mempool_manager.py @@ -5,9 +5,10 @@ import logging import time from concurrent.futures.process import ProcessPoolExecutor from typing import Dict, List, Optional, Set, Tuple -from blspy import AugSchemeMPL, G1Element +from blspy import G1Element from chiabip158 import PyBIP158 +from chia.util import cached_bls from chia.consensus.block_record import BlockRecord from chia.consensus.constants import ConsensusConstants from chia.consensus.cost_calculator import NPCResult, calculate_cost_of_program @@ -427,7 +428,7 @@ class MempoolManager: if validate_signature: # Verify aggregated signature - if not AugSchemeMPL.aggregate_verify(pks, msgs, new_spend.aggregated_signature): + if not cached_bls.aggregate_verify(pks, msgs, new_spend.aggregated_signature, True): log.warning(f"Aggsig validation error {pks} {msgs} {new_spend}") return None, MempoolInclusionStatus.FAILED, Err.BAD_AGGREGATE_SIGNATURE # Remove all conflicting Coins and SpendBundles diff --git a/chia/util/cached_bls.py b/chia/util/cached_bls.py new file mode 100644 index 000000000000..e8712f60b772 --- /dev/null +++ b/chia/util/cached_bls.py @@ -0,0 +1,49 @@ +import functools +from typing import List, Optional + +from blspy import AugSchemeMPL, G1Element, G2Element, GTElement +from chia.util.hash import std_hash +from chia.util.lru_cache import LRUCache + + +def get_pairings(cache: LRUCache, pks: List[G1Element], msgs: List[bytes], force_cache: bool) -> List[GTElement]: + pairings: List[Optional[GTElement]] = [] + missing_count: int = 0 + for pk, msg in zip(pks, msgs): + aug_msg: bytes = bytes(pk) + msg + h: bytes = bytes(std_hash(aug_msg)) + pairing: Optional[GTElement] = cache.get(h) + if not force_cache and pairing is None: + missing_count += 1 + # Heuristic to avoid more expensive sig validation with pairing + # cache when it's empty and cached pairings won't be useful later + # (e.g. while syncing) + if missing_count > len(pks) // 2: + return [] + pairings.append(pairing) + + for i, pairing in enumerate(pairings): + if pairing is None: + aug_msg = bytes(pks[i]) + msgs[i] + aug_hash: G2Element = AugSchemeMPL.g2_from_message(aug_msg) + pairing = pks[i].pair(aug_hash) + + h = bytes(std_hash(aug_msg)) + cache.put(h, pairing) + pairings[i] = pairing + + return pairings + + +LOCAL_CACHE: LRUCache = LRUCache(10000) + + +def aggregate_verify( + pks: List[G1Element], msgs: List[bytes], sig: G2Element, force_cache: bool = False, cache: LRUCache = LOCAL_CACHE +): + pairings: List[GTElement] = get_pairings(cache, pks, msgs, force_cache) + if len(pairings) == 0: + return AugSchemeMPL.aggregate_verify(pks, msgs, sig) + + pairings_prod: GTElement = functools.reduce(GTElement.__mul__, pairings) + return pairings_prod == sig.pair(G1Element.generator()) diff --git a/tests/core/full_node/test_performance.py b/tests/core/full_node/test_performance.py index 46f81377a2ff..49d397063605 100644 --- a/tests/core/full_node/test_performance.py +++ b/tests/core/full_node/test_performance.py @@ -143,7 +143,6 @@ class TestPerformance: start = time.time() num_tx: int = 0 for spend_bundle, spend_bundle_id in zip(spend_bundles, spend_bundle_ids): - log.warning(f"Num Tx: {num_tx}") num_tx += 1 respond_transaction = fnp.RespondTransaction(spend_bundle) @@ -154,6 +153,7 @@ class TestPerformance: if req is None: break + log.warning(f"Num Tx: {num_tx}") log.warning(f"Time for mempool: {time.time() - start}") pr.create_stats() pr.dump_stats("./mempool-benchmark.pstats")