add run on CPU

This commit is contained in:
XintaoWang 2018-10-03 20:23:59 +08:00
parent fa82b59aab
commit 6f106de05c

View File

@ -7,6 +7,8 @@ import torch
import architecture as arch
model_path = sys.argv[1] # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
# device = torch.device('cpu')
test_img_folder = 'LR/*'
@ -16,7 +18,7 @@ model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.cuda()
model = model.to(device)
print('Model path {:s}. \nTesting...'.format(model_path))
@ -30,7 +32,7 @@ for path in glob.glob(test_img_folder):
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.cuda()
img_LR = img_LR.to(device)
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))