Spaces:
Running
on
Zero
Running
on
Zero
import os | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "4" | |
import re | |
import cv2 | |
import einops | |
import numpy as np | |
import torch | |
import random | |
import math | |
from PIL import Image, ImageDraw, ImageFont | |
import shutil | |
import glob | |
from tqdm import tqdm | |
import subprocess as sp | |
import argparse | |
import imageio | |
import sys | |
import json | |
import datetime | |
import string | |
from dataset.opencv_transforms.functional import to_tensor, center_crop | |
from pytorch_lightning import seed_everything | |
from sgm.util import append_dims | |
from sgm.util import autocast, instantiate_from_config | |
from vtdm.model import create_model, load_state_dict | |
from vtdm.util import tensor2vid, export_to_video | |
from einops import rearrange | |
import yaml | |
import numpy as np | |
import random | |
import torch | |
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt | |
from basicsr.data.transforms import paired_random_crop | |
from basicsr.models.sr_model import SRModel | |
from basicsr.utils import DiffJPEG, USMSharp | |
from basicsr.utils.img_process_util import filter2D | |
from basicsr.utils.registry import MODEL_REGISTRY | |
from torch.nn import functional as F | |
import torch.nn as nn | |
class DegradedImages(torch.nn.Module): | |
def __init__(self, freeze=True): | |
super().__init__() | |
with open('configs/train_realesrnet_x4plus.yml', mode='r') as f: | |
opt = yaml.load(f, Loader=yaml.FullLoader) | |
self.opt = opt | |
def forward(self, images, videos, masks, kernel1s, kernel2s, sinc_kernels): | |
''' | |
images: (2, 3, 1024, 1024) [-1, 1] | |
videos: (2, 3, 16, 1024, 1024) [-1, 1] | |
masks: (2, 16, 1024, 1024) | |
kernel1s, kernel2s, sinc_kernels: (2, 16, 21, 21) | |
''' | |
self.jpeger = DiffJPEG(differentiable=False).cuda() | |
B, C, H, W = images.shape | |
ori_h, ori_w = videos.size()[3:5] | |
videos = videos / 2.0 + 0.5 | |
videos = rearrange(videos, 'b c t h w -> b t c h w') #(2, 16, 3, 1024, 1024) | |
all_lqs = [] | |
for i in range(B): | |
kernel1 = kernel1s[i] | |
kernel2 = kernel2s[i] | |
sinc_kernel = sinc_kernels[i] | |
gt = videos[i] # (16, 3, 1024, 1024) | |
mask = masks[i] # (16, 1024, 1024) | |
# ----------------------- The first degradation process ----------------------- # | |
# blur | |
out = filter2D(gt, kernel1) | |
# random resize | |
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] | |
if updown_type == 'up': | |
scale = np.random.uniform(1, self.opt['resize_range'][1]) | |
elif updown_type == 'down': | |
scale = np.random.uniform(self.opt['resize_range'][0], 1) | |
else: | |
scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, scale_factor=scale, mode=mode) | |
# add noise | |
gray_noise_prob = self.opt['gray_noise_prob'] | |
if np.random.uniform() < self.opt['gaussian_noise_prob']: | |
out = random_add_gaussian_noise_pt( | |
out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) | |
else: | |
out = random_add_poisson_noise_pt( | |
out, | |
scale_range=self.opt['poisson_scale_range'], | |
gray_prob=gray_noise_prob, | |
clip=True, | |
rounds=False) | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) | |
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts | |
out = self.jpeger(out, quality=jpeg_p) | |
# ----------------------- The second degradation process ----------------------- # | |
# blur | |
if np.random.uniform() < self.opt['second_blur_prob']: | |
out = filter2D(out, kernel2) | |
# random resize | |
updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] | |
if updown_type == 'up': | |
scale = np.random.uniform(1, self.opt['resize_range2'][1]) | |
elif updown_type == 'down': | |
scale = np.random.uniform(self.opt['resize_range2'][0], 1) | |
else: | |
scale = 1 | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate( | |
out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) | |
# add noise | |
gray_noise_prob = self.opt['gray_noise_prob2'] | |
if np.random.uniform() < self.opt['gaussian_noise_prob2']: | |
out = random_add_gaussian_noise_pt( | |
out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) | |
else: | |
out = random_add_poisson_noise_pt( | |
out, | |
scale_range=self.opt['poisson_scale_range2'], | |
gray_prob=gray_noise_prob, | |
clip=True, | |
rounds=False) | |
# JPEG compression + the final sinc filter | |
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together | |
# as one operation. | |
# We consider two orders: | |
# 1. [resize back + sinc filter] + JPEG compression | |
# 2. JPEG compression + [resize back + sinc filter] | |
# Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. | |
if np.random.uniform() < 0.5: | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) | |
out = filter2D(out, sinc_kernel) | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
else: | |
# JPEG compression | |
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) | |
out = torch.clamp(out, 0, 1) | |
out = self.jpeger(out, quality=jpeg_p) | |
# resize back + the final sinc filter | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) | |
out = filter2D(out, sinc_kernel) | |
# clamp and round | |
lqs = torch.clamp((out * 255.0).round(), 0, 255) / 255. | |
mode = random.choice(['area', 'bilinear', 'bicubic']) | |
lqs = F.interpolate(lqs, size=(ori_h, ori_w), mode=mode) # 16,3,1024,1024 | |
lqs = rearrange(lqs, 't c h w -> t h w c') # 16, 1024, 1024, 3 | |
for j in range(16): | |
lqs[j][mask[j]==0] = 1.0 | |
all_lqs.append(lqs) | |
# import cv2 | |
# gt1 = gt[0] | |
# lq1 = lqs[0] | |
# gt1 = rearrange(gt1, 'c h w -> h w c') | |
# gt1 = (gt1.cpu().numpy() * 255.).astype('uint8') | |
# lq1 = (lq1.cpu().numpy() * 255.).astype('uint8') | |
# cv2.imwrite(f'gt{i}.png', gt1) | |
# cv2.imwrite(f'lq{i}.png', lq1) | |
all_lqs = [(f - 0.5) * 2.0 for f in all_lqs] | |
all_lqs = torch.stack(all_lqs, 0) # 2, 16, 1024, 1024, 3 | |
all_lqs = rearrange(all_lqs, 'b t h w c -> b t c h w') | |
for i in range(B): | |
all_lqs[i][0] = images[i] | |
all_lqs = rearrange(all_lqs, 'b t c h w -> (b t) c h w') | |
return all_lqs |