DL upsert. (#17011)

This commit is contained in:
Florin Chirica 2023-12-19 19:48:49 +02:00 committed by GitHub
parent ebf0c00a21
commit 42d9048a79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 348 additions and 96 deletions

View File

@ -855,6 +855,85 @@ class DataStore:
return keys
async def get_ancestors_common(
self,
node_hash: bytes32,
tree_id: bytes32,
root_hash: Optional[bytes32],
generation: Optional[int] = None,
use_optimized: bool = True,
) -> List[InternalNode]:
if use_optimized:
ancestors: List[InternalNode] = await self.get_ancestors_optimized(
node_hash=node_hash,
tree_id=tree_id,
generation=generation,
root_hash=root_hash,
)
else:
ancestors = await self.get_ancestors_optimized(
node_hash=node_hash,
tree_id=tree_id,
generation=generation,
root_hash=root_hash,
)
ancestors_2: List[InternalNode] = await self.get_ancestors(
node_hash=node_hash, tree_id=tree_id, root_hash=root_hash
)
if ancestors != ancestors_2:
raise RuntimeError("Ancestors optimized didn't produce the expected result.")
if len(ancestors) >= 62:
raise RuntimeError("Tree exceeds max height of 62.")
return ancestors
async def update_ancestor_hashes_on_insert(
self,
tree_id: bytes32,
left: bytes32,
right: bytes32,
traversal_node_hash: bytes32,
ancestors: List[InternalNode],
status: Status,
root: Root,
) -> Root:
# update ancestors after inserting root, to keep table constraints.
insert_ancestors_cache: List[Tuple[bytes32, bytes32, bytes32]] = []
new_generation = root.generation + 1
# create first new internal node
new_hash = await self._insert_internal_node(left_hash=left, right_hash=right)
insert_ancestors_cache.append((left, right, tree_id))
# create updated replacements for the rest of the internal nodes
for ancestor in ancestors:
if not isinstance(ancestor, InternalNode):
raise Exception(f"Expected an internal node but got: {type(ancestor).__name__}")
if ancestor.left_hash == traversal_node_hash:
left = new_hash
right = ancestor.right_hash
elif ancestor.right_hash == traversal_node_hash:
left = ancestor.left_hash
right = new_hash
traversal_node_hash = ancestor.hash
new_hash = await self._insert_internal_node(left_hash=left, right_hash=right)
insert_ancestors_cache.append((left, right, tree_id))
new_root = await self._insert_root(
tree_id=tree_id,
node_hash=new_hash,
status=status,
generation=new_generation,
)
if status == Status.COMMITTED:
for left_hash, right_hash, tree_id in insert_ancestors_cache:
await self._insert_ancestor_table(left_hash, right_hash, tree_id, new_generation)
return new_root
async def insert(
self,
key: bytes,
@ -880,7 +959,7 @@ class DataStore:
if any(key == node.key for node in pairs):
raise Exception(f"Key already present: {key.hex()}")
else:
if bytes(key) in hint_keys_values:
if key in hint_keys_values:
raise Exception(f"Key already present: {key.hex()}")
if reference_node_hash is None:
@ -911,26 +990,6 @@ class DataStore:
if root.node_hash is None:
raise Exception("Internal error.")
if use_optimized:
ancestors: List[InternalNode] = await self.get_ancestors_optimized(
node_hash=reference_node_hash,
tree_id=tree_id,
generation=root.generation,
root_hash=root.node_hash,
)
else:
ancestors = await self.get_ancestors_optimized(
node_hash=reference_node_hash,
tree_id=tree_id,
generation=root.generation,
root_hash=root.node_hash,
)
ancestors_2: List[InternalNode] = await self.get_ancestors(
node_hash=reference_node_hash, tree_id=tree_id, root_hash=root.node_hash
)
if ancestors != ancestors_2:
raise RuntimeError("Ancestors optimized didn't produce the expected result.")
if side == Side.LEFT:
left = new_terminal_node_hash
right = reference_node_hash
@ -938,48 +997,26 @@ class DataStore:
left = reference_node_hash
right = new_terminal_node_hash
if len(ancestors) >= 62:
raise RuntimeError("Tree exceeds max height of 62.")
# update ancestors after inserting root, to keep table constraints.
insert_ancestors_cache: List[Tuple[bytes32, bytes32, bytes32]] = []
new_generation = root.generation + 1
# create first new internal node
new_hash = await self._insert_internal_node(left_hash=left, right_hash=right)
insert_ancestors_cache.append((left, right, tree_id))
traversal_node_hash = reference_node_hash
# create updated replacements for the rest of the internal nodes
for ancestor in ancestors:
if not isinstance(ancestor, InternalNode):
raise Exception(f"Expected an internal node but got: {type(ancestor).__name__}")
if ancestor.left_hash == traversal_node_hash:
left = new_hash
right = ancestor.right_hash
elif ancestor.right_hash == traversal_node_hash:
left = ancestor.left_hash
right = new_hash
traversal_node_hash = ancestor.hash
new_hash = await self._insert_internal_node(left_hash=left, right_hash=right)
insert_ancestors_cache.append((left, right, tree_id))
new_root = await self._insert_root(
ancestors = await self.get_ancestors_common(
node_hash=reference_node_hash,
tree_id=tree_id,
node_hash=new_hash,
root_hash=root.node_hash,
generation=root.generation,
use_optimized=use_optimized,
)
new_root = await self.update_ancestor_hashes_on_insert(
tree_id=tree_id,
left=left,
right=right,
traversal_node_hash=reference_node_hash,
ancestors=ancestors,
status=status,
generation=new_generation,
root=root,
)
if status == Status.COMMITTED:
for left_hash, right_hash, tree_id in insert_ancestors_cache:
await self._insert_ancestor_table(left_hash, right_hash, tree_id, new_generation)
if hint_keys_values is not None:
hint_keys_values[key] = value
return InsertResult(node_hash=new_terminal_node_hash, root=new_root)
if hint_keys_values is not None:
hint_keys_values[key] = value
return InsertResult(node_hash=new_terminal_node_hash, root=new_root)
async def delete(
self,
@ -1002,22 +1039,14 @@ class DataStore:
node_hash = leaf_hash(key=key, value=value)
node = TerminalNode(node_hash, key, value)
del hint_keys_values[key]
if use_optimized:
ancestors: List[InternalNode] = await self.get_ancestors_optimized(
node_hash=node.hash, tree_id=tree_id, root_hash=root_hash
)
else:
ancestors = await self.get_ancestors_optimized(
node_hash=node.hash, tree_id=tree_id, root_hash=root_hash
)
ancestors_2: List[InternalNode] = await self.get_ancestors(
node_hash=node.hash, tree_id=tree_id, root_hash=root_hash
)
if ancestors != ancestors_2:
raise RuntimeError("Ancestors optimized didn't produce the expected result.")
if len(ancestors) > 62:
raise RuntimeError("Tree exceeded max height of 62.")
ancestors: List[InternalNode] = await self.get_ancestors_common(
node_hash=node.hash,
tree_id=tree_id,
root_hash=root_hash,
use_optimized=use_optimized,
)
if len(ancestors) == 0:
# the only node is being deleted
return await self._insert_root(
@ -1072,6 +1101,100 @@ class DataStore:
return new_root
async def upsert(
self,
key: bytes,
new_value: bytes,
tree_id: bytes32,
hint_keys_values: Optional[Dict[bytes, bytes]] = None,
use_optimized: bool = True,
status: Status = Status.PENDING,
root: Optional[Root] = None,
) -> InsertResult:
async with self.db_wrapper.writer():
if root is None:
root = await self.get_tree_root(tree_id=tree_id)
if hint_keys_values is None:
try:
old_node = await self.get_node_by_key(key=key, tree_id=tree_id)
except KeyNotFoundError:
log.debug(f"Key not found: {key.hex()}. Doing an autoinsert instead")
return await self.autoinsert(
key=key,
value=new_value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=status,
root=root,
)
if old_node.value == new_value:
log.debug(f"New value matches old value in upsert operation: {key.hex()}. Ignoring upsert")
return InsertResult(leaf_hash(key, new_value), root)
old_node_hash = old_node.hash
else:
if key not in hint_keys_values:
log.debug(f"Key not found: {key.hex()}. Doing an autoinsert instead")
return await self.autoinsert(
key=key,
value=new_value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=status,
root=root,
)
value = hint_keys_values[key]
if value == new_value:
log.debug(f"New value matches old value in upsert operation: {key.hex()}")
return InsertResult(leaf_hash(key, new_value), root)
old_node_hash = leaf_hash(key=key, value=value)
del hint_keys_values[key]
# create new terminal node
new_terminal_node_hash = await self._insert_terminal_node(key=key, value=new_value)
ancestors = await self.get_ancestors_common(
node_hash=old_node_hash,
tree_id=tree_id,
root_hash=root.node_hash,
generation=root.generation,
use_optimized=use_optimized,
)
# Store contains only the old root, replace it with a new root having the terminal node.
if len(ancestors) == 0:
new_root = await self._insert_root(
tree_id=tree_id,
node_hash=new_terminal_node_hash,
status=status,
)
else:
parent = ancestors[0]
if parent.left_hash == old_node_hash:
left = new_terminal_node_hash
right = parent.right_hash
elif parent.right_hash == old_node_hash:
left = parent.left_hash
right = new_terminal_node_hash
else:
raise Exception("Internal error.")
new_root = await self.update_ancestor_hashes_on_insert(
tree_id=tree_id,
left=left,
right=right,
traversal_node_hash=parent.hash,
ancestors=ancestors[1:],
status=status,
root=root,
)
if hint_keys_values is not None:
hint_keys_values[key] = new_value
return InsertResult(node_hash=new_terminal_node_hash, root=new_root)
async def clean_node_table(self, writer: aiosqlite.Connection) -> None:
await writer.execute(
"""
@ -1140,6 +1263,13 @@ class DataStore:
intermediate_root = await self.delete(
key, tree_id, hint_keys_values, True, Status.COMMITTED, root=intermediate_root
)
elif change["action"] == "upsert":
key = change["key"]
new_value = change["value"]
insert_result = await self.upsert(
key, new_value, tree_id, hint_keys_values, True, Status.COMMITTED, root=intermediate_root
)
intermediate_root = insert_result.root
else:
raise Exception(f"Operation in batch is not insert or delete: {change}")

View File

@ -241,11 +241,40 @@ async def test_create_insert_get(
with pytest.raises(ValueError, match="Changelist resulted in no change to tree data"):
await data_rpc_api.batch_update({"id": store_id.hex(), "changelist": changelist})
# test delete
changelist = [{"action": "delete", "key": key.hex()}]
# test upsert
new_value = b"\x00\x02"
changelist = [{"action": "upsert", "key": key.hex(), "value": new_value.hex()}]
res = await data_rpc_api.batch_update({"id": store_id.hex(), "changelist": changelist})
update_tx_rec1 = res["tx_id"]
await farm_block_with_spend(full_node_api, ph, update_tx_rec1, wallet_rpc_api)
res = await data_rpc_api.get_value({"id": store_id.hex(), "key": key.hex()})
assert hexstr_to_bytes(res["value"]) == new_value
wallet_root = await data_rpc_api.get_root({"id": store_id.hex()})
upsert_wallet_root = wallet_root["hash"]
# test upsert unknown key acts as insert
new_value = b"\x00\x02"
changelist = [{"action": "upsert", "key": unknown_key.hex(), "value": new_value.hex()}]
res = await data_rpc_api.batch_update({"id": store_id.hex(), "changelist": changelist})
update_tx_rec2 = res["tx_id"]
await farm_block_with_spend(full_node_api, ph, update_tx_rec2, wallet_rpc_api)
res = await data_rpc_api.get_value({"id": store_id.hex(), "key": unknown_key.hex()})
assert hexstr_to_bytes(res["value"]) == new_value
# test delete
changelist = [{"action": "delete", "key": unknown_key.hex()}]
res = await data_rpc_api.batch_update({"id": store_id.hex(), "changelist": changelist})
update_tx_rec3 = res["tx_id"]
await farm_block_with_spend(full_node_api, ph, update_tx_rec3, wallet_rpc_api)
with pytest.raises(Exception):
await data_rpc_api.get_value({"id": store_id.hex(), "key": unknown_key.hex()})
wallet_root = await data_rpc_api.get_root({"id": store_id.hex()})
assert wallet_root["hash"] == upsert_wallet_root
changelist = [{"action": "delete", "key": key.hex()}]
res = await data_rpc_api.batch_update({"id": store_id.hex(), "changelist": changelist})
update_tx_rec4 = res["tx_id"]
await farm_block_with_spend(full_node_api, ph, update_tx_rec4, wallet_rpc_api)
with pytest.raises(Exception):
await data_rpc_api.get_value({"id": store_id.hex(), "key": key.hex()})
wallet_root = await data_rpc_api.get_root({"id": store_id.hex()})

View File

@ -8,7 +8,7 @@ import statistics
from dataclasses import dataclass
from pathlib import Path
from random import Random
from typing import Any, Awaitable, Callable, Dict, List, Set, Tuple, cast
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Tuple, cast
import aiosqlite
import pytest
@ -382,38 +382,67 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
random.seed(100, version=2)
batch: List[Dict[str, Any]] = []
keys: List[bytes] = []
hint_keys_values: Dict[bytes, bytes] = {}
keys_values: Dict[bytes, bytes] = {}
hint_keys_values: Optional[Dict[bytes, bytes]] = {} if use_optimized else None
for operation in range(num_batches * num_ops_per_batch):
if random.randint(0, 4) > 0 or len(keys) == 0:
[op_type] = random.choices(
["insert", "upsert-insert", "upsert-update", "delete"],
[0.4, 0.2, 0.2, 0.2],
k=1,
)
if op_type == "insert" or op_type == "upsert-insert" or len(keys_values) == 0:
if len(keys_values) == 0:
op_type = "insert"
key = operation.to_bytes(4, byteorder="big")
value = (2 * operation).to_bytes(4, byteorder="big")
if use_optimized:
if op_type == "insert":
await single_op_data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
else:
await single_op_data_store.autoinsert(
key=key, value=value, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED
)
batch.append({"action": "insert", "key": key, "value": value})
keys.append(key)
else:
key = random.choice(keys)
keys.remove(key)
if use_optimized:
await single_op_data_store.delete(
key=key, tree_id=tree_id, hint_keys_values=hint_keys_values, status=Status.COMMITTED
)
else:
await single_op_data_store.delete(
key=key, tree_id=tree_id, use_optimized=False, status=Status.COMMITTED
await single_op_data_store.upsert(
key=key,
new_value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
action = "insert" if op_type == "insert" else "upsert"
batch.append({"action": action, "key": key, "value": value})
keys_values[key] = value
elif op_type == "delete":
key = random.choice(list(keys_values.keys()))
del keys_values[key]
await single_op_data_store.delete(
key=key,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
batch.append({"action": "delete", "key": key})
else:
assert op_type == "upsert-update"
key = random.choice(list(keys_values.keys()))
old_value = keys_values[key]
new_value_int = int.from_bytes(old_value, byteorder="big") + 1
new_value = new_value_int.to_bytes(4, byteorder="big")
await single_op_data_store.upsert(
key=key,
new_value=new_value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
keys_values[key] = new_value
batch.append({"action": "upsert", "key": key, "value": new_value})
if (operation + 1) % num_ops_per_batch == 0:
saved_batches.append(batch)
batch = []
@ -445,6 +474,70 @@ async def test_batch_update(data_store: DataStore, tree_id: bytes32, use_optimiz
ancestors[node.left_hash] = node_hash
ancestors[node.right_hash] = node_hash
all_kv = await data_store.get_keys_values(tree_id)
assert {node.key: node.value for node in all_kv} == keys_values
@pytest.mark.anyio
@pytest.mark.parametrize(
"use_optimized",
[True, False],
)
async def test_upsert_ignores_existing_arguments(
data_store: DataStore,
tree_id: bytes32,
use_optimized: bool,
) -> None:
key = b"key"
value = b"value1"
hint_keys_values: Optional[Dict[bytes, bytes]] = {} if use_optimized else None
await data_store.autoinsert(
key=key,
value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
node = await data_store.get_node_by_key(key, tree_id)
assert node.value == value
new_value = b"value2"
await data_store.upsert(
key=key,
new_value=new_value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
node = await data_store.get_node_by_key(key, tree_id)
assert node.value == new_value
await data_store.upsert(
key=key,
new_value=new_value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
node = await data_store.get_node_by_key(key, tree_id)
assert node.value == new_value
key2 = b"key2"
await data_store.upsert(
key=key2,
new_value=value,
tree_id=tree_id,
hint_keys_values=hint_keys_values,
use_optimized=use_optimized,
status=Status.COMMITTED,
)
node = await data_store.get_node_by_key(key2, tree_id)
assert node.value == value
@pytest.mark.parametrize(argnames="side", argvalues=list(Side))
@pytest.mark.anyio