test.py with model path

This commit is contained in:
xinntao 2018-09-05 22:41:42 +08:00
parent 587d8f529b
commit 5cff4ca5f0

11
test.py
View File

@ -6,12 +6,7 @@ import numpy as np
import torch
import architecture as arch
mode = sys.argv[1] # ESRGAN or RRDB_PSNR
if mode == 'ESRGAN':
model_path = './models/RRDB_ESRGAN_x4.pth'
elif mode == 'RRDB_PSNR':
model_path = './models/RRDB_PSNR_x4.pth'
model_path = sys.argv[1] # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
test_img_folder = 'LR/*'
@ -23,7 +18,7 @@ for k, v in model.named_parameters():
v.requires_grad = False
model = model.cuda()
print('Mode {:s}. \nTesting...'.format(mode))
print('Model path {:s}. \nTesting...'.format(model_path))
idx = 0
for path in glob.glob(test_img_folder):
@ -40,4 +35,4 @@ for path in glob.glob(test_img_folder):
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite('results/{:s}_{}.png'.format(base, mode), output)
cv2.imwrite('results/{:s}_rlt.png'.format(base), output)