Add remove_subscription and fix bug.

This commit is contained in:
Florin Chirica 2022-07-28 18:47:15 +03:00
parent f34638b406
commit cb42f7f205
No known key found for this signature in database
GPG Key ID: 1805593F7B529698
7 changed files with 69 additions and 8 deletions

View File

@ -196,6 +196,20 @@ def subscribe(
run(subscribe_cmd(rpc_port=data_rpc_port, store_id=id, urls=urls))
@data_cmd.command("remove_subscription", short_help="")
@create_data_store_id_option()
@click.option("-u", "--url", "urls", help="", type=str, multiple=True)
@create_rpc_port_option()
def remove_subscription(
id: str,
urls: List[str],
data_rpc_port: int,
) -> None:
from chia.cmds.data_funcs import remove_subscriptions_cmd
run(remove_subscriptions_cmd(rpc_port=data_rpc_port, store_id=id, urls=urls))
@data_cmd.command("unsubscribe", short_help="")
@create_data_store_id_option()
@create_rpc_port_option()

View File

@ -156,6 +156,21 @@ async def unsubscribe_cmd(
print(f"Exception from 'data': {e}")
async def remove_subscriptions_cmd(
rpc_port: Optional[int],
store_id: str,
urls: List[str],
) -> None:
store_id_bytes = bytes32.from_hexstr(store_id)
try:
async with get_client(rpc_port) as (client, rpc_port):
await client.remove_subscriptions(store_id=store_id_bytes, urls=urls)
except aiohttp.ClientConnectorError:
print(f"Connection error. Check if data is running at {rpc_port}")
except Exception as e:
print(f"Exception from 'data': {e}")
async def get_kv_diff_cmd(
rpc_port: Optional[int],
store_id: str,

View File

@ -260,9 +260,10 @@ class DataLayer:
await self.data_store.create_tree(tree_id=tree_id)
while True:
server_info = await self.data_store.maybe_get_server_for_store(tree_id)
timestamp = int(time.time())
server_info = await self.data_store.maybe_get_server_for_store(tree_id, timestamp)
if server_info is None:
self.log.info(f"No server available for {tree_id}.")
self.log.info(f"No server available for {tree_id}")
return
url = server_info.url
root = await self.data_store.get_tree_root(tree_id=tree_id)
@ -355,6 +356,12 @@ class DataLayer:
await self.data_store.subscribe(subscription)
self.log.info(f"Done adding subscription: {subscription.tree_id}")
async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> None:
parsed_urls = [url.rstrip("/") for url in urls]
async with self.subscription_lock:
await self.data_store.remove_subscriptions(store_id, parsed_urls)
subscriptions = await self.get_subscriptions()
async def unsubscribe(self, tree_id: bytes32) -> None:
subscriptions = await self.get_subscriptions()
if tree_id not in (subscription.tree_id for subscription in subscriptions):

View File

@ -1,6 +1,5 @@
import logging
import aiosqlite
import time
from collections import defaultdict
from random import Random
from dataclasses import dataclass, replace
@ -1260,6 +1259,17 @@ class DataStore:
},
)
async def remove_subscriptions(self, tree_id: bytes32, urls: List[str], *, lock: bool = True) -> None:
async with self.db_wrapper.locked_transaction(lock=lock):
for url in urls:
await self.db.execute(
"DELETE FROM subscriptions WHERE tree_id == :tree_id AND url == :url",
{
"tree_id": tree_id.hex(),
"url": url,
},
)
async def unsubscribe(self, tree_id: bytes32, *, lock: bool = True) -> None:
async with self.db_wrapper.locked_transaction(lock=lock):
await self.db.execute(
@ -1292,7 +1302,7 @@ class DataStore:
)
async def received_incorrect_file(
self, tree_id: bytes32, server_info: ServerInfo, timestamp: int = int(time.monotonic()), *, lock: bool = True
self, tree_id: bytes32, server_info: ServerInfo, timestamp: int, *, lock: bool = True
) -> None:
SEVEN_DAYS_BAN = 7 * 24 * 60 * 60
new_server_info = ServerInfo(
@ -1311,7 +1321,7 @@ class DataStore:
await self.update_server_info(tree_id, new_server_info, lock=lock)
async def server_misses_file(
self, tree_id: bytes32, server_info: ServerInfo, timestamp: int = int(time.monotonic()), *, lock: bool = True
self, tree_id: bytes32, server_info: ServerInfo, timestamp: int, *, lock: bool = True
) -> None:
BAN_TIME_BY_MISSING_COUNT = [5 * 60] * 3 + [15 * 60] * 3 + [60 * 60] * 2 + [240 * 60]
index = min(server_info.num_consecutive_failures, len(BAN_TIME_BY_MISSING_COUNT) - 1)
@ -1323,7 +1333,7 @@ class DataStore:
await self.update_server_info(tree_id, new_server_info, lock=lock)
async def maybe_get_server_for_store(
self, tree_id: bytes32, timestamp: int = int(time.monotonic()), *, lock: bool = True
self, tree_id: bytes32, timestamp: int, *, lock: bool = True
) -> Optional[ServerInfo]:
subscriptions = await self.get_subscriptions(lock=lock)
subscription = next((subscription for subscription in subscriptions if subscription.tree_id == tree_id), None)

View File

@ -1,6 +1,7 @@
import aiohttp
import asyncio
import os
import time
import logging
from pathlib import Path
from typing import List, Optional
@ -132,6 +133,7 @@ async def insert_from_delta_file(
log: logging.Logger,
) -> bool:
for root_hash in root_hashes:
timestamp = int(time.time())
existing_generation += 1
filename = get_delta_filename(tree_id, root_hash, existing_generation)
@ -144,7 +146,7 @@ async def insert_from_delta_file(
text = await resp.read()
target_filename.write_bytes(text)
except Exception:
await data_store.server_misses_file(tree_id, server_info)
await data_store.server_misses_file(tree_id, server_info, timestamp)
raise
log.info(f"Successfully downloaded delta file {filename}.")
@ -173,7 +175,7 @@ async def insert_from_delta_file(
except Exception:
target_filename = client_foldername.joinpath(filename)
os.remove(target_filename)
await data_store.received_incorrect_file(tree_id, server_info)
await data_store.received_incorrect_file(tree_id, server_info, timestamp)
await data_store.rollback_to_generation(tree_id, existing_generation - 1)
raise

View File

@ -66,6 +66,7 @@ class DataLayerRpcApi:
"/insert": self.insert,
"/subscribe": self.subscribe,
"/unsubscribe": self.unsubscribe,
"/remove_subscriptions": self.remove_subscriptions,
"/subscriptions": self.subscriptions,
"/get_kv_diff": self.get_kv_diff,
"/get_root_history": self.get_root_history,
@ -238,6 +239,14 @@ class DataLayerRpcApi:
subscriptions: List[Subscription] = await self.service.get_subscriptions()
return {"store_ids": [sub.tree_id.hex() for sub in subscriptions]}
async def remove_subscriptions(self, request: Dict[str, Any]) -> Dict[str, Any]:
if self.service is None:
raise Exception("Data layer not created")
store_id = request.get("id")
store_id_bytes = bytes32.from_hexstr(store_id)
urls = request["urls"]
await self.service.remove_subscriptions(store_id=store_id_bytes, urls=urls)
async def add_missing_files(self, request: Dict[str, Any]) -> Dict[str, Any]:
"""
complete the data server files.

View File

@ -53,6 +53,10 @@ class DataLayerRpcClient(RpcClient):
response = await self.fetch("subscribe", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
async def remove_subscriptions(self, store_id: bytes32, urls: List[str]) -> Dict[str, Any]:
response = await self.fetch("remove_subscriptions", {"id": store_id.hex(), "urls": urls})
return response # type: ignore[no-any-return]
async def unsubscribe(self, store_id: bytes32) -> Dict[str, Any]:
response = await self.fetch("unsubscribe", {"id": store_id.hex()})
return response # type: ignore[no-any-return]