add test codes

This commit is contained in:
xinntao 2018-09-01 17:20:42 +08:00
parent b64b650df5
commit 140595ab58
6 changed files with 342 additions and 0 deletions

BIN
LR/baboon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

BIN
LR/comic.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

38
architecture.py Normal file
View File

@ -0,0 +1,38 @@
import math
import torch
import torch.nn as nn
import block as B
class RRDB_Net(nn.Module):
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
mode='CNA', res_scale=1, upsample_mode='upconv'):
super(RRDB_Net, self).__init__()
n_upscale = int(math.log(upscale, 2))
if upscale == 3:
n_upscale = 1
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
if upsample_mode == 'upconv':
upsample_block = B.upconv_blcok
elif upsample_mode == 'pixelshuffle':
upsample_block = B.pixelshuffle_block
else:
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
if upscale == 3:
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
else:
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
*upsampler, HR_conv0, HR_conv1)
def forward(self, x):
x = self.model(x)
return x

261
block.py Normal file
View File

@ -0,0 +1,261 @@
from collections import OrderedDict
import torch
import torch.nn as nn
####################
# Basic blocks
####################
def act(act_type, inplace=True, neg_slope=0.2, n_prelu=1):
# helper selecting activation
# neg_slope: for leakyrelu and init of prelu
# n_prelu: for p_relu num_parameters
act_type = act_type.lower()
if act_type == 'relu':
layer = nn.ReLU(inplace)
elif act_type == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act_type == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError('activation layer [%s] is not found' % act_type)
return layer
def norm(norm_type, nc):
# helper selecting normalization layer
norm_type = norm_type.lower()
if norm_type == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm_type == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return layer
def pad(pad_type, padding):
# helper selecting padding layer
# if padding is 'zero', do by conv layers
pad_type = pad_type.lower()
if padding == 0:
return None
if pad_type == 'reflect':
layer = nn.ReflectionPad2d(padding)
elif pad_type == 'replicate':
layer = nn.ReplicationPad2d(padding)
else:
raise NotImplementedError('padding layer [%s] is not implemented' % pad_type)
return layer
def get_valid_padding(kernel_size, dilation):
kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
padding = (kernel_size - 1) // 2
return padding
class ConcatBlock(nn.Module):
# Concat the output of a submodule to its input
def __init__(self, submodule):
super(ConcatBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = torch.cat((x, self.sub(x)), dim=1)
return output
def __repr__(self):
tmpstr = 'Identity .. \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
class ShortcutBlock(nn.Module):
#Elementwise sum the output of a submodule to its input
def __init__(self, submodule):
super(ShortcutBlock, self).__init__()
self.sub = submodule
def forward(self, x):
output = x + self.sub(x)
return output
def __repr__(self):
tmpstr = 'Identity + \n|'
modstr = self.sub.__repr__().replace('\n', '\n|')
tmpstr = tmpstr + modstr
return tmpstr
def sequential(*args):
# Flatten Sequential. It unwraps nn.Sequential.
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError('sequential does not support OrderedDict input.')
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv_block(in_nc, out_nc, kernel_size, stride=1, dilation=1, groups=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', mode='CNA'):
"""
Conv layer with padding, normalization, activation
mode: CNA --> Conv -> Norm -> Act
NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
"""
assert mode in ['CNA', 'NAC', 'CNAC'], 'Wong conv mode [%s]' % mode
padding = get_valid_padding(kernel_size, dilation)
p = pad(pad_type, padding) if pad_type and pad_type != 'zero' else None
padding = padding if pad_type == 'zero' else 0
c = nn.Conv2d(in_nc, out_nc, kernel_size=kernel_size, stride=stride, padding=padding, \
dilation=dilation, bias=bias, groups=groups)
a = act(act_type) if act_type else None
if 'CNA' in mode:
n = norm(norm_type, out_nc) if norm_type else None
return sequential(p, c, n, a)
elif mode == 'NAC':
if norm_type is None and act_type is not None:
a = act(act_type, inplace=False)
# Important!
# input----ReLU(inplace)----Conv--+----output
# |________________________|
# inplace ReLU will modify the input, therefore wrong output
n = norm(norm_type, in_nc) if norm_type else None
return sequential(n, a, p, c)
####################
# Useful blocks
####################
class ResNetBlock(nn.Module):
"""
ResNet Block, 3-3 style
with extra residual scaling used in EDSR
(Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
"""
def __init__(self, in_nc, mid_nc, out_nc, kernel_size=3, stride=1, dilation=1, groups=1, \
bias=True, pad_type='zero', norm_type=None, act_type='relu', mode='CNA', res_scale=1):
super(ResNetBlock, self).__init__()
conv0 = conv_block(in_nc, mid_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
if mode == 'CNA':
act_type = None
if mode == 'CNAC': # Residual path: |-CNAC-|
act_type = None
norm_type = None
conv1 = conv_block(mid_nc, out_nc, kernel_size, stride, dilation, groups, bias, pad_type, \
norm_type, act_type, mode)
# if in_nc != out_nc:
# self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
# None, None)
# print('Need a projecter in ResNetBlock.')
# else:
# self.project = lambda x:x
self.res = sequential(conv0, conv1)
self.res_scale = res_scale
def forward(self, x):
res = self.res(x).mul(self.res_scale)
return x + res
class ResidualDenseBlock_5C(nn.Module):
"""
Residual Dense Block
style: 5 convs
The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
"""
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(ResidualDenseBlock_5C, self).__init__()
# gc: growth channel, i.e. intermediate channels
self.conv1 = conv_block(nc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv2 = conv_block(nc+gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv3 = conv_block(nc+2*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
self.conv4 = conv_block(nc+3*gc, gc, kernel_size, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=act_type, mode=mode)
if mode == 'CNA':
last_act = None
else:
last_act = act_type
self.conv5 = conv_block(nc+4*gc, nc, 3, stride, bias=bias, pad_type=pad_type, \
norm_type=norm_type, act_type=last_act, mode=mode)
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(torch.cat((x, x1), 1))
x3 = self.conv3(torch.cat((x, x1, x2), 1))
x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
return x5.mul(0.2) + x
class RRDB(nn.Module):
"""
Residual in Residual Dense Block
"""
def __init__(self, nc, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
norm_type=None, act_type='leakyrelu', mode='CNA'):
super(RRDB, self).__init__()
self.RDB1 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB2 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
self.RDB3 = ResidualDenseBlock_5C(nc, kernel_size, gc, stride, bias, pad_type, \
norm_type, act_type, mode)
def forward(self, x):
out = self.RDB1(x)
out = self.RDB2(out)
out = self.RDB3(out)
return out.mul(0.2) + x
####################
# Upsampler
####################
def pixelshuffle_block(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu'):
"""
Pixel shuffle layer
(Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
Neural Network, CVPR17)
"""
conv = conv_block(in_nc, out_nc * (upscale_factor ** 2), kernel_size, stride, bias=bias,
pad_type=pad_type, norm_type=None, act_type=None)
pixel_shuffle = nn.PixelShuffle(upscale_factor)
n = norm(norm_type, out_nc) if norm_type else None
a = act(act_type) if act_type else None
return sequential(conv, pixel_shuffle, n, a)
def upconv_blcok(in_nc, out_nc, upscale_factor=2, kernel_size=3, stride=1, bias=True,
pad_type='zero', norm_type=None, act_type='relu', mode='nearest'):
# Up conv
# described in https://distill.pub/2016/deconv-checkerboard/
upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
conv = conv_block(in_nc, out_nc, kernel_size, stride, bias=bias,
pad_type=pad_type, norm_type=norm_type, act_type=act_type)
return sequential(upsample, conv)

BIN
results/baboon_ESRGAN.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 551 KiB

43
test.py Normal file
View File

@ -0,0 +1,43 @@
import os.path
import glob
import cv2
import numpy as np
import torch
import architecture as arch
mode = 'ESRGAN' # ESRGAN or RRDB_PSNR
if mode == 'ESRGAN':
model_path = './models/RRDB_ESRGAN_x4.pth'
elif mode == 'RRDB_PSNR':
model_path = './models/RRDB_PSNR_DF2K_x4.pth'
test_img_folder = 'LR/*'
model = arch.RRDB_Net(3, 3, 64, 23, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
mode='CNA', res_scale=1, upsample_mode='upconv')
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
for k, v in model.named_parameters():
v.requires_grad = False
model = model.cuda()
print('Mode {:s}. \nTesting...'.format(mode))
idx = 0
for path in glob.glob(test_img_folder):
idx += 1
base = os.path.splitext(os.path.basename(path))[0]
print(idx, base)
# read image
img = cv2.imread(path, cv2.IMREAD_COLOR)
img = img * 1.0 / 255
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img_LR = img.unsqueeze(0)
img_LR = img_LR.cuda()
output = model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
output = (output * 255.0).round()
cv2.imwrite('results/{:s}_{}.png'.format(base, mode), output)