File size: 4,300 Bytes
47c46ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from argparse import (
    ArgumentParser,
    Namespace,
)

import torch
from torch import nn
from torch.nn import functional as F

from utils.misc import optional_string

from .gaussian_smoothing import GaussianSmoothing


class DegradeArguments:
    @staticmethod
    def add_arguments(parser: ArgumentParser):
        parser.add_argument('--spectral_sensitivity', choices=["g", "b", "gb"], default="g",
            help="Type of spectral sensitivity. g: grayscale (panchromatic), b: blue-sensitive, gb: green+blue (orthochromatic)")
        parser.add_argument('--gaussian', type=float, default=0,
            help="estimated blur radius in pixels of the input photo if it is scaled to 1024x1024")

    @staticmethod
    def to_string(args: Namespace) -> str:
        return (
            f"{args.spectral_sensitivity}"
            + optional_string(args.gaussian > 0, f"-G{args.gaussian}")
        )


class CameraResponse(nn.Module):
    def __init__(self):
        super().__init__()

        self.register_parameter("gamma", nn.Parameter(torch.ones(1)))
        self.register_parameter("offset", nn.Parameter(torch.zeros(1)))
        self.register_parameter("gain", nn.Parameter(torch.ones(1)))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = torch.clamp(x, max=1, min=-1+1e-2)
        x = (1 + x) * 0.5
        x = self.offset + self.gain * torch.pow(x, self.gamma)
        x = (x - 0.5) * 2
        # b = torch.clamp(b, max=1, min=-1)
        return x


class SpectralResponse(nn.Module):
    # TODO: use enum instead for color mode
    def __init__(self, spectral_sensitivity: str = 'b'):
        assert spectral_sensitivity in ("g", "b", "gb"), f"spectral_sensitivity {spectral_sensitivity} is not implemented."

        super().__init__()

        self.spectral_sensitivity = spectral_sensitivity

        if self.spectral_sensitivity == "g":
            self.register_buffer("to_gray", torch.tensor([0.299, 0.587, 0.114]).reshape(1, -1, 1, 1))

    def forward(self, rgb: torch.Tensor) -> torch.Tensor:
        if self.spectral_sensitivity == "b":
            x = rgb[:, -1:]
        elif self.spectral_sensitivity == "gb":
            x = (rgb[:, 1:2] + rgb[:, -1:]) * 0.5
        else:
            assert self.spectral_sensitivity == "g"
            x = (rgb * self.to_gray).sum(dim=1, keepdim=True)
        return x


class Downsample(nn.Module):
    """Antialiasing downsampling"""
    def __init__(self, input_size: int, output_size: int, channels: int):
        super().__init__()
        if input_size % output_size == 0:
            self.stride = input_size // output_size
            self.grid = None
        else:
            self.stride = 1
            step = input_size / output_size
            x = torch.arange(output_size) * step
            Y, X = torch.meshgrid(x, x)
            grid = torch.stack((X, Y), dim=-1)
            grid /= torch.Tensor((input_size - 1, input_size - 1)).view(1, 1, -1)
            grid = grid * 2 - 1
            self.register_buffer("grid", grid)
        sigma = 0.5 * input_size / output_size
        #print(f"{input_size} -> {output_size}: sigma={sigma}")
        self.blur = GaussianSmoothing(channels, int(2 * (sigma * 2) + 1 + 0.5), sigma)

    def forward(self, im: torch.Tensor):
        out = self.blur(im, stride=self.stride)
        if self.grid is not None:
            out = F.grid_sample(out, self.grid[None].expand(im.shape[0], -1, -1, -1))
        return out



class Degrade(nn.Module):
    """
    Simulate the degradation of antique film
    """
    def __init__(self, args:Namespace):
        super().__init__()
        self.srf = SpectralResponse(args.spectral_sensitivity)
        self.crf = CameraResponse()
        self.gaussian = None
        if args.gaussian is not None and args.gaussian > 0:
            self.gaussian = GaussianSmoothing(3, 2 * int(args.gaussian * 2 + 0.5) + 1, args.gaussian)

    def forward(self, img: torch.Tensor, downsample: nn.Module = None):
        if self.gaussian is not None:
            img = self.gaussian(img)
        if downsample is not None:
            img = downsample(img)
        img = self.srf(img)
        img = self.crf(img)
        # Note that I changed it back to 3 channels
        return img.repeat((1, 3, 1, 1)) if img.shape[1] == 1 else img