ESRGAN/test.py

38 lines
1.1 KiB
Python
Raw Permalink Normal View History

2019-06-02 12:27:17 +03:00
import os.path as osp
2018-09-01 12:20:42 +03:00
import glob
import cv2
import numpy as np
import torch
2019-06-02 12:27:17 +03:00
import RRDBNet_arch as arch
2018-09-01 12:20:42 +03:00
2019-06-02 12:27:17 +03:00
model_path = 'models/RRDB_ESRGAN_x4.pth' # models/RRDB_ESRGAN_x4.pth OR models/RRDB_PSNR_x4.pth
2018-10-03 15:23:59 +03:00
device = torch.device('cuda') # if you want to run on CPU, change 'cuda' -> cpu
# device = torch.device('cpu')
2018-09-01 12:20:42 +03:00
test_img_folder = 'LR/*'
2019-06-02 12:27:17 +03:00
model = arch.RRDBNet(3, 3, 64, 23, gc=32)
2018-09-01 12:20:42 +03:00
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
2018-10-03 15:23:59 +03:00
model = model.to(device)
2018-09-01 12:20:42 +03:00
2018-09-05 17:41:42 +03:00
print('Model path {:s}. \nTesting...'.format(model_path))
2018-09-01 12:20:42 +03:00
idx = 0
for path in glob.glob(test_img_folder):
idx += 1
2019-06-02 12:27:17 +03:00
base = osp.splitext(osp.basename(path))[0]
2018-09-01 12:20:42 +03:00
print(idx, base)
2019-06-02 12:27:17 +03:00
# read images
2018-09-01 12:20:42 +03:00
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
2018-10-03 15:23:59 +03:00
img_LR = img_LR.to(device)
2018-09-01 12:20:42 +03:00
2019-06-02 12:27:17 +03:00
with torch.no_grad():
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
2018-09-01 12:20:42 +03:00
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
2018-09-05 17:41:42 +03:00
cv2.imwrite('results/{:s}_rlt.png'.format(base), output)