yangheng's picture
init
9842c28
raw
history blame
12.5 kB
import numpy as np
import random
import torch
from basicsr.data.degradations import (
random_add_gaussian_noise_pt,
random_add_poisson_noise_pt,
)
from basicsr.data.transforms import paired_random_crop
from basicsr.models.srgan_model import SRGANModel
from basicsr.utils import DiffJPEG, USMSharp
from basicsr.utils.img_process_util import filter2D
from basicsr.utils.registry import MODEL_REGISTRY
from collections import OrderedDict
from torch.nn import functional as F
@MODEL_REGISTRY.register()
class RealESRGANModel(SRGANModel):
"""RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
It mainly performs:
1. randomly synthesize LQ images in GPU tensors
2. optimize the networks with GAN training.
"""
def __init__(self, opt):
super(RealESRGANModel, self).__init__(opt)
self.jpeger = DiffJPEG(
differentiable=False
).cuda() # simulate JPEG compression artifacts
self.usm_sharpener = USMSharp().cuda() # do usm sharpening
self.queue_size = opt.get("queue_size", 180)
@torch.no_grad()
def _dequeue_and_enqueue(self):
"""It is the training pair pool for increasing the diversity in a batch.
Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a
batch could not have different resize scaling factors. Therefore, we employ this training pair pool
to increase the degradation diversity in a batch.
"""
# initialize
b, c, h, w = self.lq.size()
if not hasattr(self, "queue_lr"):
assert (
self.queue_size % b == 0
), f"queue size {self.queue_size} should be divisible by batch size {b}"
self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda()
_, c, h, w = self.gt.size()
self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda()
self.queue_ptr = 0
if self.queue_ptr == self.queue_size: # the pool is full
# do dequeue and enqueue
# shuffle
idx = torch.randperm(self.queue_size)
self.queue_lr = self.queue_lr[idx]
self.queue_gt = self.queue_gt[idx]
# get first b samples
lq_dequeue = self.queue_lr[0:b, :, :, :].clone()
gt_dequeue = self.queue_gt[0:b, :, :, :].clone()
# update the queue
self.queue_lr[0:b, :, :, :] = self.lq.clone()
self.queue_gt[0:b, :, :, :] = self.gt.clone()
self.lq = lq_dequeue
self.gt = gt_dequeue
else:
# only do enqueue
self.queue_lr[
self.queue_ptr : self.queue_ptr + b, :, :, :
] = self.lq.clone()
self.queue_gt[
self.queue_ptr : self.queue_ptr + b, :, :, :
] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
"""Accept data from dataloader, and then add two-order degradations to obtain LQ images."""
if self.is_train and self.opt.get("high_order_degradation", True):
# training data synthesis
self.gt = data["gt"].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
self.kernel1 = data["kernel1"].to(self.device)
self.kernel2 = data["kernel2"].to(self.device)
self.sinc_kernel = data["sinc_kernel"].to(self.device)
ori_h, ori_w = self.gt.size()[2:4]
# ----------------------- The first degradation process ----------------------- #
# blur
out = filter2D(self.gt_usm, self.kernel1)
# random resize
updown_type = random.choices(
["up", "down", "keep"], self.opt["resize_prob"]
)[0]
if updown_type == "up":
scale = np.random.uniform(1, self.opt["resize_range"][1])
elif updown_type == "down":
scale = np.random.uniform(self.opt["resize_range"][0], 1)
else:
scale = 1
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(out, scale_factor=scale, mode=mode)
# add noise
gray_noise_prob = self.opt["gray_noise_prob"]
if np.random.uniform() < self.opt["gaussian_noise_prob"]:
out = random_add_gaussian_noise_pt(
out,
sigma_range=self.opt["noise_range"],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt["poisson_scale_range"],
gray_prob=gray_noise_prob,
clip=True,
rounds=False,
)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range"])
out = torch.clamp(
out, 0, 1
) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
out = self.jpeger(out, quality=jpeg_p)
# ----------------------- The second degradation process ----------------------- #
# blur
if np.random.uniform() < self.opt["second_blur_prob"]:
out = filter2D(out, self.kernel2)
# random resize
updown_type = random.choices(
["up", "down", "keep"], self.opt["resize_prob2"]
)[0]
if updown_type == "up":
scale = np.random.uniform(1, self.opt["resize_range2"][1])
elif updown_type == "down":
scale = np.random.uniform(self.opt["resize_range2"][0], 1)
else:
scale = 1
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(
out,
size=(
int(ori_h / self.opt["scale"] * scale),
int(ori_w / self.opt["scale"] * scale),
),
mode=mode,
)
# add noise
gray_noise_prob = self.opt["gray_noise_prob2"]
if np.random.uniform() < self.opt["gaussian_noise_prob2"]:
out = random_add_gaussian_noise_pt(
out,
sigma_range=self.opt["noise_range2"],
clip=True,
rounds=False,
gray_prob=gray_noise_prob,
)
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt["poisson_scale_range2"],
gray_prob=gray_noise_prob,
clip=True,
rounds=False,
)
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
# as one operation.
# We consider two orders:
# 1. [resize back + sinc filter] + JPEG compression
# 2. JPEG compression + [resize back + sinc filter]
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines.
if np.random.uniform() < 0.5:
# resize back + the final sinc filter
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(
out,
size=(ori_h // self.opt["scale"], ori_w // self.opt["scale"]),
mode=mode,
)
out = filter2D(out, self.sinc_kernel)
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range2"])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
else:
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range2"])
out = torch.clamp(out, 0, 1)
out = self.jpeger(out, quality=jpeg_p)
# resize back + the final sinc filter
mode = random.choice(["area", "bilinear", "bicubic"])
out = F.interpolate(
out,
size=(ori_h // self.opt["scale"], ori_w // self.opt["scale"]),
mode=mode,
)
out = filter2D(out, self.sinc_kernel)
# clamp and round
self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
# random crop
gt_size = self.opt["gt_size"]
(self.gt, self.gt_usm), self.lq = paired_random_crop(
[self.gt, self.gt_usm], self.lq, gt_size, self.opt["scale"]
)
# training pair pool
self._dequeue_and_enqueue()
# sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue
self.gt_usm = self.usm_sharpener(self.gt)
self.lq = (
self.lq.contiguous()
) # for the warning: grad and param do not obey the gradient layout contract
else:
# for paired training or validation
self.lq = data["lq"].to(self.device)
if "gt" in data:
self.gt = data["gt"].to(self.device)
self.gt_usm = self.usm_sharpener(self.gt)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
# do not use the synthetic process during validation
self.is_train = False
super(RealESRGANModel, self).nondist_validation(
dataloader, current_iter, tb_logger, save_img
)
self.is_train = True
def optimize_parameters(self, current_iter):
# usm sharpening
l1_gt = self.gt_usm
percep_gt = self.gt_usm
gan_gt = self.gt_usm
if self.opt["l1_gt_usm"] is False:
l1_gt = self.gt
if self.opt["percep_gt_usm"] is False:
percep_gt = self.gt
if self.opt["gan_gt_usm"] is False:
gan_gt = self.gt
# optimize net_g
for p in self.net_d.parameters():
p.requires_grad = False
self.optimizer_g.zero_grad()
self.output = self.net_g(self.lq)
l_g_total = 0
loss_dict = OrderedDict()
if (
current_iter % self.net_d_iters == 0
and current_iter > self.net_d_init_iters
):
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, l1_gt)
l_g_total += l_g_pix
loss_dict["l_g_pix"] = l_g_pix
# perceptual loss
if self.cri_perceptual:
l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt)
if l_g_percep is not None:
l_g_total += l_g_percep
loss_dict["l_g_percep"] = l_g_percep
if l_g_style is not None:
l_g_total += l_g_style
loss_dict["l_g_style"] = l_g_style
# gan loss
fake_g_pred = self.net_d(self.output)
l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False)
l_g_total += l_g_gan
loss_dict["l_g_gan"] = l_g_gan
l_g_total.backward()
self.optimizer_g.step()
# optimize net_d
for p in self.net_d.parameters():
p.requires_grad = True
self.optimizer_d.zero_grad()
# real
real_d_pred = self.net_d(gan_gt)
l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
loss_dict["l_d_real"] = l_d_real
loss_dict["out_d_real"] = torch.mean(real_d_pred.detach())
l_d_real.backward()
# fake
fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9
l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)
loss_dict["l_d_fake"] = l_d_fake
loss_dict["out_d_fake"] = torch.mean(fake_d_pred.detach())
l_d_fake.backward()
self.optimizer_d.step()
if self.ema_decay > 0:
self.model_ema(decay=self.ema_decay)
self.log_dict = self.reduce_loss_dict(loss_dict)