add net_interp.py

This commit is contained in:
xinntao 2018-09-05 23:21:30 +08:00
parent 1c38e686fe
commit 69f0c7e2b1

20
net_interp.py Normal file
View File

@ -0,0 +1,20 @@
import sys
import torch
from collections import OrderedDict
alpha = float(sys.argv[1])
net_PSNR_path = './models/RRDB_PSNR_x4.pth'
net_ESRGAN_path = './models/RRDB_ESRGAN_x4.pth'
net_interp_path = './models/interp_{:02d}.pth'.format(alpha*10)
net_PSNR = torch.load(net_PSNR_path)
net_ESRGAN = torch.load(net_ESRGAN_path)
net_interp = OrderedDict()
print('Interpolating with alphs = ', alpha)
for k, v_PSNR in net_PSNR.items():
v_ESRGAN = net_ESRGAN[k]
net_interp[k] = alpha * v_PSNR + (1 - alpha) * v_ESRGAN
torch.save(net_interp_path, net_interp)