File size: 2,197 Bytes
8ec10cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch.nn as nn

from . import common

from lambda_networks import LambdaLayer


def build_model(args):
    return ResNet(args)


class ResNet(nn.Module):
    def __init__(

        self,

        args,

        in_channels=3,

        out_channels=3,

        n_feats=None,

        kernel_size=None,

        n_resblocks=None,

        mean_shift=True,

    ):
        super(ResNet, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.n_feats = args.n_feats if n_feats is None else n_feats
        self.kernel_size = args.kernel_size if kernel_size is None else kernel_size
        self.n_resblocks = args.n_resblocks if n_resblocks is None else n_resblocks

        self.mean_shift = mean_shift
        self.rgb_range = args.rgb_range
        self.mean = self.rgb_range / 2

        modules = []
        modules.append(
            common.default_conv(self.in_channels, self.n_feats, self.kernel_size)
        )
        for _ in range(self.n_resblocks // 3):
            modules.append(common.ResBlock(self.n_feats, self.kernel_size))
        modules.append(
            LambdaLayer(
                dim=self.n_feats, dim_out=self.n_feats, r=23, dim_k=16, heads=4, dim_u=1
            )
        )
        for _ in range(self.n_resblocks // 3):
            modules.append(common.ResBlock(self.n_feats, self.kernel_size))
        modules.append(
            LambdaLayer(
                dim=self.n_feats, dim_out=self.n_feats, r=7, dim_k=16, heads=4, dim_u=4
            )
        )
        for _ in range(self.n_resblocks // 3):
            modules.append(common.ResBlock(self.n_feats, self.kernel_size))
        modules.append(
            common.default_conv(self.n_feats, self.n_feats, self.kernel_size)
        )
        modules.append(common.default_conv(self.n_feats, self.out_channels, 1))

        self.body = nn.Sequential(*modules)

    def forward(self, input):
        if self.mean_shift:
            input = input - self.mean

        output = self.body(input)

        if self.mean_shift:
            output = output + self.mean

        return output