import numpy as np import random import torch from basicsr.data.degradations import ( random_add_gaussian_noise_pt, random_add_poisson_noise_pt, ) from basicsr.data.transforms import paired_random_crop from basicsr.models.srgan_model import SRGANModel from basicsr.utils import DiffJPEG, USMSharp from basicsr.utils.img_process_util import filter2D from basicsr.utils.registry import MODEL_REGISTRY from collections import OrderedDict from torch.nn import functional as F @MODEL_REGISTRY.register() class RealESRGANModel(SRGANModel): """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. It mainly performs: 1. randomly synthesize LQ images in GPU tensors 2. optimize the networks with GAN training. """ def __init__(self, opt): super(RealESRGANModel, self).__init__(opt) self.jpeger = DiffJPEG( differentiable=False ).cuda() # simulate JPEG compression artifacts self.usm_sharpener = USMSharp().cuda() # do usm sharpening self.queue_size = opt.get("queue_size", 180) @torch.no_grad() def _dequeue_and_enqueue(self): """It is the training pair pool for increasing the diversity in a batch. Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a batch could not have different resize scaling factors. Therefore, we employ this training pair pool to increase the degradation diversity in a batch. """ # initialize b, c, h, w = self.lq.size() if not hasattr(self, "queue_lr"): assert ( self.queue_size % b == 0 ), f"queue size {self.queue_size} should be divisible by batch size {b}" self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() _, c, h, w = self.gt.size() self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() self.queue_ptr = 0 if self.queue_ptr == self.queue_size: # the pool is full # do dequeue and enqueue # shuffle idx = torch.randperm(self.queue_size) self.queue_lr = self.queue_lr[idx] self.queue_gt = self.queue_gt[idx] # get first b samples lq_dequeue = self.queue_lr[0:b, :, :, :].clone() gt_dequeue = self.queue_gt[0:b, :, :, :].clone() # update the queue self.queue_lr[0:b, :, :, :] = self.lq.clone() self.queue_gt[0:b, :, :, :] = self.gt.clone() self.lq = lq_dequeue self.gt = gt_dequeue else: # only do enqueue self.queue_lr[ self.queue_ptr : self.queue_ptr + b, :, :, : ] = self.lq.clone() self.queue_gt[ self.queue_ptr : self.queue_ptr + b, :, :, : ] = self.gt.clone() self.queue_ptr = self.queue_ptr + b @torch.no_grad() def feed_data(self, data): """Accept data from dataloader, and then add two-order degradations to obtain LQ images.""" if self.is_train and self.opt.get("high_order_degradation", True): # training data synthesis self.gt = data["gt"].to(self.device) self.gt_usm = self.usm_sharpener(self.gt) self.kernel1 = data["kernel1"].to(self.device) self.kernel2 = data["kernel2"].to(self.device) self.sinc_kernel = data["sinc_kernel"].to(self.device) ori_h, ori_w = self.gt.size()[2:4] # ----------------------- The first degradation process ----------------------- # # blur out = filter2D(self.gt_usm, self.kernel1) # random resize updown_type = random.choices( ["up", "down", "keep"], self.opt["resize_prob"] )[0] if updown_type == "up": scale = np.random.uniform(1, self.opt["resize_range"][1]) elif updown_type == "down": scale = np.random.uniform(self.opt["resize_range"][0], 1) else: scale = 1 mode = random.choice(["area", "bilinear", "bicubic"]) out = F.interpolate(out, scale_factor=scale, mode=mode) # add noise gray_noise_prob = self.opt["gray_noise_prob"] if np.random.uniform() < self.opt["gaussian_noise_prob"]: out = random_add_gaussian_noise_pt( out, sigma_range=self.opt["noise_range"], clip=True, rounds=False, gray_prob=gray_noise_prob, ) else: out = random_add_poisson_noise_pt( out, scale_range=self.opt["poisson_scale_range"], gray_prob=gray_noise_prob, clip=True, rounds=False, ) # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range"]) out = torch.clamp( out, 0, 1 ) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts out = self.jpeger(out, quality=jpeg_p) # ----------------------- The second degradation process ----------------------- # # blur if np.random.uniform() < self.opt["second_blur_prob"]: out = filter2D(out, self.kernel2) # random resize updown_type = random.choices( ["up", "down", "keep"], self.opt["resize_prob2"] )[0] if updown_type == "up": scale = np.random.uniform(1, self.opt["resize_range2"][1]) elif updown_type == "down": scale = np.random.uniform(self.opt["resize_range2"][0], 1) else: scale = 1 mode = random.choice(["area", "bilinear", "bicubic"]) out = F.interpolate( out, size=( int(ori_h / self.opt["scale"] * scale), int(ori_w / self.opt["scale"] * scale), ), mode=mode, ) # add noise gray_noise_prob = self.opt["gray_noise_prob2"] if np.random.uniform() < self.opt["gaussian_noise_prob2"]: out = random_add_gaussian_noise_pt( out, sigma_range=self.opt["noise_range2"], clip=True, rounds=False, gray_prob=gray_noise_prob, ) else: out = random_add_poisson_noise_pt( out, scale_range=self.opt["poisson_scale_range2"], gray_prob=gray_noise_prob, clip=True, rounds=False, ) # JPEG compression + the final sinc filter # We also need to resize images to desired sizes. We group [resize back + sinc filter] together # as one operation. # We consider two orders: # 1. [resize back + sinc filter] + JPEG compression # 2. JPEG compression + [resize back + sinc filter] # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. if np.random.uniform() < 0.5: # resize back + the final sinc filter mode = random.choice(["area", "bilinear", "bicubic"]) out = F.interpolate( out, size=(ori_h // self.opt["scale"], ori_w // self.opt["scale"]), mode=mode, ) out = filter2D(out, self.sinc_kernel) # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range2"]) out = torch.clamp(out, 0, 1) out = self.jpeger(out, quality=jpeg_p) else: # JPEG compression jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt["jpeg_range2"]) out = torch.clamp(out, 0, 1) out = self.jpeger(out, quality=jpeg_p) # resize back + the final sinc filter mode = random.choice(["area", "bilinear", "bicubic"]) out = F.interpolate( out, size=(ori_h // self.opt["scale"], ori_w // self.opt["scale"]), mode=mode, ) out = filter2D(out, self.sinc_kernel) # clamp and round self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.0 # random crop gt_size = self.opt["gt_size"] (self.gt, self.gt_usm), self.lq = paired_random_crop( [self.gt, self.gt_usm], self.lq, gt_size, self.opt["scale"] ) # training pair pool self._dequeue_and_enqueue() # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue self.gt_usm = self.usm_sharpener(self.gt) self.lq = ( self.lq.contiguous() ) # for the warning: grad and param do not obey the gradient layout contract else: # for paired training or validation self.lq = data["lq"].to(self.device) if "gt" in data: self.gt = data["gt"].to(self.device) self.gt_usm = self.usm_sharpener(self.gt) def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): # do not use the synthetic process during validation self.is_train = False super(RealESRGANModel, self).nondist_validation( dataloader, current_iter, tb_logger, save_img ) self.is_train = True def optimize_parameters(self, current_iter): # usm sharpening l1_gt = self.gt_usm percep_gt = self.gt_usm gan_gt = self.gt_usm if self.opt["l1_gt_usm"] is False: l1_gt = self.gt if self.opt["percep_gt_usm"] is False: percep_gt = self.gt if self.opt["gan_gt_usm"] is False: gan_gt = self.gt # optimize net_g for p in self.net_d.parameters(): p.requires_grad = False self.optimizer_g.zero_grad() self.output = self.net_g(self.lq) l_g_total = 0 loss_dict = OrderedDict() if ( current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters ): # pixel loss if self.cri_pix: l_g_pix = self.cri_pix(self.output, l1_gt) l_g_total += l_g_pix loss_dict["l_g_pix"] = l_g_pix # perceptual loss if self.cri_perceptual: l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) if l_g_percep is not None: l_g_total += l_g_percep loss_dict["l_g_percep"] = l_g_percep if l_g_style is not None: l_g_total += l_g_style loss_dict["l_g_style"] = l_g_style # gan loss fake_g_pred = self.net_d(self.output) l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) l_g_total += l_g_gan loss_dict["l_g_gan"] = l_g_gan l_g_total.backward() self.optimizer_g.step() # optimize net_d for p in self.net_d.parameters(): p.requires_grad = True self.optimizer_d.zero_grad() # real real_d_pred = self.net_d(gan_gt) l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) loss_dict["l_d_real"] = l_d_real loss_dict["out_d_real"] = torch.mean(real_d_pred.detach()) l_d_real.backward() # fake fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) loss_dict["l_d_fake"] = l_d_fake loss_dict["out_d_fake"] = torch.mean(fake_d_pred.detach()) l_d_fake.backward() self.optimizer_d.step() if self.ema_decay > 0: self.model_ema(decay=self.ema_decay) self.log_dict = self.reduce_loss_dict(loss_dict)