From cb42f7f20550518f16b82cda554e1a80e9e63212 Mon Sep 17 00:00:00 2001 From: Florin Chirica Date: Thu, 28 Jul 2022 18:47:15 +0300 Subject: [PATCH] Add remove_subscription and fix bug. --- chia/cmds/data.py | 14 ++++++++++++++ chia/cmds/data_funcs.py | 15 +++++++++++++++ chia/data_layer/data_layer.py | 11 +++++++++-- chia/data_layer/data_store.py | 18 ++++++++++++++---- chia/data_layer/download_data.py | 6 ++++-- chia/rpc/data_layer_rpc_api.py | 9 +++++++++ chia/rpc/data_layer_rpc_client.py | 4 ++++ 7 files changed, 69 insertions(+), 8 deletions(-) diff --git a/chia/cmds/data.py b/chia/cmds/data.py index 36bd0175cb83..3ac60ef232fb 100644 --- a/chia/cmds/data.py +++ b/chia/cmds/data.py @@ -196,6 +196,20 @@ def subscribe( run(subscribe_cmd(rpc_port=data_rpc_port, store_id=id, urls=urls)) +@data_cmd.command("remove_subscription", short_help="") +@create_data_store_id_option() +@click.option("-u", "--url", "urls", help="", type=str, multiple=True) +@create_rpc_port_option() +def remove_subscription( + id: str, + urls: List[str], + data_rpc_port: int, +) -> None: + from chia.cmds.data_funcs import remove_subscriptions_cmd + + run(remove_subscriptions_cmd(rpc_port=data_rpc_port, store_id=id, urls=urls)) + + @data_cmd.command("unsubscribe", short_help="") @create_data_store_id_option() @create_rpc_port_option() diff --git a/chia/cmds/data_funcs.py b/chia/cmds/data_funcs.py index fbf9b6651f3b..ff8931a860db 100644 --- a/chia/cmds/data_funcs.py +++ b/chia/cmds/data_funcs.py @@ -156,6 +156,21 @@ async def unsubscribe_cmd( print(f"Exception from 'data': {e}") +async def remove_subscriptions_cmd( + rpc_port: Optional[int], + store_id: str, + urls: List[str], +) -> None: + store_id_bytes = bytes32.from_hexstr(store_id) + try: + async with get_client(rpc_port) as (client, rpc_port): + await client.remove_subscriptions(store_id=store_id_bytes, urls=urls) + except aiohttp.ClientConnectorError: + print(f"Connection error. Check if data is running at {rpc_port}") + except Exception as e: + print(f"Exception from 'data': {e}") + + async def get_kv_diff_cmd( rpc_port: Optional[int], store_id: str, diff --git a/chia/data_layer/data_layer.py b/chia/data_layer/data_layer.py index 96821f84dff6..cdd74dc58ba9 100644 --- a/chia/data_layer/data_layer.py +++ b/chia/data_layer/data_layer.py @@ -260,9 +260,10 @@ class DataLayer: await self.data_store.create_tree(tree_id=tree_id) while True: - server_info = await self.data_store.maybe_get_server_for_store(tree_id) + timestamp = int(time.time()) + server_info = await self.data_store.maybe_get_server_for_store(tree_id, timestamp) if server_info is None: - self.log.info(f"No server available for {tree_id}.") + self.log.info(f"No server available for {tree_id}") return url = server_info.url root = await self.data_store.get_tree_root(tree_id=tree_id) @@ -355,6 +356,12 @@ class DataLayer: await self.data_store.subscribe(subscription) self.log.info(f"Done adding subscription: {subscription.tree_id}") + async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> None: + parsed_urls = [url.rstrip("/") for url in urls] + async with self.subscription_lock: + await self.data_store.remove_subscriptions(store_id, parsed_urls) + subscriptions = await self.get_subscriptions() + async def unsubscribe(self, tree_id: bytes32) -> None: subscriptions = await self.get_subscriptions() if tree_id not in (subscription.tree_id for subscription in subscriptions): diff --git a/chia/data_layer/data_store.py b/chia/data_layer/data_store.py index 98d246cd06dd..d58160d33016 100644 --- a/chia/data_layer/data_store.py +++ b/chia/data_layer/data_store.py @@ -1,6 +1,5 @@ import logging import aiosqlite -import time from collections import defaultdict from random import Random from dataclasses import dataclass, replace @@ -1260,6 +1259,17 @@ class DataStore: }, ) + async def remove_subscriptions(self, tree_id: bytes32, urls: List[str], *, lock: bool = True) -> None: + async with self.db_wrapper.locked_transaction(lock=lock): + for url in urls: + await self.db.execute( + "DELETE FROM subscriptions WHERE tree_id == :tree_id AND url == :url", + { + "tree_id": tree_id.hex(), + "url": url, + }, + ) + async def unsubscribe(self, tree_id: bytes32, *, lock: bool = True) -> None: async with self.db_wrapper.locked_transaction(lock=lock): await self.db.execute( @@ -1292,7 +1302,7 @@ class DataStore: ) async def received_incorrect_file( - self, tree_id: bytes32, server_info: ServerInfo, timestamp: int = int(time.monotonic()), *, lock: bool = True + self, tree_id: bytes32, server_info: ServerInfo, timestamp: int, *, lock: bool = True ) -> None: SEVEN_DAYS_BAN = 7 * 24 * 60 * 60 new_server_info = ServerInfo( @@ -1311,7 +1321,7 @@ class DataStore: await self.update_server_info(tree_id, new_server_info, lock=lock) async def server_misses_file( - self, tree_id: bytes32, server_info: ServerInfo, timestamp: int = int(time.monotonic()), *, lock: bool = True + self, tree_id: bytes32, server_info: ServerInfo, timestamp: int, *, lock: bool = True ) -> None: BAN_TIME_BY_MISSING_COUNT = [5 * 60] * 3 + [15 * 60] * 3 + [60 * 60] * 2 + [240 * 60] index = min(server_info.num_consecutive_failures, len(BAN_TIME_BY_MISSING_COUNT) - 1) @@ -1323,7 +1333,7 @@ class DataStore: await self.update_server_info(tree_id, new_server_info, lock=lock) async def maybe_get_server_for_store( - self, tree_id: bytes32, timestamp: int = int(time.monotonic()), *, lock: bool = True + self, tree_id: bytes32, timestamp: int, *, lock: bool = True ) -> Optional[ServerInfo]: subscriptions = await self.get_subscriptions(lock=lock) subscription = next((subscription for subscription in subscriptions if subscription.tree_id == tree_id), None) diff --git a/chia/data_layer/download_data.py b/chia/data_layer/download_data.py index 332a2ba6617d..bc258040c9bb 100644 --- a/chia/data_layer/download_data.py +++ b/chia/data_layer/download_data.py @@ -1,6 +1,7 @@ import aiohttp import asyncio import os +import time import logging from pathlib import Path from typing import List, Optional @@ -132,6 +133,7 @@ async def insert_from_delta_file( log: logging.Logger, ) -> bool: for root_hash in root_hashes: + timestamp = int(time.time()) existing_generation += 1 filename = get_delta_filename(tree_id, root_hash, existing_generation) @@ -144,7 +146,7 @@ async def insert_from_delta_file( text = await resp.read() target_filename.write_bytes(text) except Exception: - await data_store.server_misses_file(tree_id, server_info) + await data_store.server_misses_file(tree_id, server_info, timestamp) raise log.info(f"Successfully downloaded delta file {filename}.") @@ -173,7 +175,7 @@ async def insert_from_delta_file( except Exception: target_filename = client_foldername.joinpath(filename) os.remove(target_filename) - await data_store.received_incorrect_file(tree_id, server_info) + await data_store.received_incorrect_file(tree_id, server_info, timestamp) await data_store.rollback_to_generation(tree_id, existing_generation - 1) raise diff --git a/chia/rpc/data_layer_rpc_api.py b/chia/rpc/data_layer_rpc_api.py index 16b5b87cc8de..9b604d799f72 100644 --- a/chia/rpc/data_layer_rpc_api.py +++ b/chia/rpc/data_layer_rpc_api.py @@ -66,6 +66,7 @@ class DataLayerRpcApi: "/insert": self.insert, "/subscribe": self.subscribe, "/unsubscribe": self.unsubscribe, + "/remove_subscriptions": self.remove_subscriptions, "/subscriptions": self.subscriptions, "/get_kv_diff": self.get_kv_diff, "/get_root_history": self.get_root_history, @@ -238,6 +239,14 @@ class DataLayerRpcApi: subscriptions: List[Subscription] = await self.service.get_subscriptions() return {"store_ids": [sub.tree_id.hex() for sub in subscriptions]} + async def remove_subscriptions(self, request: Dict[str, Any]) -> Dict[str, Any]: + if self.service is None: + raise Exception("Data layer not created") + store_id = request.get("id") + store_id_bytes = bytes32.from_hexstr(store_id) + urls = request["urls"] + await self.service.remove_subscriptions(store_id=store_id_bytes, urls=urls) + async def add_missing_files(self, request: Dict[str, Any]) -> Dict[str, Any]: """ complete the data server files. diff --git a/chia/rpc/data_layer_rpc_client.py b/chia/rpc/data_layer_rpc_client.py index ef2c18692eb1..19424f97e1d3 100644 --- a/chia/rpc/data_layer_rpc_client.py +++ b/chia/rpc/data_layer_rpc_client.py @@ -53,6 +53,10 @@ class DataLayerRpcClient(RpcClient): response = await self.fetch("subscribe", {"id": store_id.hex(), "urls": urls}) return response # type: ignore[no-any-return] + async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]: + response = await self.fetch("remove_subscriptions", {"id": store_id.hex(), "urls": urls}) + return response # type: ignore[no-any-return] + async def unsubscribe(self, store_id: bytes32) -> Dict[str, Any]: response = await self.fetch("unsubscribe", {"id": store_id.hex()}) return response # type: ignore[no-any-return]