diff --git a/modules/devices.py b/modules/devices.py index 046460fa..1325569c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -53,10 +53,17 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): + for devid in range(0,torch.cuda.device_count()): + if torch.cuda.get_device_capability(devid) == (7, 5): + shd = True + if shd: + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.enabled = True torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True + errors.run(enable_tf32, "Enabling TF32") cpu = torch.device("cpu")