chia-blockchain/chia/daemon/keychain_proxy.py
Kyle Altendorf b5a58e4998
enable ssl_context deprecation warning and fix it (#17127)
* enable ssl_context deprecation warning and fix it

* one case

* a few more

* yup

* one more
2024-01-02 17:07:31 -08:00

452 lines
18 KiB
Python

from __future__ import annotations
import asyncio
import logging
import ssl
import traceback
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from aiohttp import ClientConnectorError, ClientSession
from chia_rs import AugSchemeMPL, PrivateKey
from chia.cmds.init_funcs import check_keys
from chia.daemon.client import DaemonProxy
from chia.daemon.keychain_server import (
KEYCHAIN_ERR_KEY_NOT_FOUND,
KEYCHAIN_ERR_KEYERROR,
KEYCHAIN_ERR_LOCKED,
KEYCHAIN_ERR_MALFORMED_REQUEST,
KEYCHAIN_ERR_NO_KEYS,
)
from chia.server.server import ssl_context_for_client
from chia.util.config import load_config
from chia.util.errors import (
KeychainIsEmpty,
KeychainIsLocked,
KeychainKeyNotFound,
KeychainMalformedRequest,
KeychainMalformedResponse,
KeychainProxyConnectionTimeout,
)
from chia.util.keychain import Keychain, KeyData, bytes_to_mnemonic, mnemonic_to_seed
from chia.util.ws_message import WsRpcMessage
class KeychainProxy(DaemonProxy):
"""
KeychainProxy can act on behalf of a local or remote keychain. In the case of
wrapping a local keychain, the proxy object simply forwards-along the calls to
the underlying local keychain. In the remote case, calls are made to the daemon
over the RPC interface, allowing the daemon to act as the keychain authority.
"""
def __init__(
self,
log: logging.Logger,
uri: str = "",
ssl_context: Optional[ssl.SSLContext] = None,
local_keychain: Optional[Keychain] = None,
user: Optional[str] = None,
service: Optional[str] = None,
heartbeat: int = 300,
):
super().__init__(uri, ssl_context, heartbeat=heartbeat)
self.log = log
if local_keychain:
self.keychain = local_keychain
else:
self.keychain = None # type: ignore
self.keychain_user = user
self.keychain_service = service
# these are used to track and close the keychain connection
self.keychain_connection_task: Optional[asyncio.Task[None]] = None
self.shut_down: bool = False
self.connection_established: asyncio.Event = asyncio.Event()
def use_local_keychain(self) -> bool:
"""
Indicates whether the proxy forwards calls to a local keychain
"""
return self.keychain is not None
def format_request(self, command: str, data: Dict[str, Any]) -> WsRpcMessage:
"""
Overrides DaemonProxy.format_request() to add keychain-specific RPC params
"""
if data is None:
data = {}
if self.keychain_user or self.keychain_service:
data["kc_user"] = self.keychain_user
data["kc_service"] = self.keychain_service
return super().format_request(command, data)
async def _get(self, request: WsRpcMessage) -> WsRpcMessage:
"""
Overrides DaemonProxy._get() to handle the connection state
"""
try:
if not self.shut_down: # if we are shut down, and we send a request we should throw original error.
await asyncio.wait_for(self.connection_established.wait(), timeout=30) # in case of heavy swap usage.
else:
self.log.error("Attempting to send request to a keychain-proxy that has shut down.")
self.log.debug(f"Sending request to keychain command: {request['command']} from {request['origin']}.")
return await super()._get(request)
except asyncio.TimeoutError:
raise KeychainProxyConnectionTimeout()
async def start(self) -> None:
self.keychain_connection_task = asyncio.create_task(self.connect_to_keychain())
await self.connection_established.wait() # wait until connection is established.
async def connect_to_keychain(self) -> None:
while not self.shut_down:
try:
self.client_session = ClientSession()
self.websocket = await self.client_session.ws_connect(
self._uri,
autoclose=True,
autoping=True,
heartbeat=self.heartbeat,
ssl=self.ssl_context,
max_msg_size=self.max_message_size,
)
await self.listener()
except ClientConnectorError:
self.log.warning(f"Can not connect to keychain at {self._uri}.")
except Exception as e:
tb = traceback.format_exc()
self.log.warning(f"Exception: {tb} {type(e)}")
self.log.info(f"Reconnecting to keychain at {self._uri}.")
self.connection_established.clear()
if self.websocket is not None:
await self.websocket.close()
if self.client_session is not None:
await self.client_session.close()
self.websocket = None
self.client_session = None
await asyncio.sleep(2)
async def listener(self) -> None:
self.connection_established.set() # mark connection as active.
await super().listener()
self.log.info("Close signal received from keychain, we probably timed out.")
async def close(self) -> None:
self.shut_down = True
await super().close()
if self.keychain_connection_task is not None:
await self.keychain_connection_task
async def get_response_for_request(self, request_name: str, data: Dict[str, Any]) -> Tuple[WsRpcMessage, bool]:
request = self.format_request(request_name, data)
response = await self._get(request)
success = response["data"].get("success", False)
return response, success
def handle_error(self, response: WsRpcMessage) -> None:
"""
Common error handling for RPC responses
"""
error = response["data"].get("error", None)
if error:
error_details = response["data"].get("error_details", {})
if error == KEYCHAIN_ERR_LOCKED:
raise KeychainIsLocked()
elif error == KEYCHAIN_ERR_NO_KEYS:
raise KeychainIsEmpty()
elif error == KEYCHAIN_ERR_KEY_NOT_FOUND:
raise KeychainKeyNotFound()
elif error == KEYCHAIN_ERR_MALFORMED_REQUEST:
message = error_details.get("message", "")
raise KeychainMalformedRequest(message)
else:
# Try to construct a more informative error message including the call that failed
if "command" in response["data"]:
err = f"{response['data'].get('command')} failed with error: {error}"
raise Exception(f"{err}")
raise Exception(f"{error}")
async def add_private_key(self, mnemonic: str, label: Optional[str] = None) -> PrivateKey:
"""
Forwards to Keychain.add_private_key()
"""
key: PrivateKey
if self.use_local_keychain():
key = self.keychain.add_private_key(mnemonic, label)
else:
response, success = await self.get_response_for_request(
"add_private_key", {"mnemonic": mnemonic, "label": label}
)
if success:
seed = mnemonic_to_seed(mnemonic)
key = AugSchemeMPL.key_gen(seed)
else:
error = response["data"].get("error", None)
if error == KEYCHAIN_ERR_KEYERROR:
error_details = response["data"].get("error_details", {})
word = error_details.get("word", "")
raise KeyError(word)
else:
self.handle_error(response)
return key
async def check_keys(self, root_path: Path) -> None:
"""
Forwards to init_funcs.check_keys()
"""
if self.use_local_keychain():
check_keys(root_path, self.keychain)
else:
response, success = await self.get_response_for_request("check_keys", {"root_path": str(root_path)})
if not success:
self.handle_error(response)
async def delete_all_keys(self) -> None:
"""
Forwards to Keychain.delete_all_keys()
"""
if self.use_local_keychain():
self.keychain.delete_all_keys()
else:
response, success = await self.get_response_for_request("delete_all_keys", {})
if not success:
self.handle_error(response)
async def delete_key_by_fingerprint(self, fingerprint: int) -> None:
"""
Forwards to Keychain.delete_key_by_fingerprint()
"""
if self.use_local_keychain():
self.keychain.delete_key_by_fingerprint(fingerprint)
else:
response, success = await self.get_response_for_request(
"delete_key_by_fingerprint", {"fingerprint": fingerprint}
)
if not success:
self.handle_error(response)
async def get_all_private_keys(self) -> List[Tuple[PrivateKey, bytes]]:
"""
Forwards to Keychain.get_all_private_keys()
"""
keys: List[Tuple[PrivateKey, bytes]] = []
if self.use_local_keychain():
keys = self.keychain.get_all_private_keys()
else:
response, success = await self.get_response_for_request("get_all_private_keys", {})
if success:
private_keys = response["data"].get("private_keys", None)
if private_keys is None:
err = f"Missing private_keys in {response.get('command')} response"
self.log.error(f"{err}")
raise KeychainMalformedResponse(f"{err}")
else:
for key_dict in private_keys:
pk = key_dict.get("pk", None)
ent_str = key_dict.get("entropy", None)
if pk is None or ent_str is None:
err = f"Missing pk and/or ent in {response.get('command')} response"
self.log.error(f"{err}")
continue # We'll skip the incomplete key entry
ent = bytes.fromhex(ent_str)
mnemonic = bytes_to_mnemonic(ent)
seed = mnemonic_to_seed(mnemonic)
key = AugSchemeMPL.key_gen(seed)
if bytes(key.get_g1()).hex() == pk:
keys.append((key, ent))
else:
err = "G1Elements don't match"
self.log.error(f"{err}")
else:
self.handle_error(response)
return keys
async def get_first_private_key(self) -> Optional[PrivateKey]:
"""
Forwards to Keychain.get_first_private_key()
"""
key: Optional[PrivateKey] = None
if self.use_local_keychain():
sk_ent = self.keychain.get_first_private_key()
if sk_ent:
key = sk_ent[0]
else:
response, success = await self.get_response_for_request("get_first_private_key", {})
if success:
private_key = response["data"].get("private_key", None)
if private_key is None:
err = f"Missing private_key in {response.get('command')} response"
self.log.error(f"{err}")
raise KeychainMalformedResponse(f"{err}")
else:
pk = private_key.get("pk", None)
ent_str = private_key.get("entropy", None)
if pk is None or ent_str is None:
err = f"Missing pk and/or ent in {response.get('command')} response"
self.log.error(f"{err}")
raise KeychainMalformedResponse(f"{err}")
ent = bytes.fromhex(ent_str)
mnemonic = bytes_to_mnemonic(ent)
seed = mnemonic_to_seed(mnemonic)
sk = AugSchemeMPL.key_gen(seed)
if bytes(sk.get_g1()).hex() == pk:
key = sk
else:
err = "G1Elements don't match"
self.log.error(f"{err}")
else:
self.handle_error(response)
return key
async def get_key_for_fingerprint(self, fingerprint: Optional[int]) -> Optional[PrivateKey]:
"""
Locates and returns a private key matching the provided fingerprint
"""
key: Optional[PrivateKey] = None
if self.use_local_keychain():
private_keys = self.keychain.get_all_private_keys()
if len(private_keys) == 0:
raise KeychainIsEmpty()
else:
if fingerprint is not None:
for sk, _ in private_keys:
if sk.get_g1().get_fingerprint() == fingerprint:
key = sk
break
if key is None:
raise KeychainKeyNotFound(fingerprint)
else:
key = private_keys[0][0]
else:
response, success = await self.get_response_for_request(
"get_key_for_fingerprint", {"fingerprint": fingerprint}
)
if success:
pk = response["data"].get("pk", None)
ent = response["data"].get("entropy", None)
if pk is None or ent is None:
err = f"Missing pk and/or ent in {response.get('command')} response"
self.log.error(f"{err}")
raise KeychainMalformedResponse(f"{err}")
else:
mnemonic = bytes_to_mnemonic(bytes.fromhex(ent))
seed = mnemonic_to_seed(mnemonic)
private_key = AugSchemeMPL.key_gen(seed)
if bytes(private_key.get_g1()).hex() == pk:
key = private_key
else:
err = "G1Elements don't match"
self.log.error(f"{err}")
else:
self.handle_error(response)
return key
async def get_key(self, fingerprint: int, include_secrets: bool = False) -> Optional[KeyData]:
"""
Locates and returns KeyData matching the provided fingerprint
"""
key_data: Optional[KeyData] = None
if self.use_local_keychain():
key_data = self.keychain.get_key(fingerprint, include_secrets)
else:
response, success = await self.get_response_for_request(
"get_key", {"fingerprint": fingerprint, "include_secrets": include_secrets}
)
if success:
key_data = KeyData.from_json_dict(response["data"]["key"])
else:
self.handle_error(response)
return key_data
async def get_keys(self, include_secrets: bool = False) -> List[KeyData]:
"""
Returns all KeyData
"""
keys: List[KeyData] = []
if self.use_local_keychain():
keys = self.keychain.get_keys(include_secrets)
else:
response, success = await self.get_response_for_request("get_keys", {"include_secrets": include_secrets})
if success:
keys = [KeyData.from_json_dict(key) for key in response["data"]["keys"]]
else:
self.handle_error(response)
return keys
def wrap_local_keychain(keychain: Keychain, log: logging.Logger) -> KeychainProxy:
"""
Wrap an existing local Keychain instance in a KeychainProxy to utilize
the same interface as a remote Keychain
"""
return KeychainProxy(local_keychain=keychain, log=log)
async def connect_to_keychain(
self_hostname: str,
daemon_port: int,
daemon_heartbeat: int,
ssl_context: Optional[ssl.SSLContext],
log: logging.Logger,
user: Optional[str] = None,
service: Optional[str] = None,
) -> KeychainProxy:
"""
Connect to the local daemon.
"""
client = KeychainProxy(
uri=f"wss://{self_hostname}:{daemon_port}",
heartbeat=daemon_heartbeat,
ssl_context=ssl_context,
log=log,
user=user,
service=service,
)
# Connect to the service if the proxy isn't using a local keychain
if not client.use_local_keychain():
await client.start()
return client
async def connect_to_keychain_and_validate(
root_path: Path,
log: logging.Logger,
user: Optional[str] = None,
service: Optional[str] = None,
) -> Optional[KeychainProxy]:
"""
Connect to the local daemon and do a ping to ensure that something is really
there and running.
"""
try:
net_config = load_config(root_path, "config.yaml")
crt_path = root_path / net_config["daemon_ssl"]["private_crt"]
key_path = root_path / net_config["daemon_ssl"]["private_key"]
ca_crt_path = root_path / net_config["private_ssl_ca"]["crt"]
ca_key_path = root_path / net_config["private_ssl_ca"]["key"]
ssl_context = ssl_context_for_client(ca_crt_path, ca_key_path, crt_path, key_path, log=log)
daemon_heartbeat = net_config.get("daemon_heartbeat", 300)
connection = await connect_to_keychain(
net_config["self_hostname"], net_config["daemon_port"], daemon_heartbeat, ssl_context, log, user, service
)
# If proxying to a local keychain, don't attempt to ping
if connection.use_local_keychain():
return connection
r = await connection.ping() # this is purposely using the base classes _get method
if "value" in r["data"] and r["data"]["value"] == "pong":
return connection
except Exception as e:
print(f"Keychain(daemon) not started yet: {e}")
return None