Spaces:
Runtime error
Runtime error
from argparse import ( | |
ArgumentParser, | |
Namespace, | |
) | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from utils.misc import optional_string | |
from .gaussian_smoothing import GaussianSmoothing | |
class DegradeArguments: | |
def add_arguments(parser: ArgumentParser): | |
parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g", | |
help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)") | |
parser.add_argument('--gaussian', type=float, default=0, | |
help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024") | |
def to_string(args: Namespace) -> str: | |
return ( | |
f"{args.spectral_sensitivity}" | |
+ optional_string(args.gaussian > 0, f"-G{args.gaussian}") | |
) | |
class CameraResponse(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.register_parameter("gamma", nn.Parameter(torch.ones(1))) | |
self.register_parameter("offset", nn.Parameter(torch.zeros(1))) | |
self.register_parameter("gain", nn.Parameter(torch.ones(1))) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
x = torch.clamp(x, max=1, min=-1+1e-2) | |
x = (1 + x) * 0.5 | |
x = self.offset + self.gain * torch.pow(x, self.gamma) | |
x = (x - 0.5) * 2 | |
# b = torch.clamp(b, max=1, min=-1) | |
return x | |
class SpectralResponse(nn.Module): | |
# TODO: use enum instead for color mode | |
def __init__(self, spectral_sensitivity: str = 'b'): | |
assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented." | |
super().__init__() | |
self.spectral_sensitivity = spectral_sensitivity | |
if self.spectral_sensitivity == "g": | |
self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1)) | |
def forward(self, rgb: torch.Tensor) -> torch.Tensor: | |
if self.spectral_sensitivity == "b": | |
x = rgb[:, -1:] | |
elif self.spectral_sensitivity == "gb": | |
x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5 | |
else: | |
assert self.spectral_sensitivity == "g" | |
x = (rgb * self.to_gray).sum(dim=1, keepdim=True) | |
return x | |
class Downsample(nn.Module): | |
"""Antialiasing downsampling""" | |
def __init__(self, input_size: int, output_size: int, channels: int): | |
super().__init__() | |
if input_size % output_size == 0: | |
self.stride = input_size // output_size | |
self.grid = None | |
else: | |
self.stride = 1 | |
step = input_size / output_size | |
x = torch.arange(output_size) * step | |
Y, X = torch.meshgrid(x, x) | |
grid = torch.stack((X, Y), dim=-1) | |
grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1) | |
grid = grid * 2 - 1 | |
self.register_buffer("grid", grid) | |
sigma = 0.5 * input_size / output_size | |
#print(f"{input_size} -> {output_size}: sigma={sigma}") | |
self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma) | |
def forward(self, im: torch.Tensor): | |
out = self.blur(im, stride=self.stride) | |
if self.grid is not None: | |
out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1)) | |
return out | |
class Degrade(nn.Module): | |
""" | |
Simulate the degradation of antique film | |
""" | |
def __init__(self, args:Namespace): | |
super().__init__() | |
self.srf = SpectralResponse(args.spectral_sensitivity) | |
self.crf = CameraResponse() | |
self.gaussian = None | |
if args.gaussian is not None and args.gaussian > 0: | |
self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian) | |
def forward(self, img: torch.Tensor, downsample: nn.Module = None): | |
if self.gaussian is not None: | |
img = self.gaussian(img) | |
if downsample is not None: | |
img = downsample(img) | |
img = self.srf(img) | |
img = self.crf(img) | |
# Note that I changed it back to 3 channels | |
return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img | |