lambdanet / backup /deblur /src /model /discriminator.py
hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, args):
super(Discriminator, self).__init__()
# self.args = args
n_feats = args.n_feats
kernel_size = args.kernel_size
def conv(kernel_size, in_channel, n_feats, stride, pad=None):
if pad is None:
pad = (kernel_size-1)//2
return nn.Conv2d(in_channel, n_feats, kernel_size, stride=stride, padding=pad, bias=False)
self.conv_layers = nn.ModuleList([
conv(kernel_size, 3, n_feats//2, 1), # 256
conv(kernel_size, n_feats//2, n_feats//2, 2), # 128
conv(kernel_size, n_feats//2, n_feats, 1),
conv(kernel_size, n_feats, n_feats, 2), # 64
conv(kernel_size, n_feats, n_feats*2, 1),
conv(kernel_size, n_feats*2, n_feats*2, 4), # 16
conv(kernel_size, n_feats*2, n_feats*4, 1),
conv(kernel_size, n_feats*4, n_feats*4, 4), # 4
conv(kernel_size, n_feats*4, n_feats*8, 1),
conv(4, n_feats*8, n_feats*8, 4, 0), # 1
])
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.dense = nn.Conv2d(n_feats*8, 1, 1, bias=False)
def forward(self, x):
for layer in self.conv_layers:
x = self.act(layer(x))
x = self.dense(x)
return x