|
from torch import nn as nn |
|
from torch.nn import functional as F |
|
from torch.nn.utils import spectral_norm |
|
|
|
|
|
class UNetDiscriminatorSN(nn.Module): |
|
"""Defines a U-Net discriminator with spectral normalization (SN) |
|
|
|
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. |
|
|
|
Arg: |
|
num_in_ch (int): Channel number of inputs. Default: 3. |
|
num_feat (int): Channel number of base intermediate features. Default: 64. |
|
skip_connection (bool): Whether to use skip connections between U-Net. Default: True. |
|
""" |
|
|
|
def __init__(self, num_in_ch, num_feat=64, skip_connection=True): |
|
super(UNetDiscriminatorSN, self).__init__() |
|
self.skip_connection = skip_connection |
|
norm = spectral_norm |
|
|
|
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) |
|
|
|
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) |
|
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) |
|
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) |
|
|
|
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) |
|
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) |
|
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) |
|
|
|
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) |
|
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) |
|
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) |
|
|
|
def forward(self, x): |
|
|
|
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) |
|
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) |
|
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) |
|
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) |
|
|
|
|
|
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) |
|
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) |
|
|
|
if self.skip_connection: |
|
x4 = x4 + x2 |
|
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) |
|
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) |
|
|
|
if self.skip_connection: |
|
x5 = x5 + x1 |
|
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) |
|
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) |
|
|
|
if self.skip_connection: |
|
x6 = x6 + x0 |
|
|
|
|
|
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) |
|
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) |
|
out = self.conv9(out) |
|
|
|
return out |
|
|
|
|
|
class GANLoss(nn.Module): |
|
"""Define GAN loss. |
|
|
|
Args: |
|
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. |
|
real_label_val (float): The value for real label. Default: 1.0. |
|
fake_label_val (float): The value for fake label. Default: 0.0. |
|
loss_weight (float): Loss weight. Default: 1.0. |
|
Note that loss_weight is only for generators; and it is always 1.0 |
|
for discriminators. |
|
""" |
|
|
|
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): |
|
super(GANLoss, self).__init__() |
|
self.gan_type = gan_type |
|
self.loss_weight = loss_weight |
|
self.real_label_val = real_label_val |
|
self.fake_label_val = fake_label_val |
|
|
|
if self.gan_type == 'vanilla': |
|
self.loss = nn.BCEWithLogitsLoss() |
|
elif self.gan_type == 'lsgan': |
|
self.loss = nn.MSELoss() |
|
elif self.gan_type == 'wgan': |
|
self.loss = self._wgan_loss |
|
elif self.gan_type == 'wgan_softplus': |
|
self.loss = self._wgan_softplus_loss |
|
elif self.gan_type == 'hinge': |
|
self.loss = nn.ReLU() |
|
else: |
|
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') |
|
|
|
def _wgan_loss(self, input, target): |
|
"""wgan loss. |
|
|
|
Args: |
|
input (Tensor): Input tensor. |
|
target (bool): Target label. |
|
|
|
Returns: |
|
Tensor: wgan loss. |
|
""" |
|
return -input.mean() if target else input.mean() |
|
|
|
def _wgan_softplus_loss(self, input, target): |
|
"""wgan loss with soft plus. softplus is a smooth approximation to the |
|
ReLU function. |
|
|
|
In StyleGAN2, it is called: |
|
Logistic loss for discriminator; |
|
Non-saturating loss for generator. |
|
|
|
Args: |
|
input (Tensor): Input tensor. |
|
target (bool): Target label. |
|
|
|
Returns: |
|
Tensor: wgan loss. |
|
""" |
|
return F.softplus(-input).mean() if target else F.softplus(input).mean() |
|
|
|
def get_target_label(self, input, target_is_real): |
|
"""Get target label. |
|
|
|
Args: |
|
input (Tensor): Input tensor. |
|
target_is_real (bool): Whether the target is real or fake. |
|
|
|
Returns: |
|
(bool | Tensor): Target tensor. Return bool for wgan, otherwise, |
|
return Tensor. |
|
""" |
|
|
|
if self.gan_type in ['wgan', 'wgan_softplus']: |
|
return target_is_real |
|
target_val = (self.real_label_val if target_is_real else self.fake_label_val) |
|
return input.new_ones(input.size()) * target_val |
|
|
|
def forward(self, input, target_is_real, is_disc=False): |
|
""" |
|
Args: |
|
input (Tensor): The input for the loss module, i.e., the network |
|
prediction. |
|
target_is_real (bool): Whether the targe is real or fake. |
|
is_disc (bool): Whether the loss for discriminators or not. |
|
Default: False. |
|
|
|
Returns: |
|
Tensor: GAN loss value. |
|
""" |
|
target_label = self.get_target_label(input, target_is_real) |
|
if self.gan_type == 'hinge': |
|
if is_disc: |
|
input = -input if target_is_real else input |
|
loss = self.loss(1 + input).mean() |
|
else: |
|
loss = -input.mean() |
|
else: |
|
loss = self.loss(input, target_label) |
|
|
|
|
|
return loss if is_disc else loss * self.loss_weight |