File size: 4,104 Bytes
8ec10cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch.nn as nn

import torch
from . import common

from lambda_networks import LambdaLayer


def build_model(args):
    return ResNet(args)


class ConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(ConvGRU, self).__init__()
        self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
        self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
        self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)

    def forward(self, h, x):
        hx = torch.cat([h, x], dim=1)

        z = torch.sigmoid(self.convz(hx))
        r = torch.sigmoid(self.convr(hx))
        q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))

        # h = (1-z) * h + z * q
        # return h
        return (1-z) * h + z * q

class SepConvGRU(nn.Module):
    def __init__(self, hidden_dim=128, input_dim=192+128):
        super(SepConvGRU, self).__init__()
        self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
        self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))

        self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
        self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))


    def forward(self, h, x):
        # horizontal
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz1(hx))
        r = torch.sigmoid(self.convr1(hx))
        q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))        
        h = (1-z) * h + z * q

        # vertical
        hx = torch.cat([h, x], dim=1)
        z = torch.sigmoid(self.convz2(hx))
        r = torch.sigmoid(self.convr2(hx))
        q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))       
        h = (1-z) * h + z * q

        return h


class ResNet(nn.Module):
    def __init__(

        self,

        args

    ):
        super(ResNet, self).__init__()

        self.in_channels = 3
        self.out_channels = 3

        self.rgb_range = args.rgb_range
        self.mean = self.rgb_range / 2

        self.n_feats = args.n_feats
        self.kernel_size = args.kernel_size
        self.n_resblocks = args.n_resblocks

        self.recurrence = args.n_scales

        modules = []
        m_head=[common.default_conv(self.in_channels, self.n_feats, self.kernel_size)]
        for i in range(3):
            m_head.append(common.ResBlock(self.n_feats, self.kernel_size))
        for _ in range(self.n_resblocks // 2):
            modules.append(common.ResBlock(self.n_feats, self.kernel_size))
        modules.append(
            LambdaLayer(
                dim=self.n_feats, dim_out=self.n_feats, r=23, dim_k=16, heads=4, dim_u=4
            )
        )
        for _ in range(self.n_resblocks // 2):
            modules.append(common.ResBlock(self.n_feats, self.kernel_size))
        m_tail=[]
        
        for i in range(3):
            m_tail.append(common.ResBlock(self.n_feats, self.kernel_size))
        
        m_tail.append(
            common.default_conv(self.n_feats, self.out_channels, self.kernel_size)
        )
        self.head=nn.Sequential(*m_head)
        self.body = nn.Sequential(*modules)
        self.tail=nn.Sequential(*m_tail)
        self.gru=SepConvGRU(hidden_dim=self.n_feats,input_dim=self.n_feats)

    def forward(self, input):
        input = input[0] - self.mean
        input=self.head(input)
        hidden=input.clone()
        output_lst=[None]*self.recurrence
        for i in range(self.recurrence):
            gru_out=self.gru(hidden,input)
            res=self.body(gru_out)
            gru_out=res+gru_out
            hidden=gru_out
            tail_out=self.tail(gru_out)
            output_lst[i] = self.tail(gru_out) + self.mean
        return output_lst