diff --git a/helper/helper/device.py b/helper/helper/device.py index e059c03d..b431487b 100644 --- a/helper/helper/device.py +++ b/helper/helper/device.py @@ -271,9 +271,13 @@ class AbstractDeviceNode(RpcNode): try: response = super().__call__(*args, **kwargs) if "device_info" in response.flags: - # Clear DeviceInfo cache - self._info = None - self._data = None + old_info = self._info + # Refresh data, and close any open child + self._close_child() + self._data = self._refresh_data() + if old_info == self._info: + # No change to DeviceInfo, further propagation not needed. + response.flags.remove("device_info") return response @@ -318,9 +322,17 @@ class UsbDeviceNode(AbstractDeviceNode): def _create_connection(self, conn_type): connection = self._device.open_connection(conn_type) + self._data = self._read_data(connection) return ConnectionNode(self._device, connection, self._info) def _refresh_data(self): + # Re-use existing connection if possible + if self._child and not self._child.closed: + # Make sure to close any open session + self._child._close_child() + return self._read_data(self._child._connection) + + # New connection for conn_type in (SmartCardConnection, OtpConnection, FidoConnection): if self._supports_connection(conn_type): try: @@ -398,6 +410,9 @@ class ReaderDeviceNode(AbstractDeviceNode): self._data = None return super().get_data() + def _read_data(self, conn): + return dict(super()._read_data(conn), present=True) + def _refresh_data(self): card = self._observer.card if card is None: @@ -405,7 +420,7 @@ class ReaderDeviceNode(AbstractDeviceNode): try: with self._device.open_connection(SmartCardConnection) as conn: try: - data = dict(self._read_data(conn), present=True) + data = self._read_data(conn) except ValueError: # Unknown device, maybe NFC restricted try: @@ -434,8 +449,8 @@ class ReaderDeviceNode(AbstractDeviceNode): def ccid(self): try: connection = self._device.open_connection(SmartCardConnection) - info = read_info(connection) - return ScpConnectionNode(self._device, connection, info) + self._data = self._read_data(connection) + return ScpConnectionNode(self._device, connection, self._info) except (ValueError, SmartcardException, EstablishContextException) as e: logger.warning("Error opening connection", exc_info=True) raise ConnectionException(self._device.fingerprint, "ccid", e) @@ -444,9 +459,9 @@ class ReaderDeviceNode(AbstractDeviceNode): def fido(self): try: with self._device.open_connection(SmartCardConnection) as conn: - info = read_info(conn) + self._data = self._read_data(conn) connection = self._device.open_connection(FidoConnection) - return ConnectionNode(self._device, connection, info) + return ConnectionNode(self._device, connection, self._info) except (ValueError, SmartcardException, EstablishContextException) as e: logger.warning("Error opening connection", exc_info=True) raise ConnectionException(self._device.fingerprint, "fido", e) @@ -458,24 +473,11 @@ class ConnectionNode(RpcNode): self._device = device self._transport = device.transport self._connection = connection - self._info = info or read_info(self._connection, device.pid) + self._info = info def __call__(self, *args, **kwargs): try: - response = super().__call__(*args, **kwargs) - if "device_info" in response.flags: - # Refresh DeviceInfo - info = read_info(self._connection, self._device.pid) - if self._info != info: - self._info = info - # Make sure any child node is re-opened after this, - # as enabled applications may have changed - self.close() - else: - # No change to DeviceInfo, further propagation not needed. - response.flags.remove("device_info") - - return response + return super().__call__(*args, **kwargs) except (SmartcardException, OSError) as e: logger.exception("Connection error") raise ChildResetException(f"{e}") @@ -504,11 +506,6 @@ class ConnectionNode(RpcNode): logger.warning("Error closing connection", exc_info=True) def get_data(self): - if ( - isinstance(self._connection, SmartCardConnection) - or self._transport == TRANSPORT.USB - ): - self._info = read_info(self._connection, self._device.pid) return dict(version=self._info.version, serial=self._info.serial) def _init_child_node(self, child_cls, capability=CAPABILITY(0)):