XiangZ commited on
Commit
09dc1a4
·
verified ·
1 Parent(s): 6131e4f

Upload 23 files

Browse files
README.md CHANGED
@@ -1,3 +1,78 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - HiT-SR
4
+ - image super-resolution
5
+ - transformer
6
+ ---
7
+
8
+ <h1>
9
+ HiT-SR: Hierarchical Transformer <br> for Efficient Image Super-Resolution
10
+ </h1>
11
+
12
+ <h3><a href="https://github.com/XiangZ-0/HiT-SR">[Github]</a> | <a href="https://1drv.ms/b/c/de821e161e64ce08/EVsrOr1-PFFMsXxiRHEmKeoBSH6DPkTuN2GRmEYsl9bvDQ?e=f9wGUO">[Paper]</a> | <a href="https://1drv.ms/b/c/de821e161e64ce08/EYmRy-QOjPdFsMRT_ElKQqABYzoIIfDtkt9hofZ5YY_GjQ?e=2Iapqf">[Supp]</a> | <a href="https://www.youtube.com/watch?v=9rO0pjmmjZg">[Video]</a> | <a href="https://1drv.ms/f/c/de821e161e64ce08/EuE6xW-sN-hFgkIa6J-Y8gkB9b4vDQZQ01r1ZP1lmzM0vQ?e=aIRfCQ">[Visual Results]</a> </h3>
13
+ <div></div>
14
+
15
+ HiT-SR is a general strategy to improve transformer-based SR methods. We apply our HiT-SR approach to improve [SwinIR-Light](https://github.com/JingyunLiang/SwinIR), [SwinIR-NG](https://github.com/rami0205/NGramSwin) and [SRFormer-Light](https://github.com/HVision-NKU/SRFormer), corresponding to our HiT-SIR, HiT-SNG, and HiT-SRF. Compared with the original structure, our improved models achieve better SR performance while reducing computational burdens.
16
+
17
+ ## 🚀 Models
18
+ For each HiT-SR model, we provide 2x, 3x, 4x upscaling versions:
19
+ | Repo Name | | Model | | Upscale |
20
+ |-------------------|---|---------|---|---------|
21
+ | `XiangZ/hit-sir-2x` | | HiT-SIR | | 2x |
22
+ | `XiangZ/hit-sir-3x` | | HiT-SIR | | 3x |
23
+ | `XiangZ/hit-sir-4x` | | HiT-SIR | | 4x |
24
+ | `XiangZ/hit-sng-2x` | | HiT-SNG | | 2x |
25
+ | `XiangZ/hit-sng-3x` | | HiT-SNG | | 3x |
26
+ | `XiangZ/hit-sng-4x` | | HiT-SNG | | 4x |
27
+ | `XiangZ/hit-srf-2x` | | HiT-SNG | | 2x |
28
+ | `XiangZ/hit-srf-3x` | | HiT-SRF | | 3x |
29
+ | `XiangZ/hit-srf-4x` | | HiT-SRF | | 4x |
30
+
31
+
32
+ ## 🛠️ Setup
33
+ Install the dependencies under the working directory (use hit-srf-4x as an example):
34
+ ```
35
+ git clone https://huggingface.co/XiangZ/hit-srf-4x
36
+ cd hit-srf-4x
37
+ pip install -r requirements.txt
38
+ ```
39
+
40
+ ## 🚀 Usage
41
+
42
+ To test the model:
43
+ ```
44
+ from hit_sir_arch import HiT_SIR
45
+ from hit_sng_arch import HiT_SNG
46
+ from hit_srf_arch import HiT_SRF
47
+ import cv2
48
+
49
+ # use GPU (True) or CPU (False)
50
+ cuda_flag = True
51
+
52
+ # initialize model (change model and upscale according to your setting)
53
+ model = HiT_SRF(upscale=4)
54
+
55
+ # load model (change repo_name according to your setting)
56
+ repo_name = "XiangZ/hit-srf-4x"
57
+ model = model.from_pretrained(repo_name)
58
+ if cuda_flag:
59
+ model.cuda()
60
+
61
+ ## test and save results
62
+ image_path = "path-to-input-image"
63
+ sr_results = model.infer_image(image_path, cuda=cuda_flag)
64
+ cv2.imwrite("path-to-output-location", sr_results)
65
+ ```
66
+
67
+ ## 📎 Citation
68
+
69
+ If you find the code helpful in your research or work, please consider citing the following paper.
70
+
71
+ ```
72
+ @inproceedings{zhang2024hitsr,
73
+ title={HiT-SR: Hierarchical Transformer for Efficient Image Super-Resolution},
74
+ author={Zhang, Xiang and Zhang, Yulun and Yu, Fisher},
75
+ booktitle={ECCV},
76
+ year={2024}
77
+ }
78
+ ```
hit_sir_arch.py ADDED
@@ -0,0 +1,900 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+ import numpy as np
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ from utils import FileClient, imfrombytes, img2tensor, tensor2img
11
+
12
+ class DFE(nn.Module):
13
+ """ Dual Feature Extraction
14
+ Args:
15
+ in_features (int): Number of input channels.
16
+ out_features (int): Number of output channels.
17
+ """
18
+ def __init__(self, in_features, out_features):
19
+ super().__init__()
20
+
21
+ self.out_features = out_features
22
+
23
+ self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
24
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
+ nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
26
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
+ nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
28
+
29
+ self.linear = nn.Conv2d(in_features, out_features,1,1,0)
30
+
31
+ def forward(self, x, x_size):
32
+
33
+ B, L, C = x.shape
34
+ H, W = x_size
35
+ x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
36
+ x = self.conv(x) * self.linear(x)
37
+ x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
38
+
39
+ return x
40
+
41
+ class Mlp(nn.Module):
42
+ """ MLP-based Feed-Forward Network
43
+ Args:
44
+ in_features (int): Number of input channels.
45
+ hidden_features (int | None): Number of hidden channels. Default: None
46
+ out_features (int | None): Number of output channels. Default: None
47
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
48
+ drop (float): Dropout rate. Default: 0.0
49
+ """
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ def window_partition(x, window_size):
69
+ """
70
+ Args:
71
+ x: (B, H, W, C)
72
+ window_size (tuple): window size
73
+
74
+ Returns:
75
+ windows: (num_windows*B, window_size, window_size, C)
76
+ """
77
+ B, H, W, C = x.shape
78
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
79
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
80
+ return windows
81
+
82
+
83
+ def window_reverse(windows, window_size, H, W):
84
+ """
85
+ Args:
86
+ windows: (num_windows*B, window_size, window_size, C)
87
+ window_size (tuple): Window size
88
+ H (int): Height of image
89
+ W (int): Width of image
90
+
91
+ Returns:
92
+ x: (B, H, W, C)
93
+ """
94
+ B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
95
+ x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
96
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
97
+ return x
98
+
99
+ class DynamicPosBias(nn.Module):
100
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
101
+ """ Dynamic Relative Position Bias.
102
+ Args:
103
+ dim (int): Number of input channels.
104
+ num_heads (int): Number of heads for spatial self-correlation.
105
+ residual (bool): If True, use residual strage to connect conv.
106
+ """
107
+ def __init__(self, dim, num_heads, residual):
108
+ super().__init__()
109
+ self.residual = residual
110
+ self.num_heads = num_heads
111
+ self.pos_dim = dim // 4
112
+ self.pos_proj = nn.Linear(2, self.pos_dim)
113
+ self.pos1 = nn.Sequential(
114
+ nn.LayerNorm(self.pos_dim),
115
+ nn.ReLU(inplace=True),
116
+ nn.Linear(self.pos_dim, self.pos_dim),
117
+ )
118
+ self.pos2 = nn.Sequential(
119
+ nn.LayerNorm(self.pos_dim),
120
+ nn.ReLU(inplace=True),
121
+ nn.Linear(self.pos_dim, self.pos_dim)
122
+ )
123
+ self.pos3 = nn.Sequential(
124
+ nn.LayerNorm(self.pos_dim),
125
+ nn.ReLU(inplace=True),
126
+ nn.Linear(self.pos_dim, self.num_heads)
127
+ )
128
+ def forward(self, biases):
129
+ if self.residual:
130
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
131
+ pos = pos + self.pos1(pos)
132
+ pos = pos + self.pos2(pos)
133
+ pos = self.pos3(pos)
134
+ else:
135
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
136
+ return pos
137
+
138
+ class SCC(nn.Module):
139
+ """ Spatial-Channel Correlation.
140
+ Args:
141
+ dim (int): Number of input channels.
142
+ base_win_size (tuple[int]): The height and width of the base window.
143
+ window_size (tuple[int]): The height and width of the window.
144
+ num_heads (int): Number of heads for spatial self-correlation.
145
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
146
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
147
+ """
148
+
149
+ def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
150
+
151
+ super().__init__()
152
+ # parameters
153
+ self.dim = dim
154
+ self.window_size = window_size
155
+ self.num_heads = num_heads
156
+
157
+ # feature projection
158
+ self.qv = DFE(dim, dim)
159
+ self.proj = nn.Linear(dim, dim)
160
+
161
+ # dropout
162
+ self.value_drop = nn.Dropout(value_drop)
163
+ self.proj_drop = nn.Dropout(proj_drop)
164
+
165
+ # base window size
166
+ min_h = min(self.window_size[0], base_win_size[0])
167
+ min_w = min(self.window_size[1], base_win_size[1])
168
+ self.base_win_size = (min_h, min_w)
169
+
170
+ # normalization factor and spatial linear layer for S-SC
171
+ head_dim = dim // (2*num_heads)
172
+ self.scale = head_dim
173
+ self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
174
+
175
+ # define a parameter table of relative position bias
176
+ self.H_sp, self.W_sp = self.window_size
177
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
178
+
179
+ def spatial_linear_projection(self, x):
180
+ B, num_h, L, C = x.shape
181
+ H, W = self.window_size
182
+ map_H, map_W = self.base_win_size
183
+
184
+ x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
185
+ x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
186
+ return x
187
+
188
+ def spatial_self_correlation(self, q, v):
189
+
190
+ B, num_head, L, C = q.shape
191
+
192
+ # spatial projection
193
+ v = self.spatial_linear_projection(v)
194
+
195
+ # compute correlation map
196
+ corr_map = (q @ v.transpose(-2,-1)) / self.scale
197
+
198
+ # add relative position bias
199
+ # generate mother-set
200
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
201
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
202
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
203
+ rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
204
+ pos = self.pos(rpe_biases)
205
+
206
+ # select position bias
207
+ coords_h = torch.arange(self.H_sp, device=v.device)
208
+ coords_w = torch.arange(self.W_sp, device=v.device)
209
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
210
+ coords_flatten = torch.flatten(coords, 1)
211
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
212
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
213
+ relative_coords[:, :, 0] += self.H_sp - 1
214
+ relative_coords[:, :, 1] += self.W_sp - 1
215
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
216
+ relative_position_index = relative_coords.sum(-1)
217
+ relative_position_bias = pos[relative_position_index.view(-1)].view(
218
+ self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
219
+ relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
220
+ self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
221
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
222
+ corr_map = corr_map + relative_position_bias.unsqueeze(0)
223
+
224
+ # transformation
225
+ v_drop = self.value_drop(v)
226
+ x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
227
+
228
+ return x
229
+
230
+ def channel_self_correlation(self, q, v):
231
+
232
+ B, num_head, L, C = q.shape
233
+
234
+ # apply single head strategy
235
+ q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
236
+ v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
237
+
238
+ # compute correlation map
239
+ corr_map = (q.transpose(-2,-1) @ v) / L
240
+
241
+ # transformation
242
+ v_drop = self.value_drop(v)
243
+ x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
244
+
245
+ return x
246
+
247
+ def forward(self, x):
248
+ """
249
+ Args:
250
+ x: input features with shape of (B, H, W, C)
251
+ """
252
+ xB,xH,xW,xC = x.shape
253
+ qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
254
+
255
+ # window partition
256
+ qv = window_partition(qv, self.window_size)
257
+ qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
258
+
259
+ # qv splitting
260
+ B, L, C = qv.shape
261
+ qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
262
+ q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
263
+
264
+ # spatial self-correlation (S-SC)
265
+ x_spatial = self.spatial_self_correlation(q, v)
266
+ x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
267
+ x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
268
+
269
+ # channel self-correlation (C-SC)
270
+ x_channel = self.channel_self_correlation(q, v)
271
+ x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
272
+ x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
273
+
274
+ # spatial-channel information fusion
275
+ x = torch.cat([x_spatial, x_channel], -1)
276
+ x = self.proj_drop(self.proj(x))
277
+
278
+ return x
279
+
280
+ def extra_repr(self) -> str:
281
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
282
+
283
+
284
+ class HierarchicalTransformerBlock(nn.Module):
285
+ """ Hierarchical Transformer Block.
286
+ Args:
287
+ dim (int): Number of input channels.
288
+ input_resolution (tuple[int]): Input resulotion.
289
+ num_heads (int): Number of heads for spatial self-correlation.
290
+ base_win_size (tuple[int]): The height and width of the base window.
291
+ window_size (tuple[int]): The height and width of the window.
292
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
293
+ drop (float, optional): Dropout rate. Default: 0.0
294
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
295
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
296
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
297
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
298
+ """
299
+
300
+ def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
301
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
302
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
303
+ super().__init__()
304
+ self.dim = dim
305
+ self.input_resolution = input_resolution
306
+ self.num_heads = num_heads
307
+ self.window_size = window_size
308
+ self.mlp_ratio = mlp_ratio
309
+
310
+ # check window size
311
+ if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
312
+ assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
313
+ assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
314
+
315
+
316
+ self.norm1 = norm_layer(dim)
317
+ self.correlation = SCC(
318
+ dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
319
+ value_drop=value_drop, proj_drop=drop)
320
+
321
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
322
+ self.norm2 = norm_layer(dim)
323
+ mlp_hidden_dim = int(dim * mlp_ratio)
324
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
325
+
326
+ def check_image_size(self, x, win_size):
327
+ x = x.permute(0,3,1,2).contiguous()
328
+ _, _, h, w = x.size()
329
+ mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
330
+ mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
331
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
332
+ x = x.permute(0,2,3,1).contiguous()
333
+ return x
334
+
335
+ def forward(self, x, x_size, win_size):
336
+ H, W = x_size
337
+ B, L, C = x.shape
338
+
339
+ shortcut = x
340
+ x = x.view(B, H, W, C)
341
+
342
+ # padding
343
+ x = self.check_image_size(x, win_size)
344
+ _, H_pad, W_pad, _ = x.shape # shape after padding
345
+
346
+ x = self.correlation(x)
347
+
348
+ # unpad
349
+ x = x[:, :H, :W, :].contiguous()
350
+
351
+ # norm
352
+ x = x.view(B, H * W, C)
353
+ x = self.norm1(x)
354
+
355
+ # FFN
356
+ x = shortcut + self.drop_path(x)
357
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
358
+
359
+ return x
360
+
361
+ def extra_repr(self) -> str:
362
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
363
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
364
+
365
+
366
+ class PatchMerging(nn.Module):
367
+ """ Patch Merging Layer.
368
+ Args:
369
+ input_resolution (tuple[int]): Resolution of input feature.
370
+ dim (int): Number of input channels.
371
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
372
+ """
373
+
374
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
375
+ super().__init__()
376
+ self.input_resolution = input_resolution
377
+ self.dim = dim
378
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
379
+ self.norm = norm_layer(4 * dim)
380
+
381
+ def forward(self, x):
382
+ """
383
+ x: B, H*W, C
384
+ """
385
+ H, W = self.input_resolution
386
+ B, L, C = x.shape
387
+ assert L == H * W, "input feature has wrong size"
388
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
389
+
390
+ x = x.view(B, H, W, C)
391
+
392
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
393
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
394
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
395
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
396
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
397
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
398
+
399
+ x = self.norm(x)
400
+ x = self.reduction(x)
401
+
402
+ return x
403
+
404
+ def extra_repr(self) -> str:
405
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
406
+
407
+
408
+ class BasicLayer(nn.Module):
409
+ """ A basic Hierarchical Transformer layer for one stage.
410
+
411
+ Args:
412
+ dim (int): Number of input channels.
413
+ input_resolution (tuple[int]): Input resolution.
414
+ depth (int): Number of blocks.
415
+ num_heads (int): Number of heads for spatial self-correlation.
416
+ base_win_size (tuple[int]): The height and width of the base window.
417
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
418
+ drop (float, optional): Dropout rate. Default: 0.0
419
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
420
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
421
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
422
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
423
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
424
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
425
+ """
426
+
427
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
428
+ mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
429
+ downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
430
+
431
+ super().__init__()
432
+ self.dim = dim
433
+ self.input_resolution = input_resolution
434
+ self.depth = depth
435
+ self.use_checkpoint = use_checkpoint
436
+
437
+ self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
438
+ self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
439
+
440
+ # build blocks
441
+ self.blocks = nn.ModuleList([
442
+ HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
443
+ num_heads=num_heads,
444
+ base_win_size=base_win_size,
445
+ window_size=(self.win_hs[i], self.win_ws[i]),
446
+ mlp_ratio=mlp_ratio,
447
+ drop=drop, value_drop=value_drop,
448
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
449
+ norm_layer=norm_layer)
450
+ for i in range(depth)])
451
+
452
+ # patch merging layer
453
+ if downsample is not None:
454
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
455
+ else:
456
+ self.downsample = None
457
+
458
+ def forward(self, x, x_size):
459
+
460
+ i = 0
461
+ for blk in self.blocks:
462
+ if self.use_checkpoint:
463
+ x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
464
+ else:
465
+ x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
466
+ i = i + 1
467
+
468
+ if self.downsample is not None:
469
+ x = self.downsample(x)
470
+ return x
471
+
472
+ def extra_repr(self) -> str:
473
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
474
+
475
+
476
+ class RHTB(nn.Module):
477
+ """Residual Hierarchical Transformer Block (RHTB).
478
+ Args:
479
+ dim (int): Number of input channels.
480
+ input_resolution (tuple[int]): Input resolution.
481
+ depth (int): Number of blocks.
482
+ num_heads (int): Number of heads for spatial self-correlation.
483
+ base_win_size (tuple[int]): The height and width of the base window.
484
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
485
+ drop (float, optional): Dropout rate. Default: 0.0
486
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
487
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
488
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
489
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
490
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
491
+ img_size: Input image size.
492
+ patch_size: Patch size.
493
+ resi_connection: The convolutional block before residual connection.
494
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
495
+ """
496
+
497
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
498
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
499
+ downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
500
+ resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
501
+ super(RHTB, self).__init__()
502
+
503
+ self.dim = dim
504
+ self.input_resolution = input_resolution
505
+
506
+ self.residual_group = BasicLayer(dim=dim,
507
+ input_resolution=input_resolution,
508
+ depth=depth,
509
+ num_heads=num_heads,
510
+ base_win_size=base_win_size,
511
+ mlp_ratio=mlp_ratio,
512
+ drop=drop, value_drop=value_drop,
513
+ drop_path=drop_path,
514
+ norm_layer=norm_layer,
515
+ downsample=downsample,
516
+ use_checkpoint=use_checkpoint,
517
+ hier_win_ratios=hier_win_ratios)
518
+
519
+ if resi_connection == '1conv':
520
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
521
+ elif resi_connection == '3conv':
522
+ # to save parameters and memory
523
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
524
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
525
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
526
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
527
+
528
+ self.patch_embed = PatchEmbed(
529
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
530
+ norm_layer=None)
531
+
532
+ self.patch_unembed = PatchUnEmbed(
533
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
534
+ norm_layer=None)
535
+
536
+ def forward(self, x, x_size):
537
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
538
+
539
+
540
+ class PatchEmbed(nn.Module):
541
+ r""" Image to Patch Embedding
542
+
543
+ Args:
544
+ img_size (int): Image size. Default: 224.
545
+ patch_size (int): Patch token size. Default: 4.
546
+ in_chans (int): Number of input image channels. Default: 3.
547
+ embed_dim (int): Number of linear projection output channels. Default: 96.
548
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
549
+ """
550
+
551
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
552
+ super().__init__()
553
+ img_size = to_2tuple(img_size)
554
+ patch_size = to_2tuple(patch_size)
555
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
556
+ self.img_size = img_size
557
+ self.patch_size = patch_size
558
+ self.patches_resolution = patches_resolution
559
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
560
+
561
+ self.in_chans = in_chans
562
+ self.embed_dim = embed_dim
563
+
564
+ if norm_layer is not None:
565
+ self.norm = norm_layer(embed_dim)
566
+ else:
567
+ self.norm = None
568
+
569
+ def forward(self, x):
570
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
571
+ if self.norm is not None:
572
+ x = self.norm(x)
573
+ return x
574
+
575
+
576
+ class PatchUnEmbed(nn.Module):
577
+ r""" Image to Patch Unembedding
578
+
579
+ Args:
580
+ img_size (int): Image size. Default: 224.
581
+ patch_size (int): Patch token size. Default: 4.
582
+ in_chans (int): Number of input image channels. Default: 3.
583
+ embed_dim (int): Number of linear projection output channels. Default: 96.
584
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
585
+ """
586
+
587
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
588
+ super().__init__()
589
+ img_size = to_2tuple(img_size)
590
+ patch_size = to_2tuple(patch_size)
591
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
592
+ self.img_size = img_size
593
+ self.patch_size = patch_size
594
+ self.patches_resolution = patches_resolution
595
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
596
+
597
+ self.in_chans = in_chans
598
+ self.embed_dim = embed_dim
599
+
600
+ def forward(self, x, x_size):
601
+ B, HW, C = x.shape
602
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
603
+ return x
604
+
605
+
606
+ class Upsample(nn.Sequential):
607
+ """Upsample module.
608
+
609
+ Args:
610
+ scale (int): Scale factor. Supported scales: 2^n and 3.
611
+ num_feat (int): Channel number of intermediate features.
612
+ """
613
+
614
+ def __init__(self, scale, num_feat):
615
+ m = []
616
+ if (scale & (scale - 1)) == 0: # scale = 2^n
617
+ for _ in range(int(math.log(scale, 2))):
618
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
619
+ m.append(nn.PixelShuffle(2))
620
+ elif scale == 3:
621
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
622
+ m.append(nn.PixelShuffle(3))
623
+ else:
624
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
625
+ super(Upsample, self).__init__(*m)
626
+
627
+
628
+ class UpsampleOneStep(nn.Sequential):
629
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
630
+ Used in lightweight SR to save parameters.
631
+
632
+ Args:
633
+ scale (int): Scale factor. Supported scales: 2^n and 3.
634
+ num_feat (int): Channel number of intermediate features.
635
+
636
+ """
637
+
638
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
639
+ self.num_feat = num_feat
640
+ self.input_resolution = input_resolution
641
+ m = []
642
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
643
+ m.append(nn.PixelShuffle(scale))
644
+ super(UpsampleOneStep, self).__init__(*m)
645
+
646
+
647
+ class HiT_SIR(nn.Module, PyTorchModelHubMixin):
648
+ """ HiT-SIR network.
649
+
650
+ Args:
651
+ img_size (int | tuple(int)): Input image size. Default 64
652
+ patch_size (int | tuple(int)): Patch size. Default: 1
653
+ in_chans (int): Number of input image channels. Default: 3
654
+ embed_dim (int): Patch embedding dimension. Default: 96
655
+ depths (tuple(int)): Depth of each Transformer block.
656
+ num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
657
+ base_win_size (tuple[int]): The height and width of the base window.
658
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
659
+ drop_rate (float): Dropout rate. Default: 0
660
+ value_drop_rate (float): Dropout ratio of value. Default: 0.0
661
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
662
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
663
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
664
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
665
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
666
+ upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
667
+ img_range (float): Image range. 1. or 255.
668
+ upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
669
+ resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
670
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
671
+ """
672
+
673
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
674
+ embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
675
+ base_win_size=[8,8], mlp_ratio=2.,
676
+ drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
677
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
678
+ use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
679
+ hier_win_ratios=[0.5,1,2,4,6,8],
680
+ **kwargs):
681
+ super(HiT_SIR, self).__init__()
682
+ num_in_ch = in_chans
683
+ num_out_ch = in_chans
684
+ num_feat = 64
685
+ self.img_range = img_range
686
+ if in_chans == 3:
687
+ rgb_mean = (0.4488, 0.4371, 0.4040)
688
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
689
+ else:
690
+ self.mean = torch.zeros(1, 1, 1, 1)
691
+ self.upscale = upscale
692
+ self.upsampler = upsampler
693
+ self.base_win_size = base_win_size
694
+
695
+ #####################################################################################################
696
+ ################################### 1, shallow feature extraction ###################################
697
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
698
+
699
+ #####################################################################################################
700
+ ################################### 2, deep feature extraction ######################################
701
+ self.num_layers = len(depths)
702
+ self.embed_dim = embed_dim
703
+ self.ape = ape
704
+ self.patch_norm = patch_norm
705
+ self.num_features = embed_dim
706
+ self.mlp_ratio = mlp_ratio
707
+
708
+ # split image into non-overlapping patches
709
+ self.patch_embed = PatchEmbed(
710
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
711
+ norm_layer=norm_layer if self.patch_norm else None)
712
+ num_patches = self.patch_embed.num_patches
713
+ patches_resolution = self.patch_embed.patches_resolution
714
+ self.patches_resolution = patches_resolution
715
+
716
+ # merge non-overlapping patches into image
717
+ self.patch_unembed = PatchUnEmbed(
718
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
719
+ norm_layer=norm_layer if self.patch_norm else None)
720
+
721
+ # absolute position embedding
722
+ if self.ape:
723
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
724
+ trunc_normal_(self.absolute_pos_embed, std=.02)
725
+
726
+ self.pos_drop = nn.Dropout(p=drop_rate)
727
+
728
+ # stochastic depth
729
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
730
+
731
+ # build Residual Hierarchical Transformer blocks (RHTB)
732
+ self.layers = nn.ModuleList()
733
+ for i_layer in range(self.num_layers):
734
+ layer = RHTB(dim=embed_dim,
735
+ input_resolution=(patches_resolution[0],
736
+ patches_resolution[1]),
737
+ depth=depths[i_layer],
738
+ num_heads=num_heads[i_layer],
739
+ base_win_size=base_win_size,
740
+ mlp_ratio=self.mlp_ratio,
741
+ drop=drop_rate, value_drop=value_drop_rate,
742
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
743
+ norm_layer=norm_layer,
744
+ downsample=None,
745
+ use_checkpoint=use_checkpoint,
746
+ img_size=img_size,
747
+ patch_size=patch_size,
748
+ resi_connection=resi_connection,
749
+ hier_win_ratios=hier_win_ratios
750
+ )
751
+ self.layers.append(layer)
752
+ self.norm = norm_layer(self.num_features)
753
+
754
+ # build the last conv layer in deep feature extraction
755
+ if resi_connection == '1conv':
756
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
757
+ elif resi_connection == '3conv':
758
+ # to save parameters and memory
759
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
760
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
761
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
762
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
763
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
764
+
765
+ #####################################################################################################
766
+ ################################ 3, high quality image reconstruction ################################
767
+ if self.upsampler == 'pixelshuffle':
768
+ # for classical SR
769
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
770
+ nn.LeakyReLU(inplace=True))
771
+ self.upsample = Upsample(upscale, num_feat)
772
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
773
+ elif self.upsampler == 'pixelshuffledirect':
774
+ # for lightweight SR (to save parameters)
775
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
776
+ (patches_resolution[0], patches_resolution[1]))
777
+ elif self.upsampler == 'nearest+conv':
778
+ # for real-world SR (less artifacts)
779
+ assert self.upscale == 4, 'only support x4 now.'
780
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
781
+ nn.LeakyReLU(inplace=True))
782
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
783
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
784
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
785
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
786
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
787
+ else:
788
+ # for image denoising and JPEG compression artifact reduction
789
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
790
+
791
+ self.apply(self._init_weights)
792
+
793
+ def _init_weights(self, m):
794
+ if isinstance(m, nn.Linear):
795
+ trunc_normal_(m.weight, std=.02)
796
+ if isinstance(m, nn.Linear) and m.bias is not None:
797
+ nn.init.constant_(m.bias, 0)
798
+ elif isinstance(m, nn.LayerNorm):
799
+ nn.init.constant_(m.bias, 0)
800
+ nn.init.constant_(m.weight, 1.0)
801
+
802
+ @torch.jit.ignore
803
+ def no_weight_decay(self):
804
+ return {'absolute_pos_embed'}
805
+
806
+ @torch.jit.ignore
807
+ def no_weight_decay_keywords(self):
808
+ return {'relative_position_bias_table'}
809
+
810
+
811
+ def forward_features(self, x):
812
+ x_size = (x.shape[2], x.shape[3])
813
+ x = self.patch_embed(x)
814
+ if self.ape:
815
+ x = x + self.absolute_pos_embed
816
+ x = self.pos_drop(x)
817
+
818
+ for layer in self.layers:
819
+ x = layer(x, x_size)
820
+
821
+ x = self.norm(x) # B L C
822
+ x = self.patch_unembed(x, x_size)
823
+
824
+ return x
825
+
826
+ def infer_image(self, image_path, cuda=True):
827
+
828
+ io_backend_opt = {'type':'disk'}
829
+ self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
830
+
831
+ # load lq image
832
+ lq_path = image_path
833
+ img_bytes = self.file_client.get(lq_path, 'lq')
834
+ img_lq = imfrombytes(img_bytes, float32=True)
835
+
836
+ # BGR to RGB, HWC to CHW, numpy to tensor
837
+ x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
838
+
839
+ if cuda:
840
+ x= x.cuda()
841
+
842
+ out = self(x)
843
+
844
+ if cuda:
845
+ out = out.cpu()
846
+
847
+ out = tensor2img(out)
848
+
849
+ return out
850
+
851
+ def forward(self, x):
852
+ H, W = x.shape[2:]
853
+
854
+ self.mean = self.mean.type_as(x)
855
+ x = (x - self.mean) * self.img_range
856
+
857
+ if self.upsampler == 'pixelshuffle':
858
+ # for classical SR
859
+ x = self.conv_first(x)
860
+ x = self.conv_after_body(self.forward_features(x)) + x
861
+ x = self.conv_before_upsample(x)
862
+ x = self.conv_last(self.upsample(x))
863
+ elif self.upsampler == 'pixelshuffledirect':
864
+ # for lightweight SR
865
+ x = self.conv_first(x)
866
+ x = self.conv_after_body(self.forward_features(x)) + x
867
+ x = self.upsample(x)
868
+ elif self.upsampler == 'nearest+conv':
869
+ # for real-world SR
870
+ x = self.conv_first(x)
871
+ x = self.conv_after_body(self.forward_features(x)) + x
872
+ x = self.conv_before_upsample(x)
873
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
874
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
875
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
876
+ else:
877
+ # for image denoising and JPEG compression artifact reduction
878
+ x_first = self.conv_first(x)
879
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
880
+ x = x + self.conv_last(res)
881
+
882
+ x = x / self.img_range + self.mean
883
+
884
+ return x[:, :, :H*self.upscale, :W*self.upscale]
885
+
886
+
887
+ if __name__ == '__main__':
888
+ upscale = 4
889
+ base_win_size = [8, 8]
890
+ height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
891
+ width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
892
+
893
+ ## HiT-SIR
894
+ model = HiT_SIR(upscale=4, img_size=(height, width),
895
+ base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
896
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
897
+
898
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
899
+ print("params: ", params_num)
900
+
hit_sng_arch.py ADDED
@@ -0,0 +1,1132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_, _assert
7
+ from torchvision.transforms import functional as TF
8
+ from timm.models.fx_features import register_notrace_function
9
+
10
+ import numpy as np
11
+ from einops import rearrange
12
+ from huggingface_hub import PyTorchModelHubMixin
13
+ from utils import FileClient, imfrombytes, img2tensor, tensor2img
14
+
15
+ class DFE(nn.Module):
16
+ """ Dual Feature Extraction
17
+ Args:
18
+ in_features (int): Number of input channels.
19
+ out_features (int): Number of output channels.
20
+ """
21
+ def __init__(self, in_features, out_features):
22
+ super().__init__()
23
+
24
+ self.out_features = out_features
25
+
26
+ self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
27
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
28
+ nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
29
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
30
+ nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
31
+
32
+ self.linear = nn.Conv2d(in_features, out_features,1,1,0)
33
+
34
+ def forward(self, x, x_size):
35
+
36
+ B, L, C = x.shape
37
+ H, W = x_size
38
+ x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
39
+ x = self.conv(x) * self.linear(x)
40
+ x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
41
+
42
+ return x
43
+
44
+ class Mlp(nn.Module):
45
+ """ MLP-based Feed-Forward Network
46
+ Args:
47
+ in_features (int): Number of input channels.
48
+ hidden_features (int | None): Number of hidden channels. Default: None
49
+ out_features (int | None): Number of output channels. Default: None
50
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
51
+ drop (float): Dropout rate. Default: 0.0
52
+ """
53
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
54
+ super().__init__()
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ self.fc1 = nn.Linear(in_features, hidden_features)
58
+ self.act = act_layer()
59
+ self.fc2 = nn.Linear(hidden_features, out_features)
60
+ self.drop = nn.Dropout(drop)
61
+
62
+ def forward(self, x):
63
+ x = self.fc1(x)
64
+ x = self.act(x)
65
+ x = self.drop(x)
66
+ x = self.fc2(x)
67
+ x = self.drop(x)
68
+ return x
69
+
70
+
71
+ def window_partition(x, window_size):
72
+ """
73
+ Args:
74
+ x: (B, H, W, C)
75
+ window_size (int): window size
76
+
77
+ Returns:
78
+ windows: (num_windows*B, window_size, window_size, C)
79
+ """
80
+ B, H, W, C = x.shape
81
+ wh, ww = H//window_size, W//window_size
82
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
83
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
84
+ return windows, (wh, ww)
85
+
86
+ @register_notrace_function # reason: int argument is a Proxy
87
+ def window_unpartition(windows, num_windows):
88
+ """
89
+ Args:
90
+ windows: [B*wh*ww, WH, WW, D]
91
+ num_windows (tuple[int]): The height and width of the window.
92
+ Returns:
93
+ x: [B, ph, pw, D]
94
+ """
95
+ x = rearrange(windows, '(p h w) wh ww c -> p (h wh) (w ww) c', h=num_windows[0], w=num_windows[1])
96
+ return x.contiguous()
97
+
98
+ def window_reverse(windows, window_size, H, W):
99
+ """
100
+ Args:
101
+ windows: (num_windows*B, window_size, window_size, C)
102
+ window_size (tuple): Window size
103
+ H (int): Height of image
104
+ W (int): Width of image
105
+
106
+ Returns:
107
+ x: (B, H, W, C)
108
+ """
109
+ B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
110
+ x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
111
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
112
+ return x
113
+
114
+ class DynamicPosBias(nn.Module):
115
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
116
+ """ Dynamic Relative Position Bias.
117
+ Args:
118
+ dim (int): Number of input channels.
119
+ num_heads (int): Number of heads for spatial self-correlation.
120
+ residual (bool): If True, use residual strage to connect conv.
121
+ """
122
+ def __init__(self, dim, num_heads, residual):
123
+ super().__init__()
124
+ self.residual = residual
125
+ self.num_heads = num_heads
126
+ self.pos_dim = dim // 4
127
+ self.pos_proj = nn.Linear(2, self.pos_dim)
128
+ self.pos1 = nn.Sequential(
129
+ nn.LayerNorm(self.pos_dim),
130
+ nn.ReLU(inplace=True),
131
+ nn.Linear(self.pos_dim, self.pos_dim),
132
+ )
133
+ self.pos2 = nn.Sequential(
134
+ nn.LayerNorm(self.pos_dim),
135
+ nn.ReLU(inplace=True),
136
+ nn.Linear(self.pos_dim, self.pos_dim)
137
+ )
138
+ self.pos3 = nn.Sequential(
139
+ nn.LayerNorm(self.pos_dim),
140
+ nn.ReLU(inplace=True),
141
+ nn.Linear(self.pos_dim, self.num_heads)
142
+ )
143
+ def forward(self, biases):
144
+ if self.residual:
145
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
146
+ pos = pos + self.pos1(pos)
147
+ pos = pos + self.pos2(pos)
148
+ pos = self.pos3(pos)
149
+ else:
150
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
151
+ return pos
152
+
153
+ class SCC(nn.Module):
154
+ """ Spatial-Channel Correlation.
155
+ Args:
156
+ dim (int): Number of input channels.
157
+ base_win_size (tuple[int]): The height and width of the base window.
158
+ window_size (tuple[int]): The height and width of the window.
159
+ num_heads (int): Number of heads for spatial self-correlation.
160
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
161
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
162
+ """
163
+
164
+ def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
165
+
166
+ super().__init__()
167
+ # parameters
168
+ self.dim = dim
169
+ self.window_size = window_size
170
+ self.num_heads = num_heads
171
+
172
+ # feature projection
173
+ head_dim = dim // (2*num_heads)
174
+ if dim % (2*num_heads) > 0:
175
+ head_dim = head_dim + 1
176
+ self.attn_dim = head_dim * 2 * num_heads
177
+ self.qv = DFE(dim, self.attn_dim)
178
+ self.proj = nn.Linear(self.attn_dim, dim)
179
+
180
+ # dropout
181
+ self.value_drop = nn.Dropout(value_drop)
182
+ self.proj_drop = nn.Dropout(proj_drop)
183
+
184
+ # base window size
185
+ min_h = min(self.window_size[0], base_win_size[0])
186
+ min_w = min(self.window_size[1], base_win_size[1])
187
+ self.base_win_size = (min_h, min_w)
188
+
189
+ # normalization factor and spatial linear layer for S-SC
190
+ self.scale = head_dim
191
+ self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
192
+
193
+ # NGram window partition without shifting
194
+ self.ngram_window_partition = NGramWindowPartition(dim, window_size, 2, num_heads, shift_size=0)
195
+
196
+ # define a parameter table of relative position bias
197
+ self.H_sp, self.W_sp = self.window_size
198
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
199
+
200
+ def spatial_linear_projection(self, x):
201
+ B, num_h, L, C = x.shape
202
+ H, W = self.window_size
203
+ map_H, map_W = self.base_win_size
204
+
205
+ x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
206
+ x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
207
+ return x
208
+
209
+ def spatial_self_correlation(self, q, v):
210
+
211
+ B, num_head, L, C = q.shape
212
+
213
+ # spatial projection
214
+ v = self.spatial_linear_projection(v)
215
+
216
+ # compute correlation map
217
+ corr_map = (q @ v.transpose(-2,-1)) / self.scale
218
+
219
+ # add relative position bias
220
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
221
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
222
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
223
+ rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
224
+ pos = self.pos(rpe_biases)
225
+
226
+ # select position bias
227
+ coords_h = torch.arange(self.H_sp, device=v.device)
228
+ coords_w = torch.arange(self.W_sp, device=v.device)
229
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
230
+ coords_flatten = torch.flatten(coords, 1)
231
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
232
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
233
+ relative_coords[:, :, 0] += self.H_sp - 1
234
+ relative_coords[:, :, 1] += self.W_sp - 1
235
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
236
+ relative_position_index = relative_coords.sum(-1)
237
+ relative_position_bias = pos[relative_position_index.view(-1)].view(
238
+ self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
239
+ relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
240
+ self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
241
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
242
+ corr_map = corr_map + relative_position_bias.unsqueeze(0)
243
+
244
+ # transformation
245
+ v_drop = self.value_drop(v)
246
+ x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
247
+
248
+ return x
249
+
250
+ def channel_self_correlation(self, q, v):
251
+
252
+ B, num_head, L, C = q.shape
253
+
254
+ # apply single head strategy
255
+ q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
256
+ v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
257
+
258
+ # compute correlation map
259
+ corr_map = (q.transpose(-2,-1) @ v) / L
260
+
261
+ # transformation
262
+ v_drop = self.value_drop(v)
263
+ x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
264
+
265
+ return x
266
+
267
+ def forward(self, x):
268
+ """
269
+ Args:
270
+ x: input features with shape of (B, H, W, C)
271
+ """
272
+ xB,xH,xW,xC = x.shape
273
+ qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
274
+
275
+ # window partition
276
+ qv = self.ngram_window_partition(qv)
277
+ qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
278
+
279
+ # qv splitting
280
+ B, L, C = qv.shape
281
+ qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
282
+ q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
283
+
284
+ # spatial self-correlation (S-SC)
285
+ x_spatial = self.spatial_self_correlation(q, v)
286
+ x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
287
+ x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
288
+
289
+ # channel self-correlation (C-SC)
290
+ x_channel = self.channel_self_correlation(q, v)
291
+ x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
292
+ x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
293
+
294
+ # spatial-channel information fusion
295
+ x = torch.cat([x_spatial, x_channel], -1)
296
+ x = self.proj_drop(self.proj(x))
297
+
298
+ return x
299
+
300
+ def extra_repr(self) -> str:
301
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
302
+
303
+ class NGramWindowAttention(nn.Module):
304
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias for NGram attention.
305
+ It supports both of shifted and non-shifted window.
306
+
307
+ Args:
308
+ dim (int): Number of input channels.
309
+ window_size (tuple[int]): The height and width of the window.
310
+ num_heads (int): Number of attention heads.
311
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
312
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
313
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
314
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
315
+ """
316
+
317
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
318
+
319
+ super().__init__()
320
+ self.dim = dim
321
+ self.window_size = window_size # Wh, Ww
322
+ self.num_heads = num_heads
323
+ head_dim = dim // num_heads
324
+ self.scale = qk_scale or head_dim ** -0.5
325
+
326
+ # define a parameter table of relative position bias
327
+ self.relative_position_bias_table = nn.Parameter(
328
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
329
+
330
+ # get pair-wise relative position index for each token inside the window
331
+ coords_h = torch.arange(self.window_size[0])
332
+ coords_w = torch.arange(self.window_size[1])
333
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
334
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
335
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
336
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
337
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
338
+ relative_coords[:, :, 1] += self.window_size[1] - 1
339
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
340
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
341
+ self.register_buffer("relative_position_index", relative_position_index)
342
+
343
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
344
+ self.attn_drop = nn.Dropout(attn_drop)
345
+ self.proj = nn.Linear(dim, dim)
346
+
347
+ self.proj_drop = nn.Dropout(proj_drop)
348
+
349
+ trunc_normal_(self.relative_position_bias_table, std=.02)
350
+ self.softmax = nn.Softmax(dim=-1)
351
+
352
+ def forward(self, x, mask=None):
353
+ """
354
+ Args:
355
+ x: input features with shape of (num_windows*B, N, C)
356
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
357
+ """
358
+ B_, N, C = x.shape
359
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
360
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
361
+
362
+ q = q * self.scale
363
+ attn = (q @ k.transpose(-2, -1))
364
+
365
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
366
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
367
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
368
+ attn = attn + relative_position_bias.unsqueeze(0)
369
+
370
+ if mask is not None:
371
+ nW = mask.shape[0]
372
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
373
+ attn = attn.view(-1, self.num_heads, N, N)
374
+ attn = self.softmax(attn)
375
+ else:
376
+ attn = self.softmax(attn)
377
+
378
+ attn = self.attn_drop(attn)
379
+
380
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
381
+ x = self.proj(x)
382
+ x = self.proj_drop(x)
383
+ return x
384
+
385
+ def extra_repr(self) -> str:
386
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
387
+
388
+ class NGramContext(nn.Module):
389
+ '''
390
+ Args:
391
+ dim (int): Number of input channels.
392
+ window_size (int or tuple[int]): The height and width of the window.
393
+ ngram (int): How much windows(or patches) to see.
394
+ ngram_num_heads (int):
395
+ padding_mode (str, optional): How to pad. Default: seq_refl_win_pad
396
+ Options: ['seq_refl_win_pad', 'zero_pad']
397
+ '''
398
+ def __init__(self, dim, window_size, ngram, ngram_num_heads, padding_mode='seq_refl_win_pad'):
399
+ super(NGramContext, self).__init__()
400
+ _assert(padding_mode in ['seq_refl_win_pad', 'zero_pad'], "padding mode should be 'seq_refl_win_pad' or 'zero_pad'!!")
401
+
402
+ self.dim = dim
403
+ self.window_size = to_2tuple(window_size)
404
+ self.ngram = ngram
405
+ self.padding_mode = padding_mode
406
+
407
+ # to alleviate parameter expansion with window sizes
408
+ self.unigram_embed = nn.Conv2d(2, 1,
409
+ kernel_size=(self.window_size[0], self.window_size[1]),
410
+ stride=self.window_size, padding=0, groups=1)
411
+
412
+ self.ngram_attn = NGramWindowAttention(dim=dim//2, num_heads=ngram_num_heads, window_size=(ngram, ngram))
413
+ self.avg_pool = nn.AvgPool2d(ngram)
414
+ self.merge = nn.Conv2d(dim, dim, 1, 1, 0)
415
+
416
+ def seq_refl_win_pad(self, x, back=False):
417
+ if self.ngram == 1: return x
418
+ x = TF.pad(x, (0,0,self.ngram-1,self.ngram-1)) if not back else TF.pad(x, (self.ngram-1,self.ngram-1,0,0))
419
+ if self.padding_mode == 'zero_pad':
420
+ return x
421
+ if not back:
422
+ (start_h, start_w), (end_h, end_w) = to_2tuple(-2*self.ngram+1), to_2tuple(-self.ngram)
423
+ # pad lower
424
+ x[:,:,-(self.ngram-1):,:] = x[:,:,start_h:end_h,:]
425
+ # pad right
426
+ x[:,:,:,-(self.ngram-1):] = x[:,:,:,start_w:end_w]
427
+ else:
428
+ (start_h, start_w), (end_h, end_w) = to_2tuple(self.ngram), to_2tuple(2*self.ngram-1)
429
+ # pad upper
430
+ x[:,:,:self.ngram-1,:] = x[:,:,start_h:end_h,:]
431
+ # pad left
432
+ x[:,:,:,:self.ngram-1] = x[:,:,:,start_w:end_w]
433
+
434
+ return x
435
+
436
+ def sliding_window_attention(self, unigram):
437
+ slide = unigram.unfold(3, self.ngram, 1).unfold(2, self.ngram, 1)
438
+ slide = rearrange(slide, 'b c h w ww hh -> b (h hh) (w ww) c') # [B, 2(wh+ngram-2), 2(ww+ngram-2), D/2]
439
+ slide, num_windows = window_partition(slide, self.ngram) # [B*wh*ww, ngram, ngram, D/2], (wh, ww)
440
+ slide = slide.view(-1, self.ngram*self.ngram, self.dim//2) # [B*wh*ww, ngram*ngram, D/2]
441
+
442
+ context = self.ngram_attn(slide).view(-1, self.ngram, self.ngram, self.dim//2) # [B*wh*ww, ngram, ngram, D/2]
443
+
444
+ context = window_unpartition(context, num_windows) # [B, wh*ngram, ww*ngram, D/2]
445
+ context = rearrange(context, 'b h w d -> b d h w') # [B, D/2, wh*ngram, ww*ngram]
446
+ context = self.avg_pool(context) # [B, D/2, wh, ww]
447
+ return context
448
+
449
+ def forward(self, x):
450
+ B, ph, pw, D = x.size()
451
+ x = rearrange(x, 'b ph pw d -> b d ph pw') # [B, D, ph, pw]
452
+ x = x.contiguous().view(B*(D//2),2,ph,pw)
453
+ unigram = self.unigram_embed(x).view(B, D//2, ph//self.window_size[0], pw//self.window_size[1])
454
+
455
+ unigram_forward_pad = self.seq_refl_win_pad(unigram, False) # [B, D/2, wh+ngram-1, ww+ngram-1]
456
+ unigram_backward_pad = self.seq_refl_win_pad(unigram, True) # [B, D/2, wh+ngram-1, ww+ngram-1]
457
+
458
+ context_forward = self.sliding_window_attention(unigram_forward_pad) # [B, D/2, wh, ww]
459
+ context_backward = self.sliding_window_attention(unigram_backward_pad) # [B, D/2, wh, ww]
460
+
461
+ context_bidirect = torch.cat([context_forward, context_backward], dim=1) # [B, D, wh, ww]
462
+ context_bidirect = self.merge(context_bidirect) # [B, D, wh, ww]
463
+ context_bidirect = rearrange(context_bidirect, 'b d h w -> b h w d') # [B, wh, ww, D]
464
+
465
+ return context_bidirect.unsqueeze(-2).unsqueeze(-2).contiguous() # [B, wh, ww, 1, 1, D]
466
+
467
+ class NGramWindowPartition(nn.Module):
468
+ """
469
+ Args:
470
+ dim (int): Number of input channels.
471
+ window_size (int): The height and width of the window.
472
+ ngram (int): How much windows to see as context.
473
+ ngram_num_heads (int):
474
+ shift_size (int, optional): Shift size for SW-MSA. Default: 0
475
+ """
476
+ def __init__(self, dim, window_size, ngram, ngram_num_heads, shift_size=0):
477
+ super(NGramWindowPartition, self).__init__()
478
+ self.window_size = window_size[0]
479
+ self.ngram = ngram
480
+ self.shift_size = shift_size
481
+
482
+ self.ngram_context = NGramContext(dim, window_size, ngram, ngram_num_heads, padding_mode='seq_refl_win_pad')
483
+
484
+ def forward(self, x):
485
+ B, ph, pw, D = x.size()
486
+ wh, ww = ph//self.window_size, pw//self.window_size # number of windows (height, width)
487
+ _assert(0 not in [wh, ww], "feature map size should be larger than window size!")
488
+
489
+ context = self.ngram_context(x) # [B, wh, ww, 1, 1, D]
490
+
491
+ windows = rearrange(x, 'b (h wh) (w ww) c -> b h w wh ww c',
492
+ wh=self.window_size, ww=self.window_size).contiguous() # [B, wh, ww, WH, WW, D]. semi window partitioning
493
+ windows+=context # [B, wh, ww, WH, WW, D]. inject context
494
+
495
+ # Cyclic Shift
496
+ if self.shift_size>0:
497
+ x = rearrange(windows, 'b h w wh ww c -> b (h wh) (w ww) c').contiguous() # [B, ph, pw, D]. re-patchfying
498
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) # [B, ph, pw, D]. cyclic shift
499
+ windows = rearrange(shifted_x, 'b (h wh) (w ww) c -> b h w wh ww c',
500
+ wh=self.window_size, ww=self.window_size).contiguous() # [B, wh, ww, WH, WW, D]. re-semi window partitioning
501
+ windows = rearrange(windows, 'b h w wh ww c -> (b h w) wh ww c').contiguous() # [B*wh*ww, WH, WW, D]. window partitioning
502
+
503
+ return windows
504
+
505
+
506
+ class HierarchicalTransformerBlock(nn.Module):
507
+ """ Hierarchical Transformer Block.
508
+ Args:
509
+ dim (int): Number of input channels.
510
+ input_resolution (tuple[int]): Input resulotion.
511
+ num_heads (int): Number of heads for spatial self-correlation.
512
+ base_win_size (tuple[int]): The height and width of the base window.
513
+ window_size (tuple[int]): The height and width of the window.
514
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
515
+ drop (float, optional): Dropout rate. Default: 0.0
516
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
517
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
518
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
519
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
520
+ """
521
+
522
+ def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
523
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
524
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
525
+ super().__init__()
526
+ self.dim = dim
527
+ self.input_resolution = input_resolution
528
+ self.num_heads = num_heads
529
+ self.window_size = window_size
530
+ self.mlp_ratio = mlp_ratio
531
+
532
+ # check window size
533
+ if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
534
+ assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
535
+ assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
536
+
537
+
538
+ self.norm1 = norm_layer(dim)
539
+ self.correlation = SCC(
540
+ dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
541
+ value_drop=value_drop, proj_drop=drop)
542
+
543
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
544
+ self.norm2 = norm_layer(dim)
545
+ mlp_hidden_dim = int(dim * mlp_ratio)
546
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
547
+
548
+ def check_image_size(self, x, win_size):
549
+ x = x.permute(0,3,1,2).contiguous()
550
+ _, _, h, w = x.size()
551
+ mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
552
+ mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
553
+
554
+ if mod_pad_h >= h or mod_pad_w >= w:
555
+ pad_h, pad_w = h-1, w-1
556
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
557
+ else:
558
+ pad_h, pad_w = 0, 0
559
+
560
+ mod_pad_h = mod_pad_h - pad_h
561
+ mod_pad_w = mod_pad_w - pad_w
562
+
563
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
564
+ x = x.permute(0,2,3,1).contiguous()
565
+ return x
566
+
567
+ def forward(self, x, x_size, win_size):
568
+ H, W = x_size
569
+ B, L, C = x.shape
570
+
571
+ shortcut = x
572
+ x = x.view(B, H, W, C)
573
+
574
+ # padding
575
+ x = self.check_image_size(x, (win_size[0]*2, win_size[1]*2))
576
+ _, H_pad, W_pad, _ = x.shape # shape after padding
577
+
578
+ x = self.correlation(x)
579
+
580
+ # unpad
581
+ x = x[:, :H, :W, :].contiguous()
582
+
583
+ # norm
584
+ x = x.view(B, H * W, C)
585
+ x = self.norm1(x)
586
+
587
+ # FFN
588
+ x = shortcut + self.drop_path(x)
589
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
590
+
591
+ return x
592
+
593
+ def extra_repr(self) -> str:
594
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
595
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
596
+
597
+
598
+ class PatchMerging(nn.Module):
599
+ """ Patch Merging Layer.
600
+ Args:
601
+ input_resolution (tuple[int]): Resolution of input feature.
602
+ dim (int): Number of input channels.
603
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
604
+ """
605
+
606
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
607
+ super().__init__()
608
+ self.input_resolution = input_resolution
609
+ self.dim = dim
610
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
611
+ self.norm = norm_layer(4 * dim)
612
+
613
+ def forward(self, x):
614
+ """
615
+ x: B, H*W, C
616
+ """
617
+ H, W = self.input_resolution
618
+ B, L, C = x.shape
619
+ assert L == H * W, "input feature has wrong size"
620
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
621
+
622
+ x = x.view(B, H, W, C)
623
+
624
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
625
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
626
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
627
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
628
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
629
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
630
+
631
+ x = self.norm(x)
632
+ x = self.reduction(x)
633
+
634
+ return x
635
+
636
+ def extra_repr(self) -> str:
637
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
638
+
639
+
640
+ class BasicLayer(nn.Module):
641
+ """ A basic Hierarchical Transformer layer for one stage.
642
+
643
+ Args:
644
+ dim (int): Number of input channels.
645
+ input_resolution (tuple[int]): Input resolution.
646
+ depth (int): Number of blocks.
647
+ num_heads (int): Number of heads for spatial self-correlation.
648
+ base_win_size (tuple[int]): The height and width of the base window.
649
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
650
+ drop (float, optional): Dropout rate. Default: 0.0
651
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
652
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
653
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
654
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
655
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
656
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
657
+ """
658
+
659
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
660
+ mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
661
+ downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
662
+
663
+ super().__init__()
664
+ self.dim = dim
665
+ self.input_resolution = input_resolution
666
+ self.depth = depth
667
+ self.use_checkpoint = use_checkpoint
668
+
669
+ self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
670
+ self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
671
+
672
+ # build blocks
673
+ self.blocks = nn.ModuleList([
674
+ HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
675
+ num_heads=num_heads,
676
+ base_win_size=base_win_size,
677
+ window_size=(self.win_hs[i], self.win_ws[i]),
678
+ mlp_ratio=mlp_ratio,
679
+ drop=drop, value_drop=value_drop,
680
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
681
+ norm_layer=norm_layer)
682
+ for i in range(depth)])
683
+
684
+ # patch merging layer
685
+ if downsample is not None:
686
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
687
+ else:
688
+ self.downsample = None
689
+
690
+ def forward(self, x, x_size):
691
+
692
+ i = 0
693
+ for blk in self.blocks:
694
+ if self.use_checkpoint:
695
+ x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
696
+ else:
697
+ x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
698
+ i = i + 1
699
+
700
+ if self.downsample is not None:
701
+ x = self.downsample(x)
702
+ return x
703
+
704
+ def extra_repr(self) -> str:
705
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
706
+
707
+
708
+ class RHTB(nn.Module):
709
+ """Residual Hierarchical Transformer Block (RHTB).
710
+ Args:
711
+ dim (int): Number of input channels.
712
+ input_resolution (tuple[int]): Input resolution.
713
+ depth (int): Number of blocks.
714
+ num_heads (int): Number of heads for spatial self-correlation.
715
+ base_win_size (tuple[int]): The height and width of the base window.
716
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
717
+ drop (float, optional): Dropout rate. Default: 0.0
718
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
719
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
720
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
721
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
722
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
723
+ img_size: Input image size.
724
+ patch_size: Patch size.
725
+ resi_connection: The convolutional block before residual connection.
726
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
727
+ """
728
+
729
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
730
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
731
+ downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
732
+ resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
733
+ super(RHTB, self).__init__()
734
+
735
+ self.dim = dim
736
+ self.input_resolution = input_resolution
737
+
738
+ self.residual_group = BasicLayer(dim=dim,
739
+ input_resolution=input_resolution,
740
+ depth=depth,
741
+ num_heads=num_heads,
742
+ base_win_size=base_win_size,
743
+ mlp_ratio=mlp_ratio,
744
+ drop=drop, value_drop=value_drop,
745
+ drop_path=drop_path,
746
+ norm_layer=norm_layer,
747
+ downsample=downsample,
748
+ use_checkpoint=use_checkpoint,
749
+ hier_win_ratios=hier_win_ratios)
750
+
751
+ if resi_connection == '1conv':
752
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
753
+ elif resi_connection == '3conv':
754
+ # to save parameters and memory
755
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
756
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
757
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
758
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
759
+
760
+ self.patch_embed = PatchEmbed(
761
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
762
+ norm_layer=None)
763
+
764
+ self.patch_unembed = PatchUnEmbed(
765
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
766
+ norm_layer=None)
767
+
768
+ def forward(self, x, x_size):
769
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
770
+
771
+
772
+ class PatchEmbed(nn.Module):
773
+ r""" Image to Patch Embedding
774
+
775
+ Args:
776
+ img_size (int): Image size. Default: 224.
777
+ patch_size (int): Patch token size. Default: 4.
778
+ in_chans (int): Number of input image channels. Default: 3.
779
+ embed_dim (int): Number of linear projection output channels. Default: 96.
780
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
781
+ """
782
+
783
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
784
+ super().__init__()
785
+ img_size = to_2tuple(img_size)
786
+ patch_size = to_2tuple(patch_size)
787
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
788
+ self.img_size = img_size
789
+ self.patch_size = patch_size
790
+ self.patches_resolution = patches_resolution
791
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
792
+
793
+ self.in_chans = in_chans
794
+ self.embed_dim = embed_dim
795
+
796
+ if norm_layer is not None:
797
+ self.norm = norm_layer(embed_dim)
798
+ else:
799
+ self.norm = None
800
+
801
+ def forward(self, x):
802
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
803
+ if self.norm is not None:
804
+ x = self.norm(x)
805
+ return x
806
+
807
+
808
+ class PatchUnEmbed(nn.Module):
809
+ r""" Image to Patch Unembedding
810
+
811
+ Args:
812
+ img_size (int): Image size. Default: 224.
813
+ patch_size (int): Patch token size. Default: 4.
814
+ in_chans (int): Number of input image channels. Default: 3.
815
+ embed_dim (int): Number of linear projection output channels. Default: 96.
816
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
817
+ """
818
+
819
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
820
+ super().__init__()
821
+ img_size = to_2tuple(img_size)
822
+ patch_size = to_2tuple(patch_size)
823
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
824
+ self.img_size = img_size
825
+ self.patch_size = patch_size
826
+ self.patches_resolution = patches_resolution
827
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
828
+
829
+ self.in_chans = in_chans
830
+ self.embed_dim = embed_dim
831
+
832
+ def forward(self, x, x_size):
833
+ B, HW, C = x.shape
834
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
835
+ return x
836
+
837
+
838
+ class Upsample(nn.Sequential):
839
+ """Upsample module.
840
+
841
+ Args:
842
+ scale (int): Scale factor. Supported scales: 2^n and 3.
843
+ num_feat (int): Channel number of intermediate features.
844
+ """
845
+
846
+ def __init__(self, scale, num_feat):
847
+ m = []
848
+ if (scale & (scale - 1)) == 0: # scale = 2^n
849
+ for _ in range(int(math.log(scale, 2))):
850
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
851
+ m.append(nn.PixelShuffle(2))
852
+ elif scale == 3:
853
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
854
+ m.append(nn.PixelShuffle(3))
855
+ else:
856
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
857
+ super(Upsample, self).__init__(*m)
858
+
859
+
860
+ class UpsampleOneStep(nn.Sequential):
861
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
862
+ Used in lightweight SR to save parameters.
863
+
864
+ Args:
865
+ scale (int): Scale factor. Supported scales: 2^n and 3.
866
+ num_feat (int): Channel number of intermediate features.
867
+
868
+ """
869
+
870
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
871
+ self.num_feat = num_feat
872
+ self.input_resolution = input_resolution
873
+ m = []
874
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
875
+ m.append(nn.PixelShuffle(scale))
876
+ super(UpsampleOneStep, self).__init__(*m)
877
+
878
+
879
+ class HiT_SNG(nn.Module, PyTorchModelHubMixin):
880
+ """ HiT-SNG network.
881
+
882
+ Args:
883
+ img_size (int | tuple(int)): Input image size. Default 64
884
+ patch_size (int | tuple(int)): Patch size. Default: 1
885
+ in_chans (int): Number of input image channels. Default: 3
886
+ embed_dim (int): Patch embedding dimension. Default: 96
887
+ depths (tuple(int)): Depth of each Transformer block.
888
+ num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
889
+ base_win_size (tuple[int]): The height and width of the base window.
890
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
891
+ drop_rate (float): Dropout rate. Default: 0
892
+ value_drop_rate (float): Dropout ratio of value. Default: 0.0
893
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
894
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
895
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
896
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
897
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
898
+ upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
899
+ img_range (float): Image range. 1. or 255.
900
+ upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
901
+ resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
902
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
903
+ """
904
+
905
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
906
+ embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
907
+ base_win_size=[8,8], mlp_ratio=2.,
908
+ drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
909
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
910
+ use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
911
+ hier_win_ratios=[0.5,1,2,4,6,8],
912
+ **kwargs):
913
+ super(HiT_SNG, self).__init__()
914
+ num_in_ch = in_chans
915
+ num_out_ch = in_chans
916
+ num_feat = 64
917
+ self.img_range = img_range
918
+ if in_chans == 3:
919
+ rgb_mean = (0.4488, 0.4371, 0.4040)
920
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
921
+ else:
922
+ self.mean = torch.zeros(1, 1, 1, 1)
923
+ self.upscale = upscale
924
+ self.upsampler = upsampler
925
+ self.base_win_size = base_win_size
926
+
927
+ #####################################################################################################
928
+ ################################### 1, shallow feature extraction ###################################
929
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
930
+
931
+ #####################################################################################################
932
+ ################################### 2, deep feature extraction ######################################
933
+ self.num_layers = len(depths)
934
+ self.embed_dim = embed_dim
935
+ self.ape = ape
936
+ self.patch_norm = patch_norm
937
+ self.num_features = embed_dim
938
+ self.mlp_ratio = mlp_ratio
939
+
940
+ # split image into non-overlapping patches
941
+ self.patch_embed = PatchEmbed(
942
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
943
+ norm_layer=norm_layer if self.patch_norm else None)
944
+ num_patches = self.patch_embed.num_patches
945
+ patches_resolution = self.patch_embed.patches_resolution
946
+ self.patches_resolution = patches_resolution
947
+
948
+ # merge non-overlapping patches into image
949
+ self.patch_unembed = PatchUnEmbed(
950
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
951
+ norm_layer=norm_layer if self.patch_norm else None)
952
+
953
+ # absolute position embedding
954
+ if self.ape:
955
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
956
+ trunc_normal_(self.absolute_pos_embed, std=.02)
957
+
958
+ self.pos_drop = nn.Dropout(p=drop_rate)
959
+
960
+ # stochastic depth
961
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
962
+
963
+ # build Residual Hierarchical Transformer blocks (RHTB)
964
+ self.layers = nn.ModuleList()
965
+ for i_layer in range(self.num_layers):
966
+ layer = RHTB(dim=embed_dim,
967
+ input_resolution=(patches_resolution[0],
968
+ patches_resolution[1]),
969
+ depth=depths[i_layer],
970
+ num_heads=num_heads[i_layer],
971
+ base_win_size=base_win_size,
972
+ mlp_ratio=self.mlp_ratio,
973
+ drop=drop_rate, value_drop=value_drop_rate,
974
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
975
+ norm_layer=norm_layer,
976
+ downsample=None,
977
+ use_checkpoint=use_checkpoint,
978
+ img_size=img_size,
979
+ patch_size=patch_size,
980
+ resi_connection=resi_connection,
981
+ hier_win_ratios=hier_win_ratios
982
+ )
983
+ self.layers.append(layer)
984
+ self.norm = norm_layer(self.num_features)
985
+
986
+ # build the last conv layer in deep feature extraction
987
+ if resi_connection == '1conv':
988
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
989
+ elif resi_connection == '3conv':
990
+ # to save parameters and memory
991
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
992
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
993
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
994
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
995
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
996
+
997
+ #####################################################################################################
998
+ ################################ 3, high quality image reconstruction ################################
999
+ if self.upsampler == 'pixelshuffle':
1000
+ # for classical SR
1001
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
1002
+ nn.LeakyReLU(inplace=True))
1003
+ self.upsample = Upsample(upscale, num_feat)
1004
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1005
+ elif self.upsampler == 'pixelshuffledirect':
1006
+ # for lightweight SR (to save parameters)
1007
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
1008
+ (patches_resolution[0], patches_resolution[1]))
1009
+ elif self.upsampler == 'nearest+conv':
1010
+ # for real-world SR (less artifacts)
1011
+ assert self.upscale == 4, 'only support x4 now.'
1012
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
1013
+ nn.LeakyReLU(inplace=True))
1014
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1015
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1016
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1017
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1018
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1019
+ else:
1020
+ # for image denoising and JPEG compression artifact reduction
1021
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1022
+
1023
+ self.apply(self._init_weights)
1024
+
1025
+ def _init_weights(self, m):
1026
+ if isinstance(m, nn.Linear):
1027
+ trunc_normal_(m.weight, std=.02)
1028
+ if isinstance(m, nn.Linear) and m.bias is not None:
1029
+ nn.init.constant_(m.bias, 0)
1030
+ elif isinstance(m, nn.LayerNorm):
1031
+ nn.init.constant_(m.bias, 0)
1032
+ nn.init.constant_(m.weight, 1.0)
1033
+
1034
+ @torch.jit.ignore
1035
+ def no_weight_decay(self):
1036
+ return {'absolute_pos_embed'}
1037
+
1038
+ @torch.jit.ignore
1039
+ def no_weight_decay_keywords(self):
1040
+ return {'relative_position_bias_table'}
1041
+
1042
+
1043
+ def forward_features(self, x):
1044
+ x_size = (x.shape[2], x.shape[3])
1045
+ x = self.patch_embed(x)
1046
+ if self.ape:
1047
+ x = x + self.absolute_pos_embed
1048
+ x = self.pos_drop(x)
1049
+
1050
+ for layer in self.layers:
1051
+ x = layer(x, x_size)
1052
+
1053
+ x = self.norm(x) # B L C
1054
+ x = self.patch_unembed(x, x_size)
1055
+
1056
+ return x
1057
+
1058
+ def infer_image(self, image_path, cuda=True):
1059
+
1060
+ io_backend_opt = {'type':'disk'}
1061
+ self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
1062
+
1063
+ # load lq image
1064
+ lq_path = image_path
1065
+ img_bytes = self.file_client.get(lq_path, 'lq')
1066
+ img_lq = imfrombytes(img_bytes, float32=True)
1067
+
1068
+ # BGR to RGB, HWC to CHW, numpy to tensor
1069
+ x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
1070
+
1071
+ if cuda:
1072
+ x= x.cuda()
1073
+
1074
+ out = self(x)
1075
+
1076
+ if cuda:
1077
+ out = out.cpu()
1078
+
1079
+ out = tensor2img(out)
1080
+
1081
+ return out
1082
+
1083
+ def forward(self, x):
1084
+ H, W = x.shape[2:]
1085
+
1086
+ self.mean = self.mean.type_as(x)
1087
+ x = (x - self.mean) * self.img_range
1088
+
1089
+ if self.upsampler == 'pixelshuffle':
1090
+ # for classical SR
1091
+ x = self.conv_first(x)
1092
+ x = self.conv_after_body(self.forward_features(x)) + x
1093
+ x = self.conv_before_upsample(x)
1094
+ x = self.conv_last(self.upsample(x))
1095
+ elif self.upsampler == 'pixelshuffledirect':
1096
+ # for lightweight SR
1097
+ x = self.conv_first(x)
1098
+ x = self.conv_after_body(self.forward_features(x)) + x
1099
+ x = self.upsample(x)
1100
+ elif self.upsampler == 'nearest+conv':
1101
+ # for real-world SR
1102
+ x = self.conv_first(x)
1103
+ x = self.conv_after_body(self.forward_features(x)) + x
1104
+ x = self.conv_before_upsample(x)
1105
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1106
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1107
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1108
+ else:
1109
+ # for image denoising and JPEG compression artifact reduction
1110
+ x_first = self.conv_first(x)
1111
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1112
+ x = x + self.conv_last(res)
1113
+
1114
+ x = x / self.img_range + self.mean
1115
+
1116
+ return x[:, :, :H*self.upscale, :W*self.upscale]
1117
+
1118
+
1119
+ if __name__ == '__main__':
1120
+ upscale = 4
1121
+ base_win_size = [8, 8]
1122
+ height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
1123
+ width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
1124
+
1125
+ ## HiT-SIR
1126
+ model = HiT_SNG(upscale=4, img_size=(height, width),
1127
+ base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
1128
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
1129
+
1130
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
1131
+ print("params: ", params_num)
1132
+
hit_srf_arch.py ADDED
@@ -0,0 +1,947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+ import numpy as np
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ from utils import FileClient, imfrombytes, img2tensor, tensor2img
11
+
12
+ class DFE(nn.Module):
13
+ """ Dual Feature Extraction
14
+ Args:
15
+ in_features (int): Number of input channels.
16
+ out_features (int): Number of output channels.
17
+ """
18
+ def __init__(self, in_features, out_features):
19
+ super().__init__()
20
+
21
+ self.out_features = out_features
22
+
23
+ self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
24
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
+ nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
26
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
+ nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
28
+
29
+ self.linear = nn.Conv2d(in_features, out_features,1,1,0)
30
+
31
+ def forward(self, x, x_size):
32
+
33
+ B, L, C = x.shape
34
+ H, W = x_size
35
+ x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
36
+ x = self.conv(x) * self.linear(x)
37
+ x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
38
+
39
+ return x
40
+
41
+ class Mlp(nn.Module):
42
+ """ MLP-based Feed-Forward Network
43
+ Args:
44
+ in_features (int): Number of input channels.
45
+ hidden_features (int | None): Number of hidden channels. Default: None
46
+ out_features (int | None): Number of output channels. Default: None
47
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
48
+ drop (float): Dropout rate. Default: 0.0
49
+ """
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ class dwconv(nn.Module):
69
+ def __init__(self,hidden_features):
70
+ super(dwconv, self).__init__()
71
+ self.depthwise_conv = nn.Sequential(
72
+ nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2, dilation=1,
73
+ groups=hidden_features), nn.GELU())
74
+ self.hidden_features = hidden_features
75
+ def forward(self,x,x_size):
76
+ x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous() # b Ph*Pw c
77
+ x = self.depthwise_conv(x)
78
+ x = x.flatten(2).transpose(1, 2).contiguous()
79
+ return x
80
+
81
+ class ConvFFN(nn.Module):
82
+
83
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+ self.fc1 = nn.Linear(in_features, hidden_features)
88
+ self.act = act_layer()
89
+ self.dwconv = dwconv(hidden_features=hidden_features)
90
+ self.fc2 = nn.Linear(hidden_features, out_features)
91
+ self.drop = nn.Dropout(drop)
92
+
93
+
94
+ def forward(self, x,x_size):
95
+ x = self.fc1(x)
96
+ x = self.act(x)
97
+ x = x + self.dwconv(x,x_size)
98
+ x = self.drop(x)
99
+ x = self.fc2(x)
100
+ x = self.drop(x)
101
+ return x
102
+
103
+ def window_partition(x, window_size):
104
+ """
105
+ Args:
106
+ x: (B, H, W, C)
107
+ window_size (tuple): window size
108
+
109
+ Returns:
110
+ windows: (num_windows*B, window_size, window_size, C)
111
+ """
112
+ B, H, W, C = x.shape
113
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
114
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
115
+ return windows
116
+
117
+
118
+ def window_reverse(windows, window_size, H, W):
119
+ """
120
+ Args:
121
+ windows: (num_windows*B, window_size, window_size, C)
122
+ window_size (tuple): Window size
123
+ H (int): Height of image
124
+ W (int): Width of image
125
+
126
+ Returns:
127
+ x: (B, H, W, C)
128
+ """
129
+ B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
130
+ x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
131
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
132
+ return x
133
+
134
+ class DynamicPosBias(nn.Module):
135
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
136
+ """ Dynamic Relative Position Bias.
137
+ Args:
138
+ dim (int): Number of input channels.
139
+ num_heads (int): Number of heads for spatial self-correlation.
140
+ residual (bool): If True, use residual strage to connect conv.
141
+ """
142
+ def __init__(self, dim, num_heads, residual):
143
+ super().__init__()
144
+ self.residual = residual
145
+ self.num_heads = num_heads
146
+ self.pos_dim = dim // 4
147
+ self.pos_proj = nn.Linear(2, self.pos_dim)
148
+ self.pos1 = nn.Sequential(
149
+ nn.LayerNorm(self.pos_dim),
150
+ nn.ReLU(inplace=True),
151
+ nn.Linear(self.pos_dim, self.pos_dim),
152
+ )
153
+ self.pos2 = nn.Sequential(
154
+ nn.LayerNorm(self.pos_dim),
155
+ nn.ReLU(inplace=True),
156
+ nn.Linear(self.pos_dim, self.pos_dim)
157
+ )
158
+ self.pos3 = nn.Sequential(
159
+ nn.LayerNorm(self.pos_dim),
160
+ nn.ReLU(inplace=True),
161
+ nn.Linear(self.pos_dim, self.num_heads)
162
+ )
163
+ def forward(self, biases):
164
+ if self.residual:
165
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
166
+ pos = pos + self.pos1(pos)
167
+ pos = pos + self.pos2(pos)
168
+ pos = self.pos3(pos)
169
+ else:
170
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
171
+ return pos
172
+
173
+ class SCC(nn.Module):
174
+ """ Spatial-Channel Correlation.
175
+ Args:
176
+ dim (int): Number of input channels.
177
+ base_win_size (tuple[int]): The height and width of the base window.
178
+ window_size (tuple[int]): The height and width of the window.
179
+ num_heads (int): Number of heads for spatial self-correlation.
180
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
181
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
182
+ """
183
+
184
+ def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
185
+
186
+ super().__init__()
187
+ # parameters
188
+ self.dim = dim
189
+ self.window_size = window_size
190
+ self.num_heads = num_heads
191
+
192
+ # feature projection
193
+ self.qv = DFE(dim, dim)
194
+ self.proj = nn.Linear(dim, dim)
195
+
196
+ # dropout
197
+ self.value_drop = nn.Dropout(value_drop)
198
+ self.proj_drop = nn.Dropout(proj_drop)
199
+
200
+ # base window size
201
+ min_h = min(self.window_size[0], base_win_size[0])
202
+ min_w = min(self.window_size[1], base_win_size[1])
203
+ self.base_win_size = (min_h, min_w)
204
+
205
+ # normalization factor and spatial linear layer for S-SC
206
+ head_dim = dim // (2*num_heads)
207
+ self.scale = head_dim
208
+ self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
209
+
210
+ # define a parameter table of relative position bias
211
+ self.H_sp, self.W_sp = self.window_size
212
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
213
+
214
+ def spatial_linear_projection(self, x):
215
+ B, num_h, L, C = x.shape
216
+ H, W = self.window_size
217
+ map_H, map_W = self.base_win_size
218
+
219
+ x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
220
+ x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
221
+ return x
222
+
223
+ def spatial_self_correlation(self, q, v):
224
+
225
+ B, num_head, L, C = q.shape
226
+
227
+ # spatial projection
228
+ v = self.spatial_linear_projection(v)
229
+
230
+ # compute correlation map
231
+ corr_map = (q @ v.transpose(-2,-1)) / self.scale
232
+
233
+ # add relative position bias
234
+ # generate mother-set
235
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
236
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
237
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
238
+ rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
239
+ pos = self.pos(rpe_biases)
240
+
241
+ # select position bias
242
+ coords_h = torch.arange(self.H_sp, device=v.device)
243
+ coords_w = torch.arange(self.W_sp, device=v.device)
244
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
245
+ coords_flatten = torch.flatten(coords, 1)
246
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
247
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
248
+ relative_coords[:, :, 0] += self.H_sp - 1
249
+ relative_coords[:, :, 1] += self.W_sp - 1
250
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
251
+ relative_position_index = relative_coords.sum(-1)
252
+ relative_position_bias = pos[relative_position_index.view(-1)].view(
253
+ self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
254
+ relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
255
+ self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
256
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
257
+ corr_map = corr_map + relative_position_bias.unsqueeze(0)
258
+
259
+ # transformation
260
+ v_drop = self.value_drop(v)
261
+ x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
262
+
263
+ return x
264
+
265
+ def channel_self_correlation(self, q, v):
266
+
267
+ B, num_head, L, C = q.shape
268
+
269
+ # apply single head strategy
270
+ q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
271
+ v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
272
+
273
+ # compute correlation map
274
+ corr_map = (q.transpose(-2,-1) @ v) / L
275
+
276
+ # transformation
277
+ v_drop = self.value_drop(v)
278
+ x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
279
+
280
+ return x
281
+
282
+ def forward(self, x):
283
+ """
284
+ Args:
285
+ x: input features with shape of (B, H, W, C)
286
+ """
287
+ xB,xH,xW,xC = x.shape
288
+ qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
289
+
290
+ # window partition
291
+ qv = window_partition(qv, self.window_size)
292
+ qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
293
+
294
+ # qv splitting
295
+ B, L, C = qv.shape
296
+ qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
297
+ q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
298
+
299
+ # spatial self-correlation (S-SC)
300
+ x_spatial = self.spatial_self_correlation(q, v)
301
+ x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
302
+ x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
303
+
304
+ # channel self-correlation (C-SC)
305
+ x_channel = self.channel_self_correlation(q, v)
306
+ x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
307
+ x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
308
+
309
+ # spatial-channel information fusion
310
+ x = torch.cat([x_spatial, x_channel], -1)
311
+ x = self.proj_drop(self.proj(x))
312
+
313
+ return x
314
+
315
+ def extra_repr(self) -> str:
316
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
317
+
318
+
319
+ class HierarchicalTransformerBlock(nn.Module):
320
+ """ Hierarchical Transformer Block.
321
+ Args:
322
+ dim (int): Number of input channels.
323
+ input_resolution (tuple[int]): Input resulotion.
324
+ num_heads (int): Number of heads for spatial self-correlation.
325
+ base_win_size (tuple[int]): The height and width of the base window.
326
+ window_size (tuple[int]): The height and width of the window.
327
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
328
+ drop (float, optional): Dropout rate. Default: 0.0
329
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
330
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
331
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
332
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
+ """
334
+
335
+ def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
336
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
337
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
338
+ super().__init__()
339
+ self.dim = dim
340
+ self.input_resolution = input_resolution
341
+ self.num_heads = num_heads
342
+ self.window_size = window_size
343
+ self.mlp_ratio = mlp_ratio
344
+
345
+ # check window size
346
+ if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
347
+ assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
348
+ assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
349
+
350
+
351
+ self.norm1 = norm_layer(dim)
352
+ self.correlation = SCC(
353
+ dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
354
+ value_drop=value_drop, proj_drop=drop)
355
+
356
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
357
+ self.norm2 = norm_layer(dim)
358
+ mlp_hidden_dim = int(dim * mlp_ratio)
359
+ self.mlp = ConvFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
360
+ # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
361
+
362
+ def check_image_size(self, x, win_size):
363
+ x = x.permute(0,3,1,2).contiguous()
364
+ _, _, h, w = x.size()
365
+ mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
366
+ mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
367
+
368
+ if mod_pad_h >= h or mod_pad_w >= w:
369
+ pad_h, pad_w = h-1, w-1
370
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
371
+ else:
372
+ pad_h, pad_w = 0, 0
373
+
374
+ mod_pad_h = mod_pad_h - pad_h
375
+ mod_pad_w = mod_pad_w - pad_w
376
+
377
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
378
+ x = x.permute(0,2,3,1).contiguous()
379
+ return x
380
+
381
+ def forward(self, x, x_size, win_size):
382
+ H, W = x_size
383
+ B, L, C = x.shape
384
+
385
+ shortcut = x
386
+ x = x.view(B, H, W, C)
387
+
388
+ # padding
389
+ x = self.check_image_size(x, win_size)
390
+ _, H_pad, W_pad, _ = x.shape # shape after padding
391
+
392
+ x = self.correlation(x)
393
+
394
+ # unpad
395
+ x = x[:, :H, :W, :].contiguous()
396
+
397
+ # norm
398
+ x = x.view(B, H * W, C)
399
+ x = self.norm1(x)
400
+
401
+ # FFN
402
+ x = shortcut + self.drop_path(x)
403
+ x = x + self.drop_path(self.norm2(self.mlp(x, x_size)))
404
+
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
409
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
410
+
411
+
412
+ class PatchMerging(nn.Module):
413
+ """ Patch Merging Layer.
414
+ Args:
415
+ input_resolution (tuple[int]): Resolution of input feature.
416
+ dim (int): Number of input channels.
417
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
418
+ """
419
+
420
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
421
+ super().__init__()
422
+ self.input_resolution = input_resolution
423
+ self.dim = dim
424
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
425
+ self.norm = norm_layer(4 * dim)
426
+
427
+ def forward(self, x):
428
+ """
429
+ x: B, H*W, C
430
+ """
431
+ H, W = self.input_resolution
432
+ B, L, C = x.shape
433
+ assert L == H * W, "input feature has wrong size"
434
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
435
+
436
+ x = x.view(B, H, W, C)
437
+
438
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
439
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
440
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
441
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
442
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
443
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
444
+
445
+ x = self.norm(x)
446
+ x = self.reduction(x)
447
+
448
+ return x
449
+
450
+ def extra_repr(self) -> str:
451
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
452
+
453
+
454
+ class BasicLayer(nn.Module):
455
+ """ A basic Hierarchical Transformer layer for one stage.
456
+
457
+ Args:
458
+ dim (int): Number of input channels.
459
+ input_resolution (tuple[int]): Input resolution.
460
+ depth (int): Number of blocks.
461
+ num_heads (int): Number of heads for spatial self-correlation.
462
+ base_win_size (tuple[int]): The height and width of the base window.
463
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
464
+ drop (float, optional): Dropout rate. Default: 0.0
465
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
466
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
467
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
469
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
470
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
471
+ """
472
+
473
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
474
+ mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
475
+ downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
476
+
477
+ super().__init__()
478
+ self.dim = dim
479
+ self.input_resolution = input_resolution
480
+ self.depth = depth
481
+ self.use_checkpoint = use_checkpoint
482
+
483
+ self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
484
+ self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
485
+
486
+ # build blocks
487
+ self.blocks = nn.ModuleList([
488
+ HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
489
+ num_heads=num_heads,
490
+ base_win_size=base_win_size,
491
+ window_size=(self.win_hs[i], self.win_ws[i]),
492
+ mlp_ratio=mlp_ratio,
493
+ drop=drop, value_drop=value_drop,
494
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
495
+ norm_layer=norm_layer)
496
+ for i in range(depth)])
497
+
498
+ # patch merging layer
499
+ if downsample is not None:
500
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
501
+ else:
502
+ self.downsample = None
503
+
504
+ def forward(self, x, x_size):
505
+
506
+ i = 0
507
+ for blk in self.blocks:
508
+ if self.use_checkpoint:
509
+ x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
510
+ else:
511
+ x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
512
+ i = i + 1
513
+
514
+ if self.downsample is not None:
515
+ x = self.downsample(x)
516
+ return x
517
+
518
+ def extra_repr(self) -> str:
519
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
520
+
521
+
522
+ class RHTB(nn.Module):
523
+ """Residual Hierarchical Transformer Block (RHTB).
524
+ Args:
525
+ dim (int): Number of input channels.
526
+ input_resolution (tuple[int]): Input resolution.
527
+ depth (int): Number of blocks.
528
+ num_heads (int): Number of heads for spatial self-correlation.
529
+ base_win_size (tuple[int]): The height and width of the base window.
530
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
531
+ drop (float, optional): Dropout rate. Default: 0.0
532
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
533
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
534
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
535
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
536
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
537
+ img_size: Input image size.
538
+ patch_size: Patch size.
539
+ resi_connection: The convolutional block before residual connection.
540
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
541
+ """
542
+
543
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
544
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
545
+ downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
546
+ resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
547
+ super(RHTB, self).__init__()
548
+
549
+ self.dim = dim
550
+ self.input_resolution = input_resolution
551
+
552
+ self.residual_group = BasicLayer(dim=dim,
553
+ input_resolution=input_resolution,
554
+ depth=depth,
555
+ num_heads=num_heads,
556
+ base_win_size=base_win_size,
557
+ mlp_ratio=mlp_ratio,
558
+ drop=drop, value_drop=value_drop,
559
+ drop_path=drop_path,
560
+ norm_layer=norm_layer,
561
+ downsample=downsample,
562
+ use_checkpoint=use_checkpoint,
563
+ hier_win_ratios=hier_win_ratios)
564
+
565
+ if resi_connection == '1conv':
566
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
567
+ elif resi_connection == '3conv':
568
+ # to save parameters and memory
569
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
570
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
571
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
572
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
573
+
574
+ self.patch_embed = PatchEmbed(
575
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
576
+ norm_layer=None)
577
+
578
+ self.patch_unembed = PatchUnEmbed(
579
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
580
+ norm_layer=None)
581
+
582
+ def forward(self, x, x_size):
583
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
584
+
585
+
586
+ class PatchEmbed(nn.Module):
587
+ r""" Image to Patch Embedding
588
+
589
+ Args:
590
+ img_size (int): Image size. Default: 224.
591
+ patch_size (int): Patch token size. Default: 4.
592
+ in_chans (int): Number of input image channels. Default: 3.
593
+ embed_dim (int): Number of linear projection output channels. Default: 96.
594
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
595
+ """
596
+
597
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
598
+ super().__init__()
599
+ img_size = to_2tuple(img_size)
600
+ patch_size = to_2tuple(patch_size)
601
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
602
+ self.img_size = img_size
603
+ self.patch_size = patch_size
604
+ self.patches_resolution = patches_resolution
605
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
606
+
607
+ self.in_chans = in_chans
608
+ self.embed_dim = embed_dim
609
+
610
+ if norm_layer is not None:
611
+ self.norm = norm_layer(embed_dim)
612
+ else:
613
+ self.norm = None
614
+
615
+ def forward(self, x):
616
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
617
+ if self.norm is not None:
618
+ x = self.norm(x)
619
+ return x
620
+
621
+
622
+ class PatchUnEmbed(nn.Module):
623
+ r""" Image to Patch Unembedding
624
+
625
+ Args:
626
+ img_size (int): Image size. Default: 224.
627
+ patch_size (int): Patch token size. Default: 4.
628
+ in_chans (int): Number of input image channels. Default: 3.
629
+ embed_dim (int): Number of linear projection output channels. Default: 96.
630
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
631
+ """
632
+
633
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
634
+ super().__init__()
635
+ img_size = to_2tuple(img_size)
636
+ patch_size = to_2tuple(patch_size)
637
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
638
+ self.img_size = img_size
639
+ self.patch_size = patch_size
640
+ self.patches_resolution = patches_resolution
641
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
642
+
643
+ self.in_chans = in_chans
644
+ self.embed_dim = embed_dim
645
+
646
+ def forward(self, x, x_size):
647
+ B, HW, C = x.shape
648
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
649
+ return x
650
+
651
+
652
+ class Upsample(nn.Sequential):
653
+ """Upsample module.
654
+
655
+ Args:
656
+ scale (int): Scale factor. Supported scales: 2^n and 3.
657
+ num_feat (int): Channel number of intermediate features.
658
+ """
659
+
660
+ def __init__(self, scale, num_feat):
661
+ m = []
662
+ if (scale & (scale - 1)) == 0: # scale = 2^n
663
+ for _ in range(int(math.log(scale, 2))):
664
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
665
+ m.append(nn.PixelShuffle(2))
666
+ elif scale == 3:
667
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
668
+ m.append(nn.PixelShuffle(3))
669
+ else:
670
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
671
+ super(Upsample, self).__init__(*m)
672
+
673
+
674
+ class UpsampleOneStep(nn.Sequential):
675
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
676
+ Used in lightweight SR to save parameters.
677
+
678
+ Args:
679
+ scale (int): Scale factor. Supported scales: 2^n and 3.
680
+ num_feat (int): Channel number of intermediate features.
681
+
682
+ """
683
+
684
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
685
+ self.num_feat = num_feat
686
+ self.input_resolution = input_resolution
687
+ m = []
688
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
689
+ m.append(nn.PixelShuffle(scale))
690
+ super(UpsampleOneStep, self).__init__(*m)
691
+
692
+
693
+ class HiT_SRF(nn.Module, PyTorchModelHubMixin):
694
+ """ HiT-SRF network.
695
+
696
+ Args:
697
+ img_size (int | tuple(int)): Input image size. Default 64
698
+ patch_size (int | tuple(int)): Patch size. Default: 1
699
+ in_chans (int): Number of input image channels. Default: 3
700
+ embed_dim (int): Patch embedding dimension. Default: 96
701
+ depths (tuple(int)): Depth of each Transformer block.
702
+ num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
703
+ base_win_size (tuple[int]): The height and width of the base window.
704
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
705
+ drop_rate (float): Dropout rate. Default: 0
706
+ value_drop_rate (float): Dropout ratio of value. Default: 0.0
707
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
708
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
709
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
710
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
711
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
712
+ upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
713
+ img_range (float): Image range. 1. or 255.
714
+ upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
715
+ resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
716
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
717
+ """
718
+
719
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
720
+ embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
721
+ base_win_size=[8,8], mlp_ratio=2.,
722
+ drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
723
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
724
+ use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
725
+ hier_win_ratios=[0.5,1,2,4,6,8],
726
+ **kwargs):
727
+ super(HiT_SRF, self).__init__()
728
+ num_in_ch = in_chans
729
+ num_out_ch = in_chans
730
+ num_feat = 64
731
+ self.img_range = img_range
732
+ if in_chans == 3:
733
+ rgb_mean = (0.4488, 0.4371, 0.4040)
734
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
735
+ else:
736
+ self.mean = torch.zeros(1, 1, 1, 1)
737
+ self.upscale = upscale
738
+ self.upsampler = upsampler
739
+ self.base_win_size = base_win_size
740
+
741
+ #####################################################################################################
742
+ ################################### 1, shallow feature extraction ###################################
743
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
744
+
745
+ #####################################################################################################
746
+ ################################### 2, deep feature extraction ######################################
747
+ self.num_layers = len(depths)
748
+ self.embed_dim = embed_dim
749
+ self.ape = ape
750
+ self.patch_norm = patch_norm
751
+ self.num_features = embed_dim
752
+ self.mlp_ratio = mlp_ratio
753
+
754
+ # split image into non-overlapping patches
755
+ self.patch_embed = PatchEmbed(
756
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
757
+ norm_layer=norm_layer if self.patch_norm else None)
758
+ num_patches = self.patch_embed.num_patches
759
+ patches_resolution = self.patch_embed.patches_resolution
760
+ self.patches_resolution = patches_resolution
761
+
762
+ # merge non-overlapping patches into image
763
+ self.patch_unembed = PatchUnEmbed(
764
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
765
+ norm_layer=norm_layer if self.patch_norm else None)
766
+
767
+ # absolute position embedding
768
+ if self.ape:
769
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
770
+ trunc_normal_(self.absolute_pos_embed, std=.02)
771
+
772
+ self.pos_drop = nn.Dropout(p=drop_rate)
773
+
774
+ # stochastic depth
775
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
776
+
777
+ # build Residual Hierarchical Transformer blocks (RHTB)
778
+ self.layers = nn.ModuleList()
779
+ for i_layer in range(self.num_layers):
780
+ layer = RHTB(dim=embed_dim,
781
+ input_resolution=(patches_resolution[0],
782
+ patches_resolution[1]),
783
+ depth=depths[i_layer],
784
+ num_heads=num_heads[i_layer],
785
+ base_win_size=base_win_size,
786
+ mlp_ratio=self.mlp_ratio,
787
+ drop=drop_rate, value_drop=value_drop_rate,
788
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
789
+ norm_layer=norm_layer,
790
+ downsample=None,
791
+ use_checkpoint=use_checkpoint,
792
+ img_size=img_size,
793
+ patch_size=patch_size,
794
+ resi_connection=resi_connection,
795
+ hier_win_ratios=hier_win_ratios
796
+ )
797
+ self.layers.append(layer)
798
+ self.norm = norm_layer(self.num_features)
799
+
800
+ # build the last conv layer in deep feature extraction
801
+ if resi_connection == '1conv':
802
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
803
+ elif resi_connection == '3conv':
804
+ # to save parameters and memory
805
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
806
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
807
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
808
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
809
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
810
+
811
+ #####################################################################################################
812
+ ################################ 3, high quality image reconstruction ################################
813
+ if self.upsampler == 'pixelshuffle':
814
+ # for classical SR
815
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
816
+ nn.LeakyReLU(inplace=True))
817
+ self.upsample = Upsample(upscale, num_feat)
818
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
819
+ elif self.upsampler == 'pixelshuffledirect':
820
+ # for lightweight SR (to save parameters)
821
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
822
+ (patches_resolution[0], patches_resolution[1]))
823
+ elif self.upsampler == 'nearest+conv':
824
+ # for real-world SR (less artifacts)
825
+ assert self.upscale == 4, 'only support x4 now.'
826
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
827
+ nn.LeakyReLU(inplace=True))
828
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
829
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
830
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
831
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
832
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
833
+ else:
834
+ # for image denoising and JPEG compression artifact reduction
835
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
836
+
837
+ self.apply(self._init_weights)
838
+
839
+ def _init_weights(self, m):
840
+ if isinstance(m, nn.Linear):
841
+ trunc_normal_(m.weight, std=.02)
842
+ if isinstance(m, nn.Linear) and m.bias is not None:
843
+ nn.init.constant_(m.bias, 0)
844
+ elif isinstance(m, nn.LayerNorm):
845
+ nn.init.constant_(m.bias, 0)
846
+ nn.init.constant_(m.weight, 1.0)
847
+
848
+ @torch.jit.ignore
849
+ def no_weight_decay(self):
850
+ return {'absolute_pos_embed'}
851
+
852
+ @torch.jit.ignore
853
+ def no_weight_decay_keywords(self):
854
+ return {'relative_position_bias_table'}
855
+
856
+
857
+ def forward_features(self, x):
858
+ x_size = (x.shape[2], x.shape[3])
859
+ x = self.patch_embed(x)
860
+ if self.ape:
861
+ x = x + self.absolute_pos_embed
862
+ x = self.pos_drop(x)
863
+
864
+ for layer in self.layers:
865
+ x = layer(x, x_size)
866
+
867
+ x = self.norm(x) # B L C
868
+ x = self.patch_unembed(x, x_size)
869
+
870
+ return x
871
+
872
+ def infer_image(self, image_path, cuda=True):
873
+
874
+ io_backend_opt = {'type':'disk'}
875
+ self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
876
+
877
+ # load lq image
878
+ lq_path = image_path
879
+ img_bytes = self.file_client.get(lq_path, 'lq')
880
+ img_lq = imfrombytes(img_bytes, float32=True)
881
+
882
+ # BGR to RGB, HWC to CHW, numpy to tensor
883
+ x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
884
+
885
+ if cuda:
886
+ x= x.cuda()
887
+
888
+ out = self(x)
889
+
890
+ if cuda:
891
+ out = out.cpu()
892
+
893
+ out = tensor2img(out)
894
+
895
+ return out
896
+
897
+ def forward(self, x):
898
+ H, W = x.shape[2:]
899
+
900
+ self.mean = self.mean.type_as(x)
901
+ x = (x - self.mean) * self.img_range
902
+
903
+ if self.upsampler == 'pixelshuffle':
904
+ # for classical SR
905
+ x = self.conv_first(x)
906
+ x = self.conv_after_body(self.forward_features(x)) + x
907
+ x = self.conv_before_upsample(x)
908
+ x = self.conv_last(self.upsample(x))
909
+ elif self.upsampler == 'pixelshuffledirect':
910
+ # for lightweight SR
911
+ x = self.conv_first(x)
912
+ x = self.conv_after_body(self.forward_features(x)) + x
913
+ x = self.upsample(x)
914
+ elif self.upsampler == 'nearest+conv':
915
+ # for real-world SR
916
+ x = self.conv_first(x)
917
+ x = self.conv_after_body(self.forward_features(x)) + x
918
+ x = self.conv_before_upsample(x)
919
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
920
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
921
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
922
+ else:
923
+ # for image denoising and JPEG compression artifact reduction
924
+ x_first = self.conv_first(x)
925
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
926
+ x = x + self.conv_last(res)
927
+
928
+ x = x / self.img_range + self.mean
929
+
930
+ return x[:, :, :H*self.upscale, :W*self.upscale]
931
+
932
+
933
+ if __name__ == '__main__':
934
+ upscale = 4
935
+ base_win_size = [8, 8]
936
+ height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
937
+ width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
938
+
939
+ ## HiT-SIR
940
+ model = HiT_SRF(upscale=4, img_size=(height, width),
941
+ base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
942
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
943
+
944
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
945
+ print("params: ", params_num)
946
+
947
+
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ addict
2
+ future
3
+ lmdb
4
+ numpy>=1.17
5
+ opencv-python
6
+ Pillow
7
+ pyyaml
8
+ requests
9
+ scikit-image
10
+ scipy
11
+ tb-nightly
12
+ tqdm
13
+ yapf
14
+ timm
15
+ einops
16
+ h5py
17
+ six
18
+ huggingface_hub
utils/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .file_client import FileClient
2
+ from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
3
+ from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
4
+ from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
5
+
6
+ __all__ = [
7
+ # file_client.py
8
+ 'FileClient',
9
+ # img_util.py
10
+ 'img2tensor',
11
+ 'tensor2img',
12
+ 'imfrombytes',
13
+ 'imwrite',
14
+ 'crop_border',
15
+ # logger.py
16
+ 'MessageLogger',
17
+ 'AvgTimer',
18
+ 'init_tb_logger',
19
+ 'init_wandb_logger',
20
+ 'get_root_logger',
21
+ 'get_env_info',
22
+ # misc.py
23
+ 'set_random_seed',
24
+ 'get_time_str',
25
+ 'mkdir_and_rename',
26
+ 'make_exp_dirs',
27
+ 'scandir',
28
+ 'check_resume',
29
+ 'sizeof_fmt',
30
+ ]
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (854 Bytes). View file
 
utils/__pycache__/dist_util.cpython-38.pyc ADDED
Binary file (2.6 kB). View file
 
utils/__pycache__/file_client.cpython-38.pyc ADDED
Binary file (6.5 kB). View file
 
utils/__pycache__/img_util.cpython-38.pyc ADDED
Binary file (6.12 kB). View file
 
utils/__pycache__/logger.cpython-38.pyc ADDED
Binary file (6.94 kB). View file
 
utils/__pycache__/matlab_functions.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
utils/__pycache__/misc.cpython-38.pyc ADDED
Binary file (4.37 kB). View file
 
utils/__pycache__/options.cpython-38.pyc ADDED
Binary file (5.11 kB). View file
 
utils/__pycache__/registry.cpython-38.pyc ADDED
Binary file (2.61 kB). View file
 
utils/dist_util.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2
+ import functools
3
+ import os
4
+ import subprocess
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.multiprocessing as mp
8
+
9
+
10
+ def init_dist(launcher, backend='nccl', **kwargs):
11
+ if mp.get_start_method(allow_none=True) is None:
12
+ mp.set_start_method('spawn')
13
+ if launcher == 'pytorch':
14
+ _init_dist_pytorch(backend, **kwargs)
15
+ elif launcher == 'slurm':
16
+ _init_dist_slurm(backend, **kwargs)
17
+ else:
18
+ raise ValueError(f'Invalid launcher type: {launcher}')
19
+
20
+
21
+ def _init_dist_pytorch(backend, **kwargs):
22
+ rank = int(os.environ['RANK'])
23
+ num_gpus = torch.cuda.device_count()
24
+ torch.cuda.set_device(rank % num_gpus)
25
+ dist.init_process_group(backend=backend, **kwargs)
26
+
27
+
28
+ def _init_dist_slurm(backend, port=None):
29
+ """Initialize slurm distributed training environment.
30
+
31
+ If argument ``port`` is not specified, then the master port will be system
32
+ environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33
+ environment variable, then a default port ``29500`` will be used.
34
+
35
+ Args:
36
+ backend (str): Backend of torch.distributed.
37
+ port (int, optional): Master port. Defaults to None.
38
+ """
39
+ proc_id = int(os.environ['SLURM_PROCID'])
40
+ ntasks = int(os.environ['SLURM_NTASKS'])
41
+ node_list = os.environ['SLURM_NODELIST']
42
+ num_gpus = torch.cuda.device_count()
43
+ torch.cuda.set_device(proc_id % num_gpus)
44
+ addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
45
+ # specify master port
46
+ if port is not None:
47
+ os.environ['MASTER_PORT'] = str(port)
48
+ elif 'MASTER_PORT' in os.environ:
49
+ pass # use MASTER_PORT in the environment variable
50
+ else:
51
+ # 29500 is torch.distributed default port
52
+ os.environ['MASTER_PORT'] = '29500'
53
+ os.environ['MASTER_ADDR'] = addr
54
+ os.environ['WORLD_SIZE'] = str(ntasks)
55
+ os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
56
+ os.environ['RANK'] = str(proc_id)
57
+ dist.init_process_group(backend=backend)
58
+
59
+
60
+ def get_dist_info():
61
+ if dist.is_available():
62
+ initialized = dist.is_initialized()
63
+ else:
64
+ initialized = False
65
+ if initialized:
66
+ rank = dist.get_rank()
67
+ world_size = dist.get_world_size()
68
+ else:
69
+ rank = 0
70
+ world_size = 1
71
+ return rank, world_size
72
+
73
+
74
+ def master_only(func):
75
+
76
+ @functools.wraps(func)
77
+ def wrapper(*args, **kwargs):
78
+ rank, _ = get_dist_info()
79
+ if rank == 0:
80
+ return func(*args, **kwargs)
81
+
82
+ return wrapper
utils/file_client.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2
+ from abc import ABCMeta, abstractmethod
3
+
4
+
5
+ class BaseStorageBackend(metaclass=ABCMeta):
6
+ """Abstract class of storage backends.
7
+
8
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
9
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10
+ as texts.
11
+ """
12
+
13
+ @abstractmethod
14
+ def get(self, filepath):
15
+ pass
16
+
17
+ @abstractmethod
18
+ def get_text(self, filepath):
19
+ pass
20
+
21
+
22
+ class MemcachedBackend(BaseStorageBackend):
23
+ """Memcached storage backend.
24
+
25
+ Attributes:
26
+ server_list_cfg (str): Config file for memcached server list.
27
+ client_cfg (str): Config file for memcached client.
28
+ sys_path (str | None): Additional path to be appended to `sys.path`.
29
+ Default: None.
30
+ """
31
+
32
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33
+ if sys_path is not None:
34
+ import sys
35
+ sys.path.append(sys_path)
36
+ try:
37
+ import mc
38
+ except ImportError:
39
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
40
+
41
+ self.server_list_cfg = server_list_cfg
42
+ self.client_cfg = client_cfg
43
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
44
+ # mc.pyvector servers as a point which points to a memory cache
45
+ self._mc_buffer = mc.pyvector()
46
+
47
+ def get(self, filepath):
48
+ filepath = str(filepath)
49
+ import mc
50
+ self._client.Get(filepath, self._mc_buffer)
51
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
52
+ return value_buf
53
+
54
+ def get_text(self, filepath):
55
+ raise NotImplementedError
56
+
57
+
58
+ class HardDiskBackend(BaseStorageBackend):
59
+ """Raw hard disks storage backend."""
60
+
61
+ def get(self, filepath):
62
+ filepath = str(filepath)
63
+ with open(filepath, 'rb') as f:
64
+ value_buf = f.read()
65
+ return value_buf
66
+
67
+ def get_text(self, filepath):
68
+ filepath = str(filepath)
69
+ with open(filepath, 'r') as f:
70
+ value_buf = f.read()
71
+ return value_buf
72
+
73
+
74
+ class LmdbBackend(BaseStorageBackend):
75
+ """Lmdb storage backend.
76
+
77
+ Args:
78
+ db_paths (str | list[str]): Lmdb database paths.
79
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
80
+ readonly (bool, optional): Lmdb environment parameter. If True,
81
+ disallow any write operations. Default: True.
82
+ lock (bool, optional): Lmdb environment parameter. If False, when
83
+ concurrent access occurs, do not lock the database. Default: False.
84
+ readahead (bool, optional): Lmdb environment parameter. If False,
85
+ disable the OS filesystem readahead mechanism, which may improve
86
+ random read performance when a database is larger than RAM.
87
+ Default: False.
88
+
89
+ Attributes:
90
+ db_paths (list): Lmdb database path.
91
+ _client (list): A list of several lmdb envs.
92
+ """
93
+
94
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
95
+ try:
96
+ import lmdb
97
+ except ImportError:
98
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
99
+
100
+ if isinstance(client_keys, str):
101
+ client_keys = [client_keys]
102
+
103
+ if isinstance(db_paths, list):
104
+ self.db_paths = [str(v) for v in db_paths]
105
+ elif isinstance(db_paths, str):
106
+ self.db_paths = [str(db_paths)]
107
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
108
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
109
+
110
+ self._client = {}
111
+ for client, path in zip(client_keys, self.db_paths):
112
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
113
+
114
+ def get(self, filepath, client_key):
115
+ """Get values according to the filepath from one lmdb named client_key.
116
+
117
+ Args:
118
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
119
+ client_key (str): Used for distinguishing different lmdb envs.
120
+ """
121
+ filepath = str(filepath)
122
+ assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
123
+ client = self._client[client_key]
124
+ with client.begin(write=False) as txn:
125
+ value_buf = txn.get(filepath.encode('ascii'))
126
+ return value_buf
127
+
128
+ def get_text(self, filepath):
129
+ raise NotImplementedError
130
+
131
+
132
+ class FileClient(object):
133
+ """A general file client to access files in different backend.
134
+
135
+ The client loads a file or text in a specified backend from its path
136
+ and return it as a binary file. it can also register other backend
137
+ accessor with a given name and backend class.
138
+
139
+ Attributes:
140
+ backend (str): The storage backend type. Options are "disk",
141
+ "memcached" and "lmdb".
142
+ client (:obj:`BaseStorageBackend`): The backend object.
143
+ """
144
+
145
+ _backends = {
146
+ 'disk': HardDiskBackend,
147
+ 'memcached': MemcachedBackend,
148
+ 'lmdb': LmdbBackend,
149
+ }
150
+
151
+ def __init__(self, backend='disk', **kwargs):
152
+ if backend not in self._backends:
153
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
154
+ f' are {list(self._backends.keys())}')
155
+ self.backend = backend
156
+ self.client = self._backends[backend](**kwargs)
157
+
158
+ def get(self, filepath, client_key='default'):
159
+ # client_key is used only for lmdb, where different fileclients have
160
+ # different lmdb environments.
161
+ if self.backend == 'lmdb':
162
+ return self.client.get(filepath, client_key)
163
+ else:
164
+ return self.client.get(filepath)
165
+
166
+ def get_text(self, filepath):
167
+ return self.client.get_text(filepath)
utils/img_util.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import os
5
+ import torch
6
+ from torchvision.utils import make_grid
7
+
8
+
9
+ def img2tensor(imgs, bgr2rgb=True, float32=True):
10
+ """Numpy array to tensor.
11
+
12
+ Args:
13
+ imgs (list[ndarray] | ndarray): Input images.
14
+ bgr2rgb (bool): Whether to change bgr to rgb.
15
+ float32 (bool): Whether to change to float32.
16
+
17
+ Returns:
18
+ list[tensor] | tensor: Tensor images. If returned results only have
19
+ one element, just return tensor.
20
+ """
21
+
22
+ def _totensor(img, bgr2rgb, float32):
23
+ if img.shape[2] == 3 and bgr2rgb:
24
+ if img.dtype == 'float64':
25
+ img = img.astype('float32')
26
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27
+ img = torch.from_numpy(img.transpose(2, 0, 1))
28
+ if float32:
29
+ img = img.float()
30
+ return img
31
+
32
+ if isinstance(imgs, list):
33
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
34
+ else:
35
+ return _totensor(imgs, bgr2rgb, float32)
36
+
37
+
38
+ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39
+ """Convert torch Tensors into image numpy arrays.
40
+
41
+ After clamping to [min, max], values will be normalized to [0, 1].
42
+
43
+ Args:
44
+ tensor (Tensor or list[Tensor]): Accept shapes:
45
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46
+ 2) 3D Tensor of shape (3/1 x H x W);
47
+ 3) 2D Tensor of shape (H x W).
48
+ Tensor channel should be in RGB order.
49
+ rgb2bgr (bool): Whether to change rgb to bgr.
50
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
51
+ to uint8 type with range [0, 255]; otherwise, float type with
52
+ range [0, 1]. Default: ``np.uint8``.
53
+ min_max (tuple[int]): min and max values for clamp.
54
+
55
+ Returns:
56
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57
+ shape (H x W). The channel order is BGR.
58
+ """
59
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61
+
62
+ if torch.is_tensor(tensor):
63
+ tensor = [tensor]
64
+ result = []
65
+ for _tensor in tensor:
66
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68
+
69
+ n_dim = _tensor.dim()
70
+ if n_dim == 4:
71
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72
+ img_np = img_np.transpose(1, 2, 0)
73
+ if rgb2bgr:
74
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75
+ elif n_dim == 3:
76
+ img_np = _tensor.numpy()
77
+ img_np = img_np.transpose(1, 2, 0)
78
+ if img_np.shape[2] == 1: # gray image
79
+ img_np = np.squeeze(img_np, axis=2)
80
+ else:
81
+ if rgb2bgr:
82
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83
+ elif n_dim == 2:
84
+ img_np = _tensor.numpy()
85
+ else:
86
+ raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87
+ if out_type == np.uint8:
88
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89
+ img_np = (img_np * 255.0).round()
90
+ img_np = img_np.astype(out_type)
91
+ result.append(img_np)
92
+ if len(result) == 1:
93
+ result = result[0]
94
+ return result
95
+
96
+
97
+ def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98
+ """This implementation is slightly faster than tensor2img.
99
+ It now only supports torch tensor with shape (1, c, h, w).
100
+
101
+ Args:
102
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104
+ min_max (tuple[int]): min and max values for clamp.
105
+ """
106
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108
+ output = output.type(torch.uint8).cpu().numpy()
109
+ if rgb2bgr:
110
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111
+ return output
112
+
113
+
114
+ def imfrombytes(content, flag='color', float32=False):
115
+ """Read an image from bytes.
116
+
117
+ Args:
118
+ content (bytes): Image bytes got from files or other streams.
119
+ flag (str): Flags specifying the color type of a loaded image,
120
+ candidates are `color`, `grayscale` and `unchanged`.
121
+ float32 (bool): Whether to change to float32., If True, will also norm
122
+ to [0, 1]. Default: False.
123
+
124
+ Returns:
125
+ ndarray: Loaded image array.
126
+ """
127
+ img_np = np.frombuffer(content, np.uint8)
128
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129
+ img = cv2.imdecode(img_np, imread_flags[flag])
130
+ if float32:
131
+ img = img.astype(np.float32) / 255.
132
+ return img
133
+
134
+
135
+ def imwrite(img, file_path, params=None, auto_mkdir=True):
136
+ """Write image to file.
137
+
138
+ Args:
139
+ img (ndarray): Image array to be written.
140
+ file_path (str): Image file path.
141
+ params (None or list): Same as opencv's :func:`imwrite` interface.
142
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143
+ whether to create it automatically.
144
+
145
+ Returns:
146
+ bool: Successful or not.
147
+ """
148
+ if auto_mkdir:
149
+ dir_name = os.path.abspath(os.path.dirname(file_path))
150
+ os.makedirs(dir_name, exist_ok=True)
151
+ ok = cv2.imwrite(file_path, img, params)
152
+ if not ok:
153
+ raise IOError('Failed in writing images.')
154
+
155
+
156
+ def crop_border(imgs, crop_border):
157
+ """Crop borders of images.
158
+
159
+ Args:
160
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161
+ crop_border (int): Crop border for each end of height and weight.
162
+
163
+ Returns:
164
+ list[ndarray]: Cropped images.
165
+ """
166
+ if crop_border == 0:
167
+ return imgs
168
+ else:
169
+ if isinstance(imgs, list):
170
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171
+ else:
172
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
utils/logger.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import time
4
+
5
+ from .dist_util import get_dist_info, master_only
6
+
7
+ initialized_logger = {}
8
+
9
+
10
+ class AvgTimer():
11
+
12
+ def __init__(self, window=200):
13
+ self.window = window # average window
14
+ self.current_time = 0
15
+ self.total_time = 0
16
+ self.count = 0
17
+ self.avg_time = 0
18
+ self.start()
19
+
20
+ def start(self):
21
+ self.start_time = self.tic = time.time()
22
+
23
+ def record(self):
24
+ self.count += 1
25
+ self.toc = time.time()
26
+ self.current_time = self.toc - self.tic
27
+ self.total_time += self.current_time
28
+ # calculate average time
29
+ self.avg_time = self.total_time / self.count
30
+
31
+ # reset
32
+ if self.count > self.window:
33
+ self.count = 0
34
+ self.total_time = 0
35
+
36
+ self.tic = time.time()
37
+
38
+ def get_current_time(self):
39
+ return self.current_time
40
+
41
+ def get_avg_time(self):
42
+ return self.avg_time
43
+
44
+
45
+ class MessageLogger():
46
+ """Message logger for printing.
47
+
48
+ Args:
49
+ opt (dict): Config. It contains the following keys:
50
+ name (str): Exp name.
51
+ logger (dict): Contains 'print_freq' (str) for logger interval.
52
+ train (dict): Contains 'total_iter' (int) for total iters.
53
+ use_tb_logger (bool): Use tensorboard logger.
54
+ start_iter (int): Start iter. Default: 1.
55
+ tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
56
+ """
57
+
58
+ def __init__(self, opt, start_iter=1, tb_logger=None):
59
+ self.exp_name = opt['name']
60
+ self.interval = opt['logger']['print_freq']
61
+ self.start_iter = start_iter
62
+ self.max_iters = opt['train']['total_iter']
63
+ self.use_tb_logger = opt['logger']['use_tb_logger']
64
+ self.tb_logger = tb_logger
65
+ self.start_time = time.time()
66
+ self.logger = get_root_logger()
67
+
68
+ def reset_start_time(self):
69
+ self.start_time = time.time()
70
+
71
+ @master_only
72
+ def __call__(self, log_vars):
73
+ """Format logging message.
74
+
75
+ Args:
76
+ log_vars (dict): It contains the following keys:
77
+ epoch (int): Epoch number.
78
+ iter (int): Current iter.
79
+ lrs (list): List for learning rates.
80
+
81
+ time (float): Iter time.
82
+ data_time (float): Data time for each iter.
83
+ """
84
+ # epoch, iter, learning rates
85
+ epoch = log_vars.pop('epoch')
86
+ current_iter = log_vars.pop('iter')
87
+ lrs = log_vars.pop('lrs')
88
+
89
+ message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
90
+ for v in lrs:
91
+ message += f'{v:.3e},'
92
+ message += ')] '
93
+
94
+ # time and estimated time
95
+ if 'time' in log_vars.keys():
96
+ iter_time = log_vars.pop('time')
97
+ data_time = log_vars.pop('data_time')
98
+
99
+ total_time = time.time() - self.start_time
100
+ time_sec_avg = total_time / (current_iter - self.start_iter + 1)
101
+ eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
102
+ eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
103
+ message += f'[eta: {eta_str}, '
104
+ message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
105
+
106
+ # other items, especially losses
107
+ for k, v in log_vars.items():
108
+ message += f'{k}: {v:.4e} '
109
+ # tensorboard logger
110
+ if self.use_tb_logger and 'debug' not in self.exp_name:
111
+ if k.startswith('l_'):
112
+ self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
113
+ else:
114
+ self.tb_logger.add_scalar(k, v, current_iter)
115
+ self.logger.info(message)
116
+
117
+
118
+ @master_only
119
+ def init_tb_logger(log_dir):
120
+ from torch.utils.tensorboard import SummaryWriter
121
+ tb_logger = SummaryWriter(log_dir=log_dir)
122
+ return tb_logger
123
+
124
+
125
+ @master_only
126
+ def init_wandb_logger(opt):
127
+ """We now only use wandb to sync tensorboard log."""
128
+ import wandb
129
+ logger = get_root_logger()
130
+
131
+ project = opt['logger']['wandb']['project']
132
+ resume_id = opt['logger']['wandb'].get('resume_id')
133
+ if resume_id:
134
+ wandb_id = resume_id
135
+ resume = 'allow'
136
+ logger.warning(f'Resume wandb logger with id={wandb_id}.')
137
+ else:
138
+ wandb_id = wandb.util.generate_id()
139
+ resume = 'never'
140
+
141
+ wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
142
+
143
+ logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
144
+
145
+
146
+ def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
147
+ """Get the root logger.
148
+
149
+ The logger will be initialized if it has not been initialized. By default a
150
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
151
+ also be added.
152
+
153
+ Args:
154
+ logger_name (str): root logger name. Default: 'basicsr'.
155
+ log_file (str | None): The log filename. If specified, a FileHandler
156
+ will be added to the root logger.
157
+ log_level (int): The root logger level. Note that only the process of
158
+ rank 0 is affected, while other processes will set the level to
159
+ "Error" and be silent most of the time.
160
+
161
+ Returns:
162
+ logging.Logger: The root logger.
163
+ """
164
+ logger = logging.getLogger(logger_name)
165
+ # if the logger has been initialized, just return it
166
+ if logger_name in initialized_logger:
167
+ return logger
168
+
169
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
170
+ stream_handler = logging.StreamHandler()
171
+ stream_handler.setFormatter(logging.Formatter(format_str))
172
+ logger.addHandler(stream_handler)
173
+ logger.propagate = False
174
+ rank, _ = get_dist_info()
175
+ if rank != 0:
176
+ logger.setLevel('ERROR')
177
+ elif log_file is not None:
178
+ logger.setLevel(log_level)
179
+ # add file handler
180
+ file_handler = logging.FileHandler(log_file, 'w')
181
+ file_handler.setFormatter(logging.Formatter(format_str))
182
+ file_handler.setLevel(log_level)
183
+ logger.addHandler(file_handler)
184
+ initialized_logger[logger_name] = True
185
+ return logger
186
+
187
+
188
+ def get_env_info():
189
+ """Get environment information.
190
+
191
+ Currently, only log the software version.
192
+ """
193
+ import torch
194
+ import torchvision
195
+
196
+ from basicsr.version import __version__
197
+ msg = r"""
198
+ ____ _ _____ ____
199
+ / __ ) ____ _ _____ (_)_____/ ___/ / __ \
200
+ / __ |/ __ `// ___// // ___/\__ \ / /_/ /
201
+ / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
202
+ /_____/ \__,_//____//_/ \___//____//_/ |_|
203
+ ______ __ __ __ __
204
+ / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
205
+ / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
206
+ / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
207
+ \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
208
+ """
209
+ msg += ('\nVersion Information: '
210
+ f'\n\tBasicSR: {__version__}'
211
+ f'\n\tPyTorch: {torch.__version__}'
212
+ f'\n\tTorchVision: {torchvision.__version__}')
213
+ return msg
utils/matlab_functions.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+
5
+
6
+ def cubic(x):
7
+ """cubic function used for calculate_weights_indices."""
8
+ absx = torch.abs(x)
9
+ absx2 = absx**2
10
+ absx3 = absx**3
11
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * (
12
+ (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
13
+ (absx <= 2)).type_as(absx))
14
+
15
+
16
+ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
17
+ """Calculate weights and indices, used for imresize function.
18
+
19
+ Args:
20
+ in_length (int): Input length.
21
+ out_length (int): Output length.
22
+ scale (float): Scale factor.
23
+ kernel_width (int): Kernel width.
24
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
25
+ """
26
+
27
+ if (scale < 1) and antialiasing:
28
+ # Use a modified kernel (larger kernel width) to simultaneously
29
+ # interpolate and antialias
30
+ kernel_width = kernel_width / scale
31
+
32
+ # Output-space coordinates
33
+ x = torch.linspace(1, out_length, out_length)
34
+
35
+ # Input-space coordinates. Calculate the inverse mapping such that 0.5
36
+ # in output space maps to 0.5 in input space, and 0.5 + scale in output
37
+ # space maps to 1.5 in input space.
38
+ u = x / scale + 0.5 * (1 - 1 / scale)
39
+
40
+ # What is the left-most pixel that can be involved in the computation?
41
+ left = torch.floor(u - kernel_width / 2)
42
+
43
+ # What is the maximum number of pixels that can be involved in the
44
+ # computation? Note: it's OK to use an extra pixel here; if the
45
+ # corresponding weights are all zero, it will be eliminated at the end
46
+ # of this function.
47
+ p = math.ceil(kernel_width) + 2
48
+
49
+ # The indices of the input pixels involved in computing the k-th output
50
+ # pixel are in row k of the indices matrix.
51
+ indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
52
+ out_length, p)
53
+
54
+ # The weights used to compute the k-th output pixel are in row k of the
55
+ # weights matrix.
56
+ distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
57
+
58
+ # apply cubic kernel
59
+ if (scale < 1) and antialiasing:
60
+ weights = scale * cubic(distance_to_center * scale)
61
+ else:
62
+ weights = cubic(distance_to_center)
63
+
64
+ # Normalize the weights matrix so that each row sums to 1.
65
+ weights_sum = torch.sum(weights, 1).view(out_length, 1)
66
+ weights = weights / weights_sum.expand(out_length, p)
67
+
68
+ # If a column in weights is all zero, get rid of it. only consider the
69
+ # first and last column.
70
+ weights_zero_tmp = torch.sum((weights == 0), 0)
71
+ if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
72
+ indices = indices.narrow(1, 1, p - 2)
73
+ weights = weights.narrow(1, 1, p - 2)
74
+ if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
75
+ indices = indices.narrow(1, 0, p - 2)
76
+ weights = weights.narrow(1, 0, p - 2)
77
+ weights = weights.contiguous()
78
+ indices = indices.contiguous()
79
+ sym_len_s = -indices.min() + 1
80
+ sym_len_e = indices.max() - in_length
81
+ indices = indices + sym_len_s - 1
82
+ return weights, indices, int(sym_len_s), int(sym_len_e)
83
+
84
+
85
+ @torch.no_grad()
86
+ def imresize(img, scale, antialiasing=True):
87
+ """imresize function same as MATLAB.
88
+
89
+ It now only supports bicubic.
90
+ The same scale applies for both height and width.
91
+
92
+ Args:
93
+ img (Tensor | Numpy array):
94
+ Tensor: Input image with shape (c, h, w), [0, 1] range.
95
+ Numpy: Input image with shape (h, w, c), [0, 1] range.
96
+ scale (float): Scale factor. The same scale applies for both height
97
+ and width.
98
+ antialisaing (bool): Whether to apply anti-aliasing when downsampling.
99
+ Default: True.
100
+
101
+ Returns:
102
+ Tensor: Output image with shape (c, h, w), [0, 1] range, w/o round.
103
+ """
104
+ squeeze_flag = False
105
+ if type(img).__module__ == np.__name__: # numpy type
106
+ numpy_type = True
107
+ if img.ndim == 2:
108
+ img = img[:, :, None]
109
+ squeeze_flag = True
110
+ img = torch.from_numpy(img.transpose(2, 0, 1)).float()
111
+ else:
112
+ numpy_type = False
113
+ if img.ndim == 2:
114
+ img = img.unsqueeze(0)
115
+ squeeze_flag = True
116
+
117
+ in_c, in_h, in_w = img.size()
118
+ out_h, out_w = math.ceil(in_h * scale), math.ceil(in_w * scale)
119
+ kernel_width = 4
120
+ kernel = 'cubic'
121
+
122
+ # get weights and indices
123
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
124
+ antialiasing)
125
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
126
+ antialiasing)
127
+ # process H dimension
128
+ # symmetric copying
129
+ img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
130
+ img_aug.narrow(1, sym_len_hs, in_h).copy_(img)
131
+
132
+ sym_patch = img[:, :sym_len_hs, :]
133
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
134
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
135
+ img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
136
+
137
+ sym_patch = img[:, -sym_len_he:, :]
138
+ inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
139
+ sym_patch_inv = sym_patch.index_select(1, inv_idx)
140
+ img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
141
+
142
+ out_1 = torch.FloatTensor(in_c, out_h, in_w)
143
+ kernel_width = weights_h.size(1)
144
+ for i in range(out_h):
145
+ idx = int(indices_h[i][0])
146
+ for j in range(in_c):
147
+ out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
148
+
149
+ # process W dimension
150
+ # symmetric copying
151
+ out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
152
+ out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
153
+
154
+ sym_patch = out_1[:, :, :sym_len_ws]
155
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
156
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
157
+ out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
158
+
159
+ sym_patch = out_1[:, :, -sym_len_we:]
160
+ inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
161
+ sym_patch_inv = sym_patch.index_select(2, inv_idx)
162
+ out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
163
+
164
+ out_2 = torch.FloatTensor(in_c, out_h, out_w)
165
+ kernel_width = weights_w.size(1)
166
+ for i in range(out_w):
167
+ idx = int(indices_w[i][0])
168
+ for j in range(in_c):
169
+ out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
170
+
171
+ if squeeze_flag:
172
+ out_2 = out_2.squeeze(0)
173
+ if numpy_type:
174
+ out_2 = out_2.numpy()
175
+ if not squeeze_flag:
176
+ out_2 = out_2.transpose(1, 2, 0)
177
+
178
+ return out_2
179
+
180
+
181
+ def rgb2ycbcr(img, y_only=False):
182
+ """Convert a RGB image to YCbCr image.
183
+
184
+ This function produces the same results as Matlab's `rgb2ycbcr` function.
185
+ It implements the ITU-R BT.601 conversion for standard-definition
186
+ television. See more details in
187
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
188
+
189
+ It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`.
190
+ In OpenCV, it implements a JPEG conversion. See more details in
191
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
192
+
193
+ Args:
194
+ img (ndarray): The input image. It accepts:
195
+ 1. np.uint8 type with range [0, 255];
196
+ 2. np.float32 type with range [0, 1].
197
+ y_only (bool): Whether to only return Y channel. Default: False.
198
+
199
+ Returns:
200
+ ndarray: The converted YCbCr image. The output image has the same type
201
+ and range as input image.
202
+ """
203
+ img_type = img.dtype
204
+ img = _convert_input_type_range(img)
205
+ if y_only:
206
+ out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
207
+ else:
208
+ out_img = np.matmul(
209
+ img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
210
+ out_img = _convert_output_type_range(out_img, img_type)
211
+ return out_img
212
+
213
+
214
+ def bgr2ycbcr(img, y_only=False):
215
+ """Convert a BGR image to YCbCr image.
216
+
217
+ The bgr version of rgb2ycbcr.
218
+ It implements the ITU-R BT.601 conversion for standard-definition
219
+ television. See more details in
220
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
221
+
222
+ It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`.
223
+ In OpenCV, it implements a JPEG conversion. See more details in
224
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
225
+
226
+ Args:
227
+ img (ndarray): The input image. It accepts:
228
+ 1. np.uint8 type with range [0, 255];
229
+ 2. np.float32 type with range [0, 1].
230
+ y_only (bool): Whether to only return Y channel. Default: False.
231
+
232
+ Returns:
233
+ ndarray: The converted YCbCr image. The output image has the same type
234
+ and range as input image.
235
+ """
236
+ img_type = img.dtype
237
+ img = _convert_input_type_range(img)
238
+ if y_only:
239
+ out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
240
+ else:
241
+ out_img = np.matmul(
242
+ img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
243
+ out_img = _convert_output_type_range(out_img, img_type)
244
+ return out_img
245
+
246
+
247
+ def ycbcr2rgb(img):
248
+ """Convert a YCbCr image to RGB image.
249
+
250
+ This function produces the same results as Matlab's ycbcr2rgb function.
251
+ It implements the ITU-R BT.601 conversion for standard-definition
252
+ television. See more details in
253
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
254
+
255
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`.
256
+ In OpenCV, it implements a JPEG conversion. See more details in
257
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
258
+
259
+ Args:
260
+ img (ndarray): The input image. It accepts:
261
+ 1. np.uint8 type with range [0, 255];
262
+ 2. np.float32 type with range [0, 1].
263
+
264
+ Returns:
265
+ ndarray: The converted RGB image. The output image has the same type
266
+ and range as input image.
267
+ """
268
+ img_type = img.dtype
269
+ img = _convert_input_type_range(img) * 255
270
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
271
+ [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
272
+ out_img = _convert_output_type_range(out_img, img_type)
273
+ return out_img
274
+
275
+
276
+ def ycbcr2bgr(img):
277
+ """Convert a YCbCr image to BGR image.
278
+
279
+ The bgr version of ycbcr2rgb.
280
+ It implements the ITU-R BT.601 conversion for standard-definition
281
+ television. See more details in
282
+ https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion.
283
+
284
+ It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`.
285
+ In OpenCV, it implements a JPEG conversion. See more details in
286
+ https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion.
287
+
288
+ Args:
289
+ img (ndarray): The input image. It accepts:
290
+ 1. np.uint8 type with range [0, 255];
291
+ 2. np.float32 type with range [0, 1].
292
+
293
+ Returns:
294
+ ndarray: The converted BGR image. The output image has the same type
295
+ and range as input image.
296
+ """
297
+ img_type = img.dtype
298
+ img = _convert_input_type_range(img) * 255
299
+ out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
300
+ [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
301
+ out_img = _convert_output_type_range(out_img, img_type)
302
+ return out_img
303
+
304
+
305
+ def _convert_input_type_range(img):
306
+ """Convert the type and range of the input image.
307
+
308
+ It converts the input image to np.float32 type and range of [0, 1].
309
+ It is mainly used for pre-processing the input image in colorspace
310
+ conversion functions such as rgb2ycbcr and ycbcr2rgb.
311
+
312
+ Args:
313
+ img (ndarray): The input image. It accepts:
314
+ 1. np.uint8 type with range [0, 255];
315
+ 2. np.float32 type with range [0, 1].
316
+
317
+ Returns:
318
+ (ndarray): The converted image with type of np.float32 and range of
319
+ [0, 1].
320
+ """
321
+ img_type = img.dtype
322
+ img = img.astype(np.float32)
323
+ if img_type == np.float32:
324
+ pass
325
+ elif img_type == np.uint8:
326
+ img /= 255.
327
+ else:
328
+ raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
329
+ return img
330
+
331
+
332
+ def _convert_output_type_range(img, dst_type):
333
+ """Convert the type and range of the image according to dst_type.
334
+
335
+ It converts the image to desired type and range. If `dst_type` is np.uint8,
336
+ images will be converted to np.uint8 type with range [0, 255]. If
337
+ `dst_type` is np.float32, it converts the image to np.float32 type with
338
+ range [0, 1].
339
+ It is mainly used for post-processing images in colorspace conversion
340
+ functions such as rgb2ycbcr and ycbcr2rgb.
341
+
342
+ Args:
343
+ img (ndarray): The image to be converted with np.float32 type and
344
+ range [0, 255].
345
+ dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it
346
+ converts the image to np.uint8 type with range [0, 255]. If
347
+ dst_type is np.float32, it converts the image to np.float32 type
348
+ with range [0, 1].
349
+
350
+ Returns:
351
+ (ndarray): The converted image with desired type and range.
352
+ """
353
+ if dst_type not in (np.uint8, np.float32):
354
+ raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}')
355
+ if dst_type == np.uint8:
356
+ img = img.round()
357
+ else:
358
+ img /= 255.
359
+ return img.astype(dst_type)
utils/misc.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import random
4
+ import time
5
+ import torch
6
+ from os import path as osp
7
+
8
+ from .dist_util import master_only
9
+
10
+
11
+ def set_random_seed(seed):
12
+ """Set random seeds."""
13
+ random.seed(seed)
14
+ np.random.seed(seed)
15
+ torch.manual_seed(seed)
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+
19
+
20
+ def get_time_str():
21
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
22
+
23
+
24
+ def mkdir_and_rename(path):
25
+ """mkdirs. If path exists, rename it with timestamp and create a new one.
26
+
27
+ Args:
28
+ path (str): Folder path.
29
+ """
30
+ if osp.exists(path):
31
+ new_name = path + '_archived_' + get_time_str()
32
+ print(f'Path already exists. Rename it to {new_name}', flush=True)
33
+ os.rename(path, new_name)
34
+ os.makedirs(path, exist_ok=True)
35
+
36
+
37
+ @master_only
38
+ def make_exp_dirs(opt):
39
+ """Make dirs for experiments."""
40
+ path_opt = opt['path'].copy()
41
+ if opt['is_train']:
42
+ mkdir_and_rename(path_opt.pop('experiments_root'))
43
+ else:
44
+ mkdir_and_rename(path_opt.pop('results_root'))
45
+ for key, path in path_opt.items():
46
+ if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key):
47
+ continue
48
+ else:
49
+ os.makedirs(path, exist_ok=True)
50
+
51
+
52
+ def scandir(dir_path, suffix=None, recursive=False, full_path=False):
53
+ """Scan a directory to find the interested files.
54
+
55
+ Args:
56
+ dir_path (str): Path of the directory.
57
+ suffix (str | tuple(str), optional): File suffix that we are
58
+ interested in. Default: None.
59
+ recursive (bool, optional): If set to True, recursively scan the
60
+ directory. Default: False.
61
+ full_path (bool, optional): If set to True, include the dir_path.
62
+ Default: False.
63
+
64
+ Returns:
65
+ A generator for all the interested files with relative paths.
66
+ """
67
+
68
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
69
+ raise TypeError('"suffix" must be a string or tuple of strings')
70
+
71
+ root = dir_path
72
+
73
+ def _scandir(dir_path, suffix, recursive):
74
+ for entry in os.scandir(dir_path):
75
+ if not entry.name.startswith('.') and entry.is_file():
76
+ if full_path:
77
+ return_path = entry.path
78
+ else:
79
+ return_path = osp.relpath(entry.path, root)
80
+
81
+ if suffix is None:
82
+ yield return_path
83
+ elif return_path.endswith(suffix):
84
+ yield return_path
85
+ else:
86
+ if recursive:
87
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
88
+ else:
89
+ continue
90
+
91
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
92
+
93
+
94
+ def check_resume(opt, resume_iter):
95
+ """Check resume states and pretrain_network paths.
96
+
97
+ Args:
98
+ opt (dict): Options.
99
+ resume_iter (int): Resume iteration.
100
+ """
101
+ if opt['path']['resume_state']:
102
+ # get all the networks
103
+ networks = [key for key in opt.keys() if key.startswith('network_')]
104
+ flag_pretrain = False
105
+ for network in networks:
106
+ if opt['path'].get(f'pretrain_{network}') is not None:
107
+ flag_pretrain = True
108
+ if flag_pretrain:
109
+ print('pretrain_network path will be ignored during resuming.')
110
+ # set pretrained model paths
111
+ for network in networks:
112
+ name = f'pretrain_{network}'
113
+ basename = network.replace('network_', '')
114
+ if opt['path'].get('ignore_resume_networks') is None or (network
115
+ not in opt['path']['ignore_resume_networks']):
116
+ opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
117
+ print(f"Set {name} to {opt['path'][name]}")
118
+
119
+ # change param_key to params in resume
120
+ param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
121
+ for param_key in param_keys:
122
+ if opt['path'][param_key] == 'params_ema':
123
+ opt['path'][param_key] = 'params'
124
+ print(f'Set {param_key} to params')
125
+
126
+
127
+ def sizeof_fmt(size, suffix='B'):
128
+ """Get human readable file size.
129
+
130
+ Args:
131
+ size (int): File size.
132
+ suffix (str): Suffix. Default: 'B'.
133
+
134
+ Return:
135
+ str: Formatted file siz.
136
+ """
137
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
138
+ if abs(size) < 1024.0:
139
+ return f'{size:3.1f} {unit}{suffix}'
140
+ size /= 1024.0
141
+ return f'{size:3.1f} Y{suffix}'
utils/options.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import torch
4
+ import yaml
5
+ from collections import OrderedDict
6
+ from os import path as osp
7
+
8
+ from basicsr.utils import set_random_seed
9
+ from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
10
+
11
+
12
+ def ordered_yaml():
13
+ """Support OrderedDict for yaml.
14
+
15
+ Returns:
16
+ yaml Loader and Dumper.
17
+ """
18
+ try:
19
+ from yaml import CDumper as Dumper
20
+ from yaml import CLoader as Loader
21
+ except ImportError:
22
+ from yaml import Dumper, Loader
23
+
24
+ _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25
+
26
+ def dict_representer(dumper, data):
27
+ return dumper.represent_dict(data.items())
28
+
29
+ def dict_constructor(loader, node):
30
+ return OrderedDict(loader.construct_pairs(node))
31
+
32
+ Dumper.add_representer(OrderedDict, dict_representer)
33
+ Loader.add_constructor(_mapping_tag, dict_constructor)
34
+ return Loader, Dumper
35
+
36
+
37
+ def dict2str(opt, indent_level=1):
38
+ """dict to string for printing options.
39
+
40
+ Args:
41
+ opt (dict): Option dict.
42
+ indent_level (int): Indent level. Default: 1.
43
+
44
+ Return:
45
+ (str): Option string for printing.
46
+ """
47
+ msg = '\n'
48
+ for k, v in opt.items():
49
+ if isinstance(v, dict):
50
+ msg += ' ' * (indent_level * 2) + k + ':['
51
+ msg += dict2str(v, indent_level + 1)
52
+ msg += ' ' * (indent_level * 2) + ']\n'
53
+ else:
54
+ msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
55
+ return msg
56
+
57
+
58
+ def _postprocess_yml_value(value):
59
+ # None
60
+ if value == '~' or value.lower() == 'none':
61
+ return None
62
+ # bool
63
+ if value.lower() == 'true':
64
+ return True
65
+ elif value.lower() == 'false':
66
+ return False
67
+ # !!float number
68
+ if value.startswith('!!float'):
69
+ return float(value.replace('!!float', ''))
70
+ # number
71
+ if value.isdigit():
72
+ return int(value)
73
+ elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
74
+ return float(value)
75
+ # list
76
+ if value.startswith('['):
77
+ return eval(value)
78
+ # str
79
+ return value
80
+
81
+
82
+ def parse_options(root_path, is_train=True):
83
+ parser = argparse.ArgumentParser()
84
+ parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
85
+ parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
86
+ parser.add_argument('--auto_resume', action='store_true')
87
+ parser.add_argument('--debug', action='store_true')
88
+ parser.add_argument('--local_rank', type=int, default=0)
89
+ parser.add_argument(
90
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
91
+ args = parser.parse_args()
92
+
93
+ # parse yml to dict
94
+ with open(args.opt, mode='r') as f:
95
+ opt = yaml.load(f, Loader=ordered_yaml()[0])
96
+
97
+ # distributed settings
98
+ if args.launcher == 'none':
99
+ opt['dist'] = False
100
+ print('Disable distributed.', flush=True)
101
+ else:
102
+ opt['dist'] = True
103
+ if args.launcher == 'slurm' and 'dist_params' in opt:
104
+ init_dist(args.launcher, **opt['dist_params'])
105
+ else:
106
+ init_dist(args.launcher)
107
+ opt['rank'], opt['world_size'] = get_dist_info()
108
+
109
+ # random seed
110
+ seed = opt.get('manual_seed')
111
+ if seed is None:
112
+ seed = random.randint(1, 10000)
113
+ opt['manual_seed'] = seed
114
+ set_random_seed(seed + opt['rank'])
115
+
116
+ # force to update yml options
117
+ if args.force_yml is not None:
118
+ for entry in args.force_yml:
119
+ # now do not support creating new keys
120
+ keys, value = entry.split('=')
121
+ keys, value = keys.strip(), value.strip()
122
+ value = _postprocess_yml_value(value)
123
+ eval_str = 'opt'
124
+ for key in keys.split(':'):
125
+ eval_str += f'["{key}"]'
126
+ eval_str += '=value'
127
+ # using exec function
128
+ exec(eval_str)
129
+
130
+ opt['auto_resume'] = args.auto_resume
131
+ opt['is_train'] = is_train
132
+
133
+ # debug setting
134
+ if args.debug and not opt['name'].startswith('debug'):
135
+ opt['name'] = 'debug_' + opt['name']
136
+
137
+ if opt['num_gpu'] == 'auto':
138
+ opt['num_gpu'] = torch.cuda.device_count()
139
+
140
+ # datasets
141
+ for phase, dataset in opt['datasets'].items():
142
+ # for multiple datasets, e.g., val_1, val_2; test_1, test_2
143
+ phase = phase.split('_')[0]
144
+ dataset['phase'] = phase
145
+ if 'scale' in opt:
146
+ dataset['scale'] = opt['scale']
147
+ if dataset.get('dataroot_gt') is not None:
148
+ dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
149
+ if dataset.get('dataroot_lq') is not None:
150
+ dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
151
+
152
+ # paths
153
+ for key, val in opt['path'].items():
154
+ if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
155
+ opt['path'][key] = osp.expanduser(val)
156
+
157
+ if is_train:
158
+ experiments_root = osp.join(root_path, 'experiments', opt['name'])
159
+ opt['path']['experiments_root'] = experiments_root
160
+ opt['path']['models'] = osp.join(experiments_root, 'models')
161
+ opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
162
+ opt['path']['log'] = experiments_root
163
+ opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
164
+
165
+ # change some options for debug mode
166
+ if 'debug' in opt['name']:
167
+ if 'val' in opt:
168
+ opt['val']['val_freq'] = 8
169
+ opt['logger']['print_freq'] = 1
170
+ opt['logger']['save_checkpoint_freq'] = 8
171
+ else: # test
172
+ results_root = osp.join(root_path, 'results', opt['name'])
173
+ opt['path']['results_root'] = results_root
174
+ opt['path']['log'] = results_root
175
+ opt['path']['visualization'] = osp.join(results_root, 'visualization')
176
+
177
+ return opt, args
178
+
179
+
180
+ @master_only
181
+ def copy_opt_file(opt_file, experiments_root):
182
+ # copy the yml file to the experiment root
183
+ import sys
184
+ import time
185
+ from shutil import copyfile
186
+ cmd = ' '.join(sys.argv)
187
+ filename = osp.join(experiments_root, osp.basename(opt_file))
188
+ copyfile(opt_file, filename)
189
+
190
+ with open(filename, 'r+') as f:
191
+ lines = f.readlines()
192
+ lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
193
+ f.seek(0)
194
+ f.writelines(lines)
utils/registry.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2
+
3
+
4
+ class Registry():
5
+ """
6
+ The registry that provides name -> object mapping, to support third-party
7
+ users' custom modules.
8
+
9
+ To create a registry (e.g. a backbone registry):
10
+
11
+ .. code-block:: python
12
+
13
+ BACKBONE_REGISTRY = Registry('BACKBONE')
14
+
15
+ To register an object:
16
+
17
+ .. code-block:: python
18
+
19
+ @BACKBONE_REGISTRY.register()
20
+ class MyBackbone():
21
+ ...
22
+
23
+ Or:
24
+
25
+ .. code-block:: python
26
+
27
+ BACKBONE_REGISTRY.register(MyBackbone)
28
+ """
29
+
30
+ def __init__(self, name):
31
+ """
32
+ Args:
33
+ name (str): the name of this registry
34
+ """
35
+ self._name = name
36
+ self._obj_map = {}
37
+
38
+ def _do_register(self, name, obj):
39
+ assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
40
+ f"in '{self._name}' registry!")
41
+ self._obj_map[name] = obj
42
+
43
+ def register(self, obj=None):
44
+ """
45
+ Register the given object under the the name `obj.__name__`.
46
+ Can be used as either a decorator or not.
47
+ See docstring of this class for usage.
48
+ """
49
+ if obj is None:
50
+ # used as a decorator
51
+ def deco(func_or_class):
52
+ name = func_or_class.__name__
53
+ self._do_register(name, func_or_class)
54
+ return func_or_class
55
+
56
+ return deco
57
+
58
+ # used as a function call
59
+ name = obj.__name__
60
+ self._do_register(name, obj)
61
+
62
+ def get(self, name):
63
+ ret = self._obj_map.get(name)
64
+ if ret is None:
65
+ raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
66
+ return ret
67
+
68
+ def __contains__(self, name):
69
+ return name in self._obj_map
70
+
71
+ def __iter__(self):
72
+ return iter(self._obj_map.items())
73
+
74
+ def keys(self):
75
+ return self._obj_map.keys()
76
+
77
+
78
+ DATASET_REGISTRY = Registry('dataset')
79
+ ARCH_REGISTRY = Registry('arch')
80
+ MODEL_REGISTRY = Registry('model')
81
+ LOSS_REGISTRY = Registry('loss')
82
+ METRIC_REGISTRY = Registry('metric')