Refactor DeviceInfo caching

This commit is contained in:
Dain Nilsson 2024-09-13 10:45:41 +02:00
parent 5a24f57e0b
commit d8e9cf34c9
No known key found for this signature in database
GPG Key ID: F04367096FBA95E8

View File

@ -271,9 +271,13 @@ class AbstractDeviceNode(RpcNode):
try:
response = super().__call__(*args, **kwargs)
if "device_info" in response.flags:
# Clear DeviceInfo cache
self._info = None
self._data = None
old_info = self._info
# Refresh data, and close any open child
self._close_child()
self._data = self._refresh_data()
if old_info == self._info:
# No change to DeviceInfo, further propagation not needed.
response.flags.remove("device_info")
return response
@ -318,9 +322,17 @@ class UsbDeviceNode(AbstractDeviceNode):
def _create_connection(self, conn_type):
connection = self._device.open_connection(conn_type)
self._data = self._read_data(connection)
return ConnectionNode(self._device, connection, self._info)
def _refresh_data(self):
# Re-use existing connection if possible
if self._child and not self._child.closed:
# Make sure to close any open session
self._child._close_child()
return self._read_data(self._child._connection)
# New connection
for conn_type in (SmartCardConnection, OtpConnection, FidoConnection):
if self._supports_connection(conn_type):
try:
@ -398,6 +410,9 @@ class ReaderDeviceNode(AbstractDeviceNode):
self._data = None
return super().get_data()
def _read_data(self, conn):
return dict(super()._read_data(conn), present=True)
def _refresh_data(self):
card = self._observer.card
if card is None:
@ -405,7 +420,7 @@ class ReaderDeviceNode(AbstractDeviceNode):
try:
with self._device.open_connection(SmartCardConnection) as conn:
try:
data = dict(self._read_data(conn), present=True)
data = self._read_data(conn)
except ValueError:
# Unknown device, maybe NFC restricted
try:
@ -434,8 +449,8 @@ class ReaderDeviceNode(AbstractDeviceNode):
def ccid(self):
try:
connection = self._device.open_connection(SmartCardConnection)
info = read_info(connection)
return ScpConnectionNode(self._device, connection, info)
self._data = self._read_data(connection)
return ScpConnectionNode(self._device, connection, self._info)
except (ValueError, SmartcardException, EstablishContextException) as e:
logger.warning("Error opening connection", exc_info=True)
raise ConnectionException(self._device.fingerprint, "ccid", e)
@ -444,9 +459,9 @@ class ReaderDeviceNode(AbstractDeviceNode):
def fido(self):
try:
with self._device.open_connection(SmartCardConnection) as conn:
info = read_info(conn)
self._data = self._read_data(conn)
connection = self._device.open_connection(FidoConnection)
return ConnectionNode(self._device, connection, info)
return ConnectionNode(self._device, connection, self._info)
except (ValueError, SmartcardException, EstablishContextException) as e:
logger.warning("Error opening connection", exc_info=True)
raise ConnectionException(self._device.fingerprint, "fido", e)
@ -458,24 +473,11 @@ class ConnectionNode(RpcNode):
self._device = device
self._transport = device.transport
self._connection = connection
self._info = info or read_info(self._connection, device.pid)
self._info = info
def __call__(self, *args, **kwargs):
try:
response = super().__call__(*args, **kwargs)
if "device_info" in response.flags:
# Refresh DeviceInfo
info = read_info(self._connection, self._device.pid)
if self._info != info:
self._info = info
# Make sure any child node is re-opened after this,
# as enabled applications may have changed
self.close()
else:
# No change to DeviceInfo, further propagation not needed.
response.flags.remove("device_info")
return response
return super().__call__(*args, **kwargs)
except (SmartcardException, OSError) as e:
logger.exception("Connection error")
raise ChildResetException(f"{e}")
@ -504,11 +506,6 @@ class ConnectionNode(RpcNode):
logger.warning("Error closing connection", exc_info=True)
def get_data(self):
if (
isinstance(self._connection, SmartCardConnection)
or self._transport == TRANSPORT.USB
):
self._info = read_info(self._connection, self._device.pid)
return dict(version=self._info.version, serial=self._info.serial)
def _init_child_node(self, child_cls, capability=CAPABILITY(0)):