Automatically deserialize bytes in actions

This commit is contained in:
Dain Nilsson 2024-10-07 16:12:33 +02:00
parent f8b4eff328
commit 654ce4a6cc
No known key found for this signature in database
GPG Key ID: F04367096FBA95E8
5 changed files with 46 additions and 42 deletions

View File

@ -130,6 +130,13 @@ class RpcNode:
elif action in self.list_actions():
action_f = self.get_action(action)
args = inspect.signature(action_f).parameters
# Decode any serialized bytes parameters
for key, param in args.items():
if param.annotation in (bytes, bytes | None):
value = params.get(key, None)
if value is not None:
params[key] = decode_bytes(value)
# Add event and signal if requested
if "event" in args:
params["event"] = event
if "signal" in args:

View File

@ -90,15 +90,13 @@ class ManagementNode(RpcNode):
def configure(
self,
reboot: bool = False,
cur_lock_code: str = "",
new_lock_code: str = "",
cur_lock_code: bytes | None = None,
new_lock_code: bytes | None = None,
enabled_capabilities: dict = {},
auto_eject_timeout: int | None = None,
challenge_response_timeout: int | None = None,
device_flags: int | None = None,
):
cur_code = bytes.fromhex(cur_lock_code) or None
new_code = bytes.fromhex(new_lock_code) or None
config = DeviceConfig(
enabled_capabilities,
auto_eject_timeout,
@ -106,7 +104,7 @@ class ManagementNode(RpcNode):
DEVICE_FLAG(device_flags) if device_flags else None,
)
serial = self.session.read_device_info().serial
self.session.write_device_config(config, reboot, cur_code, new_code)
self.session.write_device_config(config, reboot, cur_lock_code, new_lock_code)
flags = ["device_info"]
if reboot:
enabled = config.enabled_capabilities.get(TRANSPORT.USB)

View File

@ -142,13 +142,13 @@ class OathNode(RpcNode):
else:
return False
def _get_key(self, key: str | None, password: str | None):
def _get_key(self, key: bytes | None, password: str | None):
if key and password:
raise ValueError("Only one of 'key' and 'password' can be provided.")
if password:
return self.session.derive_key(password)
if key:
return decode_bytes(key)
return key
raise ValueError("One of 'key' and 'password' must be provided.")
def _set_key_verifier(self, key):
@ -163,7 +163,7 @@ class OathNode(RpcNode):
@action
def validate(
self,
key: str | None = None,
key: bytes | None = None,
password: str | None = None,
remember: bool = False,
):
@ -192,7 +192,7 @@ class OathNode(RpcNode):
@action
def set_key(
self,
key: str | None = None,
key: bytes | None = None,
password: str | None = None,
remember: bool = False,
):

View File

@ -201,9 +201,9 @@ class PivNode(RpcNode):
return dict(status=True, authenticated=self._authenticated)
@action
def authenticate(self, signal, key: str):
def authenticate(self, signal, key: bytes):
try:
self._authenticate(bytes.fromhex(key), signal)
self._authenticate(key, signal)
return dict(status=True)
except ApduError as e:
if e.sw == SW.SECURITY_CONDITION_NOT_SATISFIED:
@ -213,14 +213,13 @@ class PivNode(RpcNode):
@action(condition=lambda self: self._authenticated)
def set_key(
self,
params,
key: str,
key: bytes,
key_type: int = MANAGEMENT_KEY_TYPE.TDES,
store_key: bool = False,
):
pivman_set_mgm_key(
self.session,
bytes.fromhex(key),
key,
MANAGEMENT_KEY_TYPE(key_type),
False,
store_key,
@ -264,9 +263,9 @@ class PivNode(RpcNode):
return SlotsNode(self.session)
@action(closes_child=False)
def examine_file(self, data: str, password: str | None = None):
def examine_file(self, data: bytes, password: str | None = None):
try:
private_key, certs = _parse_file(bytes.fromhex(data), password)
private_key, certs = _parse_file(data, password)
certificate = _choose_cert(certs)
return dict(
@ -461,9 +460,9 @@ class SlotNode(RpcNode):
return dict()
@action
def import_file(self, data: str, password: str | None = None, **kwargs):
def import_file(self, data: bytes, password: str | None = None, **kwargs):
try:
private_key, certs = _parse_file(bytes.fromhex(data), password)
private_key, certs = _parse_file(data, password)
except InvalidPasswordError:
logger.debug("Invalid or missing password", exc_info=True)
raise ValueError("Wrong/Missing password")

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base import RpcNode, action, child
from .base import RpcNode, action, child, decode_bytes
from yubikit.core import NotSupportedError, CommandError
from yubikit.core.otp import modhex_encode, modhex_decode
@ -95,15 +95,15 @@ class YubiOtpNode(RpcNode):
self,
serial: int,
public_id: str,
private_id: str,
key: str,
private_id: bytes,
key: bytes,
):
return dict(
csv=format_csv(
serial,
modhex_decode(public_id),
bytes.fromhex(private_id),
bytes.fromhex(key),
private_id,
key,
)
)
@ -145,19 +145,16 @@ class SlotNode(RpcNode):
return False
@action(condition=lambda self: self._maybe_configured(self.slot))
def delete(self, curr_acc_code: str | None = None):
def delete(self, curr_acc_code: bytes | None = None):
try:
access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None
self.session.delete_slot(self.slot, access_code)
self.session.delete_slot(self.slot, curr_acc_code)
return dict()
except CommandError:
raise ValueError(_FAIL_MSG)
@action(condition=lambda self: self._can_calculate(self.slot))
def calculate(self, event, challenge: str):
response = self.session.calculate_hmac_sha1(
self.slot, bytes.fromhex(challenge), event
)
def calculate(self, event, challenge: bytes):
response = self.session.calculate_hmac_sha1(self.slot, challenge, event)
return dict(response=response)
@staticmethod
@ -189,13 +186,13 @@ class SlotNode(RpcNode):
if "token_id" in options:
token_id, *args = options.pop("token_id")
config.token_id(bytes.fromhex(token_id), *args)
config.token_id(decode_bytes(token_id), *args)
@staticmethod
def _get_config(cfg_type: str, **kwargs) -> SlotConfiguration:
match cfg_type:
case "hmac_sha1":
return HmacSha1SlotConfiguration(bytes.fromhex(kwargs["key"]))
return HmacSha1SlotConfiguration(decode_bytes(kwargs["key"]))
case "hotp":
return HotpSlotConfiguration(parse_b32_key(kwargs["key"]))
case "static_password":
@ -207,8 +204,8 @@ class SlotNode(RpcNode):
case "yubiotp":
return YubiOtpSlotConfiguration(
fixed=modhex_decode(kwargs["public_id"]),
uid=bytes.fromhex(kwargs["private_id"]),
key=bytes.fromhex(kwargs["key"]),
uid=decode_bytes(kwargs["private_id"]),
key=decode_bytes(kwargs["key"]),
)
case unsupported:
raise ValueError(
@ -217,17 +214,20 @@ class SlotNode(RpcNode):
@action
def put(
self, type: str, options: dict = {}, curr_acc_code: str | None = None, **kwargs
self,
type: str,
options: dict = {},
curr_acc_code: bytes | None = None,
**kwargs,
):
access_code = bytes.fromhex(curr_acc_code) if curr_acc_code else None
config = self._get_config(type, **kwargs)
self._apply_options(config, options)
try:
self.session.put_configuration(
self.slot,
config,
access_code,
access_code,
curr_acc_code,
curr_acc_code,
)
return dict()
except CommandError:
@ -240,8 +240,8 @@ class SlotNode(RpcNode):
def update(
self,
params,
acc_code: str | None = None,
curr_acc_code: str | None = None,
acc_code: bytes | None = None,
curr_acc_code: bytes | None = None,
**kwargs,
):
config = UpdateConfiguration()
@ -249,7 +249,7 @@ class SlotNode(RpcNode):
self.session.update_configuration(
self.slot,
config,
bytes.fromhex(acc_code) if acc_code else None,
bytes.fromhex(curr_acc_code) if curr_acc_code else None,
acc_code,
curr_acc_code,
)
return dict()