Automatically deserialize bytes in actions

This commit is contained in:
Dain Nilsson 2024-10-07 16:12:33 +02:00
parent f8b4eff328
commit 654ce4a6cc
No known key found for this signature in database
GPG Key ID: F04367096FBA95E8
5 changed files with 46 additions and 42 deletions

View File

@ -130,6 +130,13 @@ class RpcNode:
elif action in self.list_actions(): elif action in self.list_actions():
action_f = self.get_action(action) action_f = self.get_action(action)
args = inspect.signature(action_f).parameters args = inspect.signature(action_f).parameters
# Decode any serialized bytes parameters
for key, param in args.items():
if param.annotation in (bytes, bytes | None):
value = params.get(key, None)
if value is not None:
params[key] = decode_bytes(value)
# Add event and signal if requested
if "event" in args: if "event" in args:
params["event"] = event params["event"] = event
if "signal" in args: if "signal" in args:

View File

@ -90,15 +90,13 @@ class ManagementNode(RpcNode):
def configure( def configure(
self, self,
reboot: bool = False, reboot: bool = False,
cur_lock_code: str = "", cur_lock_code: bytes | None = None,
new_lock_code: str = "", new_lock_code: bytes | None = None,
enabled_capabilities: dict = {}, enabled_capabilities: dict = {},
auto_eject_timeout: int | None = None, auto_eject_timeout: int | None = None,
challenge_response_timeout: int | None = None, challenge_response_timeout: int | None = None,
device_flags: int | None = None, device_flags: int | None = None,
): ):
cur_code = bytes.fromhex(cur_lock_code) or None
new_code = bytes.fromhex(new_lock_code) or None
config = DeviceConfig( config = DeviceConfig(
enabled_capabilities, enabled_capabilities,
auto_eject_timeout, auto_eject_timeout,
@ -106,7 +104,7 @@ class ManagementNode(RpcNode):
DEVICE_FLAG(device_flags) if device_flags else None, DEVICE_FLAG(device_flags) if device_flags else None,
) )
serial = self.session.read_device_info().serial serial = self.session.read_device_info().serial
self.session.write_device_config(config, reboot, cur_code, new_code) self.session.write_device_config(config, reboot, cur_lock_code, new_lock_code)
flags = ["device_info"] flags = ["device_info"]
if reboot: if reboot:
enabled = config.enabled_capabilities.get(TRANSPORT.USB) enabled = config.enabled_capabilities.get(TRANSPORT.USB)

View File

@ -142,13 +142,13 @@ class OathNode(RpcNode):
else: else:
return False return False
def _get_key(self, key: str | None, password: str | None): def _get_key(self, key: bytes | None, password: str | None):
if key and password: if key and password:
raise ValueError("Only one of 'key' and 'password' can be provided.") raise ValueError("Only one of 'key' and 'password' can be provided.")
if password: if password:
return self.session.derive_key(password) return self.session.derive_key(password)
if key: if key:
return decode_bytes(key) return key
raise ValueError("One of 'key' and 'password' must be provided.") raise ValueError("One of 'key' and 'password' must be provided.")
def _set_key_verifier(self, key): def _set_key_verifier(self, key):
@ -163,7 +163,7 @@ class OathNode(RpcNode):
@action @action
def validate( def validate(
self, self,
key: str | None = None, key: bytes | None = None,
password: str | None = None, password: str | None = None,
remember: bool = False, remember: bool = False,
): ):
@ -192,7 +192,7 @@ class OathNode(RpcNode):
@action @action
def set_key( def set_key(
self, self,
key: str | None = None, key: bytes | None = None,
password: str | None = None, password: str | None = None,
remember: bool = False, remember: bool = False,
): ):

View File

@ -201,9 +201,9 @@ class PivNode(RpcNode):
return dict(status=True, authenticated=self._authenticated) return dict(status=True, authenticated=self._authenticated)
@action @action
def authenticate(self, signal, key: str): def authenticate(self, signal, key: bytes):
try: try:
self._authenticate(bytes.fromhex(key), signal) self._authenticate(key, signal)
return dict(status=True) return dict(status=True)
except ApduError as e: except ApduError as e:
if e.sw == SW.SECURITY_CONDITION_NOT_SATISFIED: if e.sw == SW.SECURITY_CONDITION_NOT_SATISFIED:
@ -213,14 +213,13 @@ class PivNode(RpcNode):
@action(condition=lambda self: self._authenticated) @action(condition=lambda self: self._authenticated)
def set_key( def set_key(
self, self,
params, key: bytes,
key: str,
key_type: int = MANAGEMENT_KEY_TYPE.TDES, key_type: int = MANAGEMENT_KEY_TYPE.TDES,
store_key: bool = False, store_key: bool = False,
): ):
pivman_set_mgm_key( pivman_set_mgm_key(
self.session, self.session,
bytes.fromhex(key), key,
MANAGEMENT_KEY_TYPE(key_type), MANAGEMENT_KEY_TYPE(key_type),
False, False,
store_key, store_key,
@ -264,9 +263,9 @@ class PivNode(RpcNode):
return SlotsNode(self.session) return SlotsNode(self.session)
@action(closes_child=False) @action(closes_child=False)
def examine_file(self, data: str, password: str | None = None): def examine_file(self, data: bytes, password: str | None = None):
try: try:
private_key, certs = _parse_file(bytes.fromhex(data), password) private_key, certs = _parse_file(data, password)
certificate = _choose_cert(certs) certificate = _choose_cert(certs)
return dict( return dict(
@ -461,9 +460,9 @@ class SlotNode(RpcNode):
return dict() return dict()
@action @action
def import_file(self, data: str, password: str | None = None, **kwargs): def import_file(self, data: bytes, password: str | None = None, **kwargs):
try: try:
private_key, certs = _parse_file(bytes.fromhex(data), password) private_key, certs = _parse_file(data, password)
except InvalidPasswordError: except InvalidPasswordError:
logger.debug("Invalid or missing password", exc_info=True) logger.debug("Invalid or missing password", exc_info=True)
raise ValueError("Wrong/Missing password") raise ValueError("Wrong/Missing password")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .base import RpcNode, action, child from .base import RpcNode, action, child, decode_bytes
from yubikit.core import NotSupportedError, CommandError from yubikit.core import NotSupportedError, CommandError
from yubikit.core.otp import modhex_encode, modhex_decode from yubikit.core.otp import modhex_encode, modhex_decode
@ -95,15 +95,15 @@ class YubiOtpNode(RpcNode):
self, self,
serial: int, serial: int,
public_id: str, public_id: str,
private_id: str, private_id: bytes,
key: str, key: bytes,
): ):
return dict( return dict(
csv=format_csv( csv=format_csv(
serial, serial,
modhex_decode(public_id), modhex_decode(public_id),
bytes.fromhex(private_id), private_id,
bytes.fromhex(key), key,
) )
) )
@ -145,19 +145,16 @@ class SlotNode(RpcNode):
return False return False
@action(condition=lambda self: self._maybe_configured(self.slot)) @action(condition=lambda self: self._maybe_configured(self.slot))
def delete(self, curr_acc_code: str | None = None): def delete(self, curr_acc_code: bytes | None = None):
try: try:
access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None self.session.delete_slot(self.slot, curr_acc_code)
self.session.delete_slot(self.slot, access_code)
return dict() return dict()
except CommandError: except CommandError:
raise ValueError(_FAIL_MSG) raise ValueError(_FAIL_MSG)
@action(condition=lambda self: self._can_calculate(self.slot)) @action(condition=lambda self: self._can_calculate(self.slot))
def calculate(self, event, challenge: str): def calculate(self, event, challenge: bytes):
response = self.session.calculate_hmac_sha1( response = self.session.calculate_hmac_sha1(self.slot, challenge, event)
self.slot, bytes.fromhex(challenge), event
)
return dict(response=response) return dict(response=response)
@staticmethod @staticmethod
@ -189,13 +186,13 @@ class SlotNode(RpcNode):
if "token_id" in options: if "token_id" in options:
token_id, *args = options.pop("token_id") token_id, *args = options.pop("token_id")
config.token_id(bytes.fromhex(token_id), *args) config.token_id(decode_bytes(token_id), *args)
@staticmethod @staticmethod
def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration: def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration:
match cfg_type: match cfg_type:
case "hmac_sha1": case "hmac_sha1":
return HmacSha1SlotConfiguration(bytes.fromhex(kwargs["key"])) return HmacSha1SlotConfiguration(decode_bytes(kwargs["key"]))
case "hotp": case "hotp":
return HotpSlotConfiguration(parse_b32_key(kwargs["key"])) return HotpSlotConfiguration(parse_b32_key(kwargs["key"]))
case "static_password": case "static_password":
@ -207,8 +204,8 @@ class SlotNode(RpcNode):
case "yubiotp": case "yubiotp":
return YubiOtpSlotConfiguration( return YubiOtpSlotConfiguration(
fixed=modhex_decode(kwargs["public_id"]), fixed=modhex_decode(kwargs["public_id"]),
uid=bytes.fromhex(kwargs["private_id"]), uid=decode_bytes(kwargs["private_id"]),
key=bytes.fromhex(kwargs["key"]), key=decode_bytes(kwargs["key"]),
) )
case unsupported: case unsupported:
raise ValueError( raise ValueError(
@ -217,17 +214,20 @@ class SlotNode(RpcNode):
@action @action
def put( def put(
self, type: str, options: dict = {}, curr_acc_code: str | None = None, **kwargs self,
type: str,
options: dict = {},
curr_acc_code: bytes | None = None,
**kwargs,
): ):
access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None
config = self._get_config(type, **kwargs) config = self._get_config(type, **kwargs)
self._apply_options(config, options) self._apply_options(config, options)
try: try:
self.session.put_configuration( self.session.put_configuration(
self.slot, self.slot,
config, config,
access_code, curr_acc_code,
access_code, curr_acc_code,
) )
return dict() return dict()
except CommandError: except CommandError:
@ -240,8 +240,8 @@ class SlotNode(RpcNode):
def update( def update(
self, self,
params, params,
acc_code: str | None = None, acc_code: bytes | None = None,
curr_acc_code: str | None = None, curr_acc_code: bytes | None = None,
**kwargs, **kwargs,
): ):
config = UpdateConfiguration() config = UpdateConfiguration()
@ -249,7 +249,7 @@ class SlotNode(RpcNode):
self.session.update_configuration( self.session.update_configuration(
self.slot, self.slot,
config, config,
bytes.fromhex(acc_code) if acc_code else None, acc_code,
bytes.fromhex(curr_acc_code) if curr_acc_code else None, curr_acc_code,
) )
return dict() return dict()