diff --git a/modules/devices.py b/modules/devices.py index 4c63f465..058a5e00 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -39,12 +39,9 @@ def torch_gc(): def enable_tf32(): if torch.cuda.is_available(): - #TODO: make this better; find a way to check if it is a turing card - turing = ["1630","1650","1660","Quadro RTX 3000","Quadro RTX 4000","Quadro RTX 4000","Quadro RTX 5000","Quadro RTX 5000","Quadro RTX 6000","Quadro RTX 6000","Quadro RTX 8000","Quadro RTX T400","Quadro RTX T400","Quadro RTX T600","Quadro RTX T1000","Quadro RTX T1000","2060","2070","2080","Titan RTX","Tesla T4","MX450","MX550"] for devid in range(0,torch.cuda.device_count()): - for i in turing: - if i in torch.cuda.get_device_name(devid): - shd = True + if torch.cuda.get_device_capability(devid) == (7, 5): + shd = True if shd: torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = True