mirror of
https://github.com/xinntao/ESRGAN.git
synced 2024-10-26 21:12:58 +03:00
39 lines
1.6 KiB
Python
39 lines
1.6 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import block as B
|
|
|
|
|
|
class RRDB_Net(nn.Module):
|
|
def __init__(self, in_nc, out_nc, nf, nb, gc=32, upscale=4, norm_type=None, act_type='leakyrelu', \
|
|
mode='CNA', res_scale=1, upsample_mode='upconv'):
|
|
super(RRDB_Net, self).__init__()
|
|
n_upscale = int(math.log(upscale, 2))
|
|
if upscale == 3:
|
|
n_upscale = 1
|
|
|
|
fea_conv = B.conv_block(in_nc, nf, kernel_size=3, norm_type=None, act_type=None)
|
|
rb_blocks = [B.RRDB(nf, kernel_size=3, gc=32, stride=1, bias=True, pad_type='zero', \
|
|
norm_type=norm_type, act_type=act_type, mode='CNA') for _ in range(nb)]
|
|
LR_conv = B.conv_block(nf, nf, kernel_size=3, norm_type=norm_type, act_type=None, mode=mode)
|
|
|
|
if upsample_mode == 'upconv':
|
|
upsample_block = B.upconv_blcok
|
|
elif upsample_mode == 'pixelshuffle':
|
|
upsample_block = B.pixelshuffle_block
|
|
else:
|
|
raise NotImplementedError('upsample mode [%s] is not found' % upsample_mode)
|
|
if upscale == 3:
|
|
upsampler = upsample_block(nf, nf, 3, act_type=act_type)
|
|
else:
|
|
upsampler = [upsample_block(nf, nf, act_type=act_type) for _ in range(n_upscale)]
|
|
HR_conv0 = B.conv_block(nf, nf, kernel_size=3, norm_type=None, act_type=act_type)
|
|
HR_conv1 = B.conv_block(nf, out_nc, kernel_size=3, norm_type=None, act_type=None)
|
|
|
|
self.model = B.sequential(fea_conv, B.ShortcutBlock(B.sequential(*rb_blocks, LR_conv)),\
|
|
*upsampler, HR_conv0, HR_conv1)
|
|
|
|
def forward(self, x):
|
|
x = self.model(x)
|
|
return x
|