Clear cache of DeviceInfo when needed

This commit is contained in:
Dain Nilsson 2024-07-09 11:30:02 +02:00
parent e7a047c9b9
commit 0b7d6736cb
No known key found for this signature in database
GPG Key ID: F04367096FBA95E8
7 changed files with 100 additions and 65 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import RpcException, encode_bytes
from .base import RpcResponse, RpcException, encode_bytes
from .device import RootNode
from queue import Queue
@ -80,7 +80,7 @@ def _handle_incoming(event, recv, error, cmd_queue):
def process(
send: Callable[[Dict], None],
recv: Callable[[], Dict],
handler: Callable[[str, List, Dict, Event, Callable[[str], None]], Dict],
handler: Callable[[str, List, Dict, Event, Callable[[str], None]], RpcResponse],
) -> None:
def error(status: str, message: str, body: Dict = {}):
send(dict(kind="error", status=status, message=message, body=body))
@ -88,8 +88,8 @@ def process(
def signal(status: str, body: Dict = {}):
send(dict(kind="signal", status=status, body=body))
def success(body: Dict):
send(dict(kind="success", body=body))
def success(response: RpcResponse):
send(dict(kind="success", body=response.body))
event = Event()
cmd_queue: Queue = Queue(1)

View File

@ -27,6 +27,12 @@ def encode_bytes(value: bytes) -> str:
decode_bytes = bytes.fromhex
class RpcResponse:
def __init__(self, body, side_effects=None):
self.body = body
self.side_effects = side_effects or []
class RpcException(Exception):
"""An exception that is returned as the result of an RPC command.i
@ -116,16 +122,20 @@ class RpcNode:
try:
if target:
traversed += [target[0]]
return self.get_child(target[0])(
response = self.get_child(target[0])(
action, target[1:], params, event, signal, traversed
)
if action in self.list_actions():
return self.get_action(action)(params, event, signal)
if action in self.list_children():
elif action in self.list_actions():
response = self.get_action(action)(params, event, signal)
elif action in self.list_children():
traversed += [action]
return self.get_child(action)(
response = self.get_child(action)(
"get", [], params, event, signal, traversed
)
if isinstance(response, RpcResponse):
return response
return RpcResponse(response)
except ChildResetException as e:
self._close_child()
raise StateResetException(e.message, traversed)

View File

@ -258,12 +258,24 @@ class AbstractDeviceNode(RpcNode):
super().__init__()
self._device = device
self._info = info
self._data = None
def __call__(self, *args, **kwargs):
try:
return super().__call__(*args, **kwargs)
response = super().__call__(*args, **kwargs)
if "device_info" in response.side_effects:
# Clear DeviceInfo cache
self._info = None
self._data = None
# Make sure any child node is re-opened after this,
# as enabled applications may have changed
super().close()
return response
except (SmartcardException, OSError):
logger.exception("Device error")
self._child = None
name = self._child_name
self._child_name = None
@ -276,6 +288,14 @@ class AbstractDeviceNode(RpcNode):
logger.exception(f"Unable to create child {name}")
raise NoSuchNodeException(name)
def get_data(self):
if not self._data:
self._data = self._refresh_data()
return self._data
def _refresh_data(self):
...
def _read_data(self, conn):
pid = self._device.pid
self._info = read_info(conn, pid)
@ -296,7 +316,7 @@ class UsbDeviceNode(AbstractDeviceNode):
connection = self._device.open_connection(conn_type)
return ConnectionNode(self._device, connection, self._info)
def get_data(self):
def _refresh_data(self):
for conn_type in (SmartCardConnection, OtpConnection, FidoConnection):
if self._supports_connection(conn_type):
try:
@ -335,7 +355,7 @@ class _ReaderObserver(CardObserver):
def __init__(self, device):
self.device = device
self.card = None
self.data = None
self.needs_refresh = True
def update(self, observable, actions):
added, removed = actions
@ -346,7 +366,7 @@ class _ReaderObserver(CardObserver):
break
else:
self.card = None
self.data = None
self.needs_refresh = True
logger.debug(f"NFC card: {self.card}")
@ -357,35 +377,29 @@ class ReaderDeviceNode(AbstractDeviceNode):
self._monitor = CardMonitor()
self._monitor.addObserver(self._observer)
def __call__(self, *args, **kwargs):
result = super().__call__(*args, **kwargs)
# Clear DeviceInfo cache on configure command
if ("configure", ["ccid", "management"]) == args[:2]:
self._observer.data = None
# Make sure any child node is re-opened after this,
# as enabled applications may have changed
super().close()
return result
def close(self):
self._monitor.deleteObserver(self._observer)
super().close()
def get_data(self):
if self._observer.data is None:
card = self._observer.card
if card is None:
return dict(present=False, status="no-card")
try:
with self._device.open_connection(SmartCardConnection) as conn:
self._observer.data = dict(self._read_data(conn), present=True)
except NoCardException:
return dict(present=False, status="no-card")
except ValueError:
self._observer.data = dict(present=False, status="unknown-device")
return self._observer.data
if self._observer.needs_refresh:
self._data = None
return super().get_data()
def _refresh_data(self):
card = self._observer.card
if card is None:
return dict(present=False, status="no-card")
try:
with self._device.open_connection(SmartCardConnection) as conn:
data = dict(self._read_data(conn), present=True)
self._observer.needs_refresh = False
return data
except NoCardException:
return dict(present=False, status="no-card")
except ValueError:
self._observer.needs_refresh = False
return dict(present=False, status="unknown-device")
@action(closes_child=False)
def get(self, params, event, signal):

View File

@ -13,6 +13,7 @@
# limitations under the License.
from .base import (
RpcResponse,
RpcNode,
action,
child,
@ -189,7 +190,7 @@ class Ctap2Node(RpcNode):
raise InactivityException()
self._info = self.ctap.get_info()
self._token = None
return dict()
return RpcResponse(dict(), ["device_info"])
@action(condition=lambda self: self._info.options["clientPin"])
def unlock(self, params, event, signal):
@ -224,7 +225,7 @@ class Ctap2Node(RpcNode):
params.pop("new_pin"),
)
self._info = self.ctap.get_info()
return dict()
return RpcResponse(dict(), ["device_info"])
except CtapError as e:
return _handle_pin_error(e, self.client_pin)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import RpcNode, action
from .base import RpcResponse, RpcNode, action
from yubikit.core import require_version, NotSupportedError, TRANSPORT, Connection
from yubikit.core.smartcard import SmartCardConnection
from yubikit.core.otp import OtpConnection
@ -90,7 +90,7 @@ class ManagementNode(RpcNode):
if reboot:
enabled = config.enabled_capabilities.get(TRANSPORT.USB)
self._await_reboot(serial, enabled)
return dict()
return RpcResponse(dict(), ["device_info"])
@action
def set_mode(self, params, event, signal):
@ -106,4 +106,4 @@ class ManagementNode(RpcNode):
)
def device_reset(self, params, event, signal):
self.session.device_reset()
return dict()
return RpcResponse(dict(), ["device_info"])

View File

@ -13,6 +13,7 @@
# limitations under the License.
from .base import (
RpcResponse,
RpcNode,
action,
child,
@ -193,7 +194,7 @@ class OathNode(RpcNode):
self.session.set_key(key)
self._set_key_verifier(key)
remember &= self._remember_key(key if remember else None)
return dict(remembered=remember)
return RpcResponse(dict(remembered=remember), ["device_info"])
@action(condition=lambda self: self.session.has_key)
def unset_key(self, params, event, signal):
@ -207,7 +208,7 @@ class OathNode(RpcNode):
self.session.reset()
self._key_verifier = None
self._remember_key(None)
return dict()
return RpcResponse(dict(), ["device_info"])
@child
def accounts(self):

View File

@ -13,6 +13,7 @@
# limitations under the License.
from .base import (
RpcResponse,
RpcNode,
action,
child,
@ -212,7 +213,7 @@ class PivNode(RpcNode):
store_key = params.pop("store_key", False)
pivman_set_mgm_key(self.session, key, key_type, False, store_key)
self._pivman_data = get_pivman_data(self.session)
return dict()
return RpcResponse(dict(), ["device_info"])
@action
def change_pin(self, params, event, signal):
@ -220,9 +221,9 @@ class PivNode(RpcNode):
new_pin = params.pop("new_pin")
try:
pivman_change_pin(self.session, old_pin, new_pin)
return RpcResponse(dict(), ["device_info"])
except Exception as e:
_handle_pin_puk_error(e)
return dict()
@action
def change_puk(self, params, event, signal):
@ -230,9 +231,9 @@ class PivNode(RpcNode):
new_puk = params.pop("new_puk")
try:
self.session.change_puk(old_puk, new_puk)
return RpcResponse(dict(), ["device_info"])
except Exception as e:
_handle_pin_puk_error(e)
return dict()
@action
def unblock_pin(self, params, event, signal):
@ -240,16 +241,16 @@ class PivNode(RpcNode):
new_pin = params.pop("new_pin")
try:
self.session.unblock_pin(puk, new_pin)
return RpcResponse(dict(), ["device_info"])
except Exception as e:
_handle_pin_puk_error(e)
return dict()
@action
def reset(self, params, event, signal):
self.session.reset()
self._authenticated = False
self._pivman_data = get_pivman_data(self.session)
return dict()
return RpcResponse(dict(), ["device_info"])
@child
def slots(self):
@ -266,9 +267,11 @@ class PivNode(RpcNode):
return dict(
status=True,
password=password is not None,
key_type=KEY_TYPE.from_public_key(private_key.public_key())
if private_key
else None,
key_type=(
KEY_TYPE.from_public_key(private_key.public_key())
if private_key
else None
),
cert_info=_get_cert_info(certificate),
)
except InvalidPasswordError:
@ -413,9 +416,11 @@ class SlotNode(RpcNode):
id=f"{int(self.slot):02x}",
name=self.slot.name,
metadata=_metadata_dict(self.metadata),
certificate=self.certificate.public_bytes(encoding=Encoding.PEM).decode()
if self.certificate
else None,
certificate=(
self.certificate.public_bytes(encoding=Encoding.PEM).decode()
if self.certificate
else None
),
)
@action(condition=lambda self: self.certificate or self.metadata)
@ -492,16 +497,20 @@ class SlotNode(RpcNode):
return dict(
metadata=_metadata_dict(metadata),
public_key=private_key.public_key()
.public_bytes(
encoding=Encoding.PEM, format=PublicFormat.SubjectPublicKeyInfo
)
.decode()
if private_key
else None,
certificate=self.certificate.public_bytes(encoding=Encoding.PEM).decode()
if certs
else None,
public_key=(
private_key.public_key()
.public_bytes(
encoding=Encoding.PEM, format=PublicFormat.SubjectPublicKeyInfo
)
.decode()
if private_key
else None
),
certificate=(
self.certificate.public_bytes(encoding=Encoding.PEM).decode()
if certs
else None
),
)
@action