EditGuard / models /discrim.py
Ricoooo's picture
'folder'
5d21dd2
from torch import nn as nn
from torch.nn import functional as F
from torch.nn.utils import spectral_norm
class UNetDiscriminatorSN(nn.Module):
"""Defines a U-Net discriminator with spectral normalization (SN)
It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
"""
def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
super(UNetDiscriminatorSN, self).__init__()
self.skip_connection = skip_connection
norm = spectral_norm
# the first convolution
self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
# downsample
self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
# upsample
self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
# extra convolutions
self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
def forward(self, x):
# downsample
x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
# upsample
x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
if self.skip_connection:
x4 = x4 + x2
x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
if self.skip_connection:
x5 = x5 + x1
x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
if self.skip_connection:
x6 = x6 + x0
# extra convolutions
out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
out = self.conv9(out)
return out
class GANLoss(nn.Module):
"""Define GAN loss.
Args:
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
real_label_val (float): The value for real label. Default: 1.0.
fake_label_val (float): The value for fake label. Default: 0.0.
loss_weight (float): Loss weight. Default: 1.0.
Note that loss_weight is only for generators; and it is always 1.0
for discriminators.
"""
def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0):
super(GANLoss, self).__init__()
self.gan_type = gan_type
self.loss_weight = loss_weight
self.real_label_val = real_label_val
self.fake_label_val = fake_label_val
if self.gan_type == 'vanilla':
self.loss = nn.BCEWithLogitsLoss()
elif self.gan_type == 'lsgan':
self.loss = nn.MSELoss()
elif self.gan_type == 'wgan':
self.loss = self._wgan_loss
elif self.gan_type == 'wgan_softplus':
self.loss = self._wgan_softplus_loss
elif self.gan_type == 'hinge':
self.loss = nn.ReLU()
else:
raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.')
def _wgan_loss(self, input, target):
"""wgan loss.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return -input.mean() if target else input.mean()
def _wgan_softplus_loss(self, input, target):
"""wgan loss with soft plus. softplus is a smooth approximation to the
ReLU function.
In StyleGAN2, it is called:
Logistic loss for discriminator;
Non-saturating loss for generator.
Args:
input (Tensor): Input tensor.
target (bool): Target label.
Returns:
Tensor: wgan loss.
"""
return F.softplus(-input).mean() if target else F.softplus(input).mean()
def get_target_label(self, input, target_is_real):
"""Get target label.
Args:
input (Tensor): Input tensor.
target_is_real (bool): Whether the target is real or fake.
Returns:
(bool | Tensor): Target tensor. Return bool for wgan, otherwise,
return Tensor.
"""
if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
target_val = (self.real_label_val if target_is_real else self.fake_label_val)
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
"""
Args:
input (Tensor): The input for the loss module, i.e., the network
prediction.
target_is_real (bool): Whether the targe is real or fake.
is_disc (bool): Whether the loss for discriminators or not.
Default: False.
Returns:
Tensor: GAN loss value.
"""
target_label = self.get_target_label(input, target_is_real)
if self.gan_type == 'hinge':
if is_disc: # for discriminators in hinge-gan
input = -input if target_is_real else input
loss = self.loss(1 + input).mean()
else: # for generators in hinge-gan
loss = -input.mean()
else: # other gan types
loss = self.loss(input, target_label)
# loss_weight is always 1.0 for discriminators
return loss if is_disc else loss * self.loss_weight