XiangZ commited on
Commit
6da1fbb
·
verified ·
1 Parent(s): 1fe87b7

Update hit_srf_arch.py

Browse files
Files changed (1) hide show
  1. hit_srf_arch.py +945 -947
hit_srf_arch.py CHANGED
@@ -1,947 +1,945 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
- import torch.utils.checkpoint as checkpoint
6
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
-
8
- import numpy as np
9
- from huggingface_hub import PyTorchModelHubMixin
10
- from utils import FileClient, imfrombytes, img2tensor, tensor2img
11
-
12
- class DFE(nn.Module):
13
- """ Dual Feature Extraction
14
- Args:
15
- in_features (int): Number of input channels.
16
- out_features (int): Number of output channels.
17
- """
18
- def __init__(self, in_features, out_features):
19
- super().__init__()
20
-
21
- self.out_features = out_features
22
-
23
- self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
24
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
- nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
26
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
- nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
28
-
29
- self.linear = nn.Conv2d(in_features, out_features,1,1,0)
30
-
31
- def forward(self, x, x_size):
32
-
33
- B, L, C = x.shape
34
- H, W = x_size
35
- x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
36
- x = self.conv(x) * self.linear(x)
37
- x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
38
-
39
- return x
40
-
41
- class Mlp(nn.Module):
42
- """ MLP-based Feed-Forward Network
43
- Args:
44
- in_features (int): Number of input channels.
45
- hidden_features (int | None): Number of hidden channels. Default: None
46
- out_features (int | None): Number of output channels. Default: None
47
- act_layer (nn.Module): Activation layer. Default: nn.GELU
48
- drop (float): Dropout rate. Default: 0.0
49
- """
50
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
- super().__init__()
52
- out_features = out_features or in_features
53
- hidden_features = hidden_features or in_features
54
- self.fc1 = nn.Linear(in_features, hidden_features)
55
- self.act = act_layer()
56
- self.fc2 = nn.Linear(hidden_features, out_features)
57
- self.drop = nn.Dropout(drop)
58
-
59
- def forward(self, x):
60
- x = self.fc1(x)
61
- x = self.act(x)
62
- x = self.drop(x)
63
- x = self.fc2(x)
64
- x = self.drop(x)
65
- return x
66
-
67
-
68
- class dwconv(nn.Module):
69
- def __init__(self,hidden_features):
70
- super(dwconv, self).__init__()
71
- self.depthwise_conv = nn.Sequential(
72
- nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2, dilation=1,
73
- groups=hidden_features), nn.GELU())
74
- self.hidden_features = hidden_features
75
- def forward(self,x,x_size):
76
- x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous() # b Ph*Pw c
77
- x = self.depthwise_conv(x)
78
- x = x.flatten(2).transpose(1, 2).contiguous()
79
- return x
80
-
81
- class ConvFFN(nn.Module):
82
-
83
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
84
- super().__init__()
85
- out_features = out_features or in_features
86
- hidden_features = hidden_features or in_features
87
- self.fc1 = nn.Linear(in_features, hidden_features)
88
- self.act = act_layer()
89
- self.dwconv = dwconv(hidden_features=hidden_features)
90
- self.fc2 = nn.Linear(hidden_features, out_features)
91
- self.drop = nn.Dropout(drop)
92
-
93
-
94
- def forward(self, x,x_size):
95
- x = self.fc1(x)
96
- x = self.act(x)
97
- x = x + self.dwconv(x,x_size)
98
- x = self.drop(x)
99
- x = self.fc2(x)
100
- x = self.drop(x)
101
- return x
102
-
103
- def window_partition(x, window_size):
104
- """
105
- Args:
106
- x: (B, H, W, C)
107
- window_size (tuple): window size
108
-
109
- Returns:
110
- windows: (num_windows*B, window_size, window_size, C)
111
- """
112
- B, H, W, C = x.shape
113
- x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
114
- windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
115
- return windows
116
-
117
-
118
- def window_reverse(windows, window_size, H, W):
119
- """
120
- Args:
121
- windows: (num_windows*B, window_size, window_size, C)
122
- window_size (tuple): Window size
123
- H (int): Height of image
124
- W (int): Width of image
125
-
126
- Returns:
127
- x: (B, H, W, C)
128
- """
129
- B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
130
- x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
131
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
132
- return x
133
-
134
- class DynamicPosBias(nn.Module):
135
- # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
136
- """ Dynamic Relative Position Bias.
137
- Args:
138
- dim (int): Number of input channels.
139
- num_heads (int): Number of heads for spatial self-correlation.
140
- residual (bool): If True, use residual strage to connect conv.
141
- """
142
- def __init__(self, dim, num_heads, residual):
143
- super().__init__()
144
- self.residual = residual
145
- self.num_heads = num_heads
146
- self.pos_dim = dim // 4
147
- self.pos_proj = nn.Linear(2, self.pos_dim)
148
- self.pos1 = nn.Sequential(
149
- nn.LayerNorm(self.pos_dim),
150
- nn.ReLU(inplace=True),
151
- nn.Linear(self.pos_dim, self.pos_dim),
152
- )
153
- self.pos2 = nn.Sequential(
154
- nn.LayerNorm(self.pos_dim),
155
- nn.ReLU(inplace=True),
156
- nn.Linear(self.pos_dim, self.pos_dim)
157
- )
158
- self.pos3 = nn.Sequential(
159
- nn.LayerNorm(self.pos_dim),
160
- nn.ReLU(inplace=True),
161
- nn.Linear(self.pos_dim, self.num_heads)
162
- )
163
- def forward(self, biases):
164
- if self.residual:
165
- pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
166
- pos = pos + self.pos1(pos)
167
- pos = pos + self.pos2(pos)
168
- pos = self.pos3(pos)
169
- else:
170
- pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
171
- return pos
172
-
173
- class SCC(nn.Module):
174
- """ Spatial-Channel Correlation.
175
- Args:
176
- dim (int): Number of input channels.
177
- base_win_size (tuple[int]): The height and width of the base window.
178
- window_size (tuple[int]): The height and width of the window.
179
- num_heads (int): Number of heads for spatial self-correlation.
180
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
181
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
182
- """
183
-
184
- def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
185
-
186
- super().__init__()
187
- # parameters
188
- self.dim = dim
189
- self.window_size = window_size
190
- self.num_heads = num_heads
191
-
192
- # feature projection
193
- self.qv = DFE(dim, dim)
194
- self.proj = nn.Linear(dim, dim)
195
-
196
- # dropout
197
- self.value_drop = nn.Dropout(value_drop)
198
- self.proj_drop = nn.Dropout(proj_drop)
199
-
200
- # base window size
201
- min_h = min(self.window_size[0], base_win_size[0])
202
- min_w = min(self.window_size[1], base_win_size[1])
203
- self.base_win_size = (min_h, min_w)
204
-
205
- # normalization factor and spatial linear layer for S-SC
206
- head_dim = dim // (2*num_heads)
207
- self.scale = head_dim
208
- self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
209
-
210
- # define a parameter table of relative position bias
211
- self.H_sp, self.W_sp = self.window_size
212
- self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
213
-
214
- def spatial_linear_projection(self, x):
215
- B, num_h, L, C = x.shape
216
- H, W = self.window_size
217
- map_H, map_W = self.base_win_size
218
-
219
- x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
220
- x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
221
- return x
222
-
223
- def spatial_self_correlation(self, q, v):
224
-
225
- B, num_head, L, C = q.shape
226
-
227
- # spatial projection
228
- v = self.spatial_linear_projection(v)
229
-
230
- # compute correlation map
231
- corr_map = (q @ v.transpose(-2,-1)) / self.scale
232
-
233
- # add relative position bias
234
- # generate mother-set
235
- position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
236
- position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
237
- biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
238
- rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
239
- pos = self.pos(rpe_biases)
240
-
241
- # select position bias
242
- coords_h = torch.arange(self.H_sp, device=v.device)
243
- coords_w = torch.arange(self.W_sp, device=v.device)
244
- coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
245
- coords_flatten = torch.flatten(coords, 1)
246
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
247
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
248
- relative_coords[:, :, 0] += self.H_sp - 1
249
- relative_coords[:, :, 1] += self.W_sp - 1
250
- relative_coords[:, :, 0] *= 2 * self.W_sp - 1
251
- relative_position_index = relative_coords.sum(-1)
252
- relative_position_bias = pos[relative_position_index.view(-1)].view(
253
- self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
254
- relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
255
- self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
256
- relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
257
- corr_map = corr_map + relative_position_bias.unsqueeze(0)
258
-
259
- # transformation
260
- v_drop = self.value_drop(v)
261
- x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
262
-
263
- return x
264
-
265
- def channel_self_correlation(self, q, v):
266
-
267
- B, num_head, L, C = q.shape
268
-
269
- # apply single head strategy
270
- q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
271
- v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
272
-
273
- # compute correlation map
274
- corr_map = (q.transpose(-2,-1) @ v) / L
275
-
276
- # transformation
277
- v_drop = self.value_drop(v)
278
- x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
279
-
280
- return x
281
-
282
- def forward(self, x):
283
- """
284
- Args:
285
- x: input features with shape of (B, H, W, C)
286
- """
287
- xB,xH,xW,xC = x.shape
288
- qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
289
-
290
- # window partition
291
- qv = window_partition(qv, self.window_size)
292
- qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
293
-
294
- # qv splitting
295
- B, L, C = qv.shape
296
- qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
297
- q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
298
-
299
- # spatial self-correlation (S-SC)
300
- x_spatial = self.spatial_self_correlation(q, v)
301
- x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
302
- x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
303
-
304
- # channel self-correlation (C-SC)
305
- x_channel = self.channel_self_correlation(q, v)
306
- x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
307
- x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
308
-
309
- # spatial-channel information fusion
310
- x = torch.cat([x_spatial, x_channel], -1)
311
- x = self.proj_drop(self.proj(x))
312
-
313
- return x
314
-
315
- def extra_repr(self) -> str:
316
- return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
317
-
318
-
319
- class HierarchicalTransformerBlock(nn.Module):
320
- """ Hierarchical Transformer Block.
321
- Args:
322
- dim (int): Number of input channels.
323
- input_resolution (tuple[int]): Input resulotion.
324
- num_heads (int): Number of heads for spatial self-correlation.
325
- base_win_size (tuple[int]): The height and width of the base window.
326
- window_size (tuple[int]): The height and width of the window.
327
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
328
- drop (float, optional): Dropout rate. Default: 0.0
329
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
330
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
331
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
332
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
- """
334
-
335
- def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
336
- mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
337
- act_layer=nn.GELU, norm_layer=nn.LayerNorm):
338
- super().__init__()
339
- self.dim = dim
340
- self.input_resolution = input_resolution
341
- self.num_heads = num_heads
342
- self.window_size = window_size
343
- self.mlp_ratio = mlp_ratio
344
-
345
- # check window size
346
- if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
347
- assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
348
- assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
349
-
350
-
351
- self.norm1 = norm_layer(dim)
352
- self.correlation = SCC(
353
- dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
354
- value_drop=value_drop, proj_drop=drop)
355
-
356
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
357
- self.norm2 = norm_layer(dim)
358
- mlp_hidden_dim = int(dim * mlp_ratio)
359
- self.mlp = ConvFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
360
- # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
361
-
362
- def check_image_size(self, x, win_size):
363
- x = x.permute(0,3,1,2).contiguous()
364
- _, _, h, w = x.size()
365
- mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
366
- mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
367
-
368
- if mod_pad_h >= h or mod_pad_w >= w:
369
- pad_h, pad_w = h-1, w-1
370
- x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
371
- else:
372
- pad_h, pad_w = 0, 0
373
-
374
- mod_pad_h = mod_pad_h - pad_h
375
- mod_pad_w = mod_pad_w - pad_w
376
-
377
- x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
378
- x = x.permute(0,2,3,1).contiguous()
379
- return x
380
-
381
- def forward(self, x, x_size, win_size):
382
- H, W = x_size
383
- B, L, C = x.shape
384
-
385
- shortcut = x
386
- x = x.view(B, H, W, C)
387
-
388
- # padding
389
- x = self.check_image_size(x, win_size)
390
- _, H_pad, W_pad, _ = x.shape # shape after padding
391
-
392
- x = self.correlation(x)
393
-
394
- # unpad
395
- x = x[:, :H, :W, :].contiguous()
396
-
397
- # norm
398
- x = x.view(B, H * W, C)
399
- x = self.norm1(x)
400
-
401
- # FFN
402
- x = shortcut + self.drop_path(x)
403
- x = x + self.drop_path(self.norm2(self.mlp(x, x_size)))
404
-
405
- return x
406
-
407
- def extra_repr(self) -> str:
408
- return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
409
- f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
410
-
411
-
412
- class PatchMerging(nn.Module):
413
- """ Patch Merging Layer.
414
- Args:
415
- input_resolution (tuple[int]): Resolution of input feature.
416
- dim (int): Number of input channels.
417
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
418
- """
419
-
420
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
421
- super().__init__()
422
- self.input_resolution = input_resolution
423
- self.dim = dim
424
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
425
- self.norm = norm_layer(4 * dim)
426
-
427
- def forward(self, x):
428
- """
429
- x: B, H*W, C
430
- """
431
- H, W = self.input_resolution
432
- B, L, C = x.shape
433
- assert L == H * W, "input feature has wrong size"
434
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
435
-
436
- x = x.view(B, H, W, C)
437
-
438
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
439
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
440
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
441
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
442
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
443
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
444
-
445
- x = self.norm(x)
446
- x = self.reduction(x)
447
-
448
- return x
449
-
450
- def extra_repr(self) -> str:
451
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
452
-
453
-
454
- class BasicLayer(nn.Module):
455
- """ A basic Hierarchical Transformer layer for one stage.
456
-
457
- Args:
458
- dim (int): Number of input channels.
459
- input_resolution (tuple[int]): Input resolution.
460
- depth (int): Number of blocks.
461
- num_heads (int): Number of heads for spatial self-correlation.
462
- base_win_size (tuple[int]): The height and width of the base window.
463
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
464
- drop (float, optional): Dropout rate. Default: 0.0
465
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
466
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
467
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
469
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
470
- hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
471
- """
472
-
473
- def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
474
- mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
475
- downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
476
-
477
- super().__init__()
478
- self.dim = dim
479
- self.input_resolution = input_resolution
480
- self.depth = depth
481
- self.use_checkpoint = use_checkpoint
482
-
483
- self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
484
- self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
485
-
486
- # build blocks
487
- self.blocks = nn.ModuleList([
488
- HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
489
- num_heads=num_heads,
490
- base_win_size=base_win_size,
491
- window_size=(self.win_hs[i], self.win_ws[i]),
492
- mlp_ratio=mlp_ratio,
493
- drop=drop, value_drop=value_drop,
494
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
495
- norm_layer=norm_layer)
496
- for i in range(depth)])
497
-
498
- # patch merging layer
499
- if downsample is not None:
500
- self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
501
- else:
502
- self.downsample = None
503
-
504
- def forward(self, x, x_size):
505
-
506
- i = 0
507
- for blk in self.blocks:
508
- if self.use_checkpoint:
509
- x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
510
- else:
511
- x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
512
- i = i + 1
513
-
514
- if self.downsample is not None:
515
- x = self.downsample(x)
516
- return x
517
-
518
- def extra_repr(self) -> str:
519
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
520
-
521
-
522
- class RHTB(nn.Module):
523
- """Residual Hierarchical Transformer Block (RHTB).
524
- Args:
525
- dim (int): Number of input channels.
526
- input_resolution (tuple[int]): Input resolution.
527
- depth (int): Number of blocks.
528
- num_heads (int): Number of heads for spatial self-correlation.
529
- base_win_size (tuple[int]): The height and width of the base window.
530
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
531
- drop (float, optional): Dropout rate. Default: 0.0
532
- value_drop (float, optional): Dropout ratio of value. Default: 0.0
533
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
534
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
535
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
536
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
537
- img_size: Input image size.
538
- patch_size: Patch size.
539
- resi_connection: The convolutional block before residual connection.
540
- hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
541
- """
542
-
543
- def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
544
- mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
545
- downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
546
- resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
547
- super(RHTB, self).__init__()
548
-
549
- self.dim = dim
550
- self.input_resolution = input_resolution
551
-
552
- self.residual_group = BasicLayer(dim=dim,
553
- input_resolution=input_resolution,
554
- depth=depth,
555
- num_heads=num_heads,
556
- base_win_size=base_win_size,
557
- mlp_ratio=mlp_ratio,
558
- drop=drop, value_drop=value_drop,
559
- drop_path=drop_path,
560
- norm_layer=norm_layer,
561
- downsample=downsample,
562
- use_checkpoint=use_checkpoint,
563
- hier_win_ratios=hier_win_ratios)
564
-
565
- if resi_connection == '1conv':
566
- self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
567
- elif resi_connection == '3conv':
568
- # to save parameters and memory
569
- self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
570
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
571
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
572
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
573
-
574
- self.patch_embed = PatchEmbed(
575
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
576
- norm_layer=None)
577
-
578
- self.patch_unembed = PatchUnEmbed(
579
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
580
- norm_layer=None)
581
-
582
- def forward(self, x, x_size):
583
- return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
584
-
585
-
586
- class PatchEmbed(nn.Module):
587
- r""" Image to Patch Embedding
588
-
589
- Args:
590
- img_size (int): Image size. Default: 224.
591
- patch_size (int): Patch token size. Default: 4.
592
- in_chans (int): Number of input image channels. Default: 3.
593
- embed_dim (int): Number of linear projection output channels. Default: 96.
594
- norm_layer (nn.Module, optional): Normalization layer. Default: None
595
- """
596
-
597
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
598
- super().__init__()
599
- img_size = to_2tuple(img_size)
600
- patch_size = to_2tuple(patch_size)
601
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
602
- self.img_size = img_size
603
- self.patch_size = patch_size
604
- self.patches_resolution = patches_resolution
605
- self.num_patches = patches_resolution[0] * patches_resolution[1]
606
-
607
- self.in_chans = in_chans
608
- self.embed_dim = embed_dim
609
-
610
- if norm_layer is not None:
611
- self.norm = norm_layer(embed_dim)
612
- else:
613
- self.norm = None
614
-
615
- def forward(self, x):
616
- x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
617
- if self.norm is not None:
618
- x = self.norm(x)
619
- return x
620
-
621
-
622
- class PatchUnEmbed(nn.Module):
623
- r""" Image to Patch Unembedding
624
-
625
- Args:
626
- img_size (int): Image size. Default: 224.
627
- patch_size (int): Patch token size. Default: 4.
628
- in_chans (int): Number of input image channels. Default: 3.
629
- embed_dim (int): Number of linear projection output channels. Default: 96.
630
- norm_layer (nn.Module, optional): Normalization layer. Default: None
631
- """
632
-
633
- def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
634
- super().__init__()
635
- img_size = to_2tuple(img_size)
636
- patch_size = to_2tuple(patch_size)
637
- patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
638
- self.img_size = img_size
639
- self.patch_size = patch_size
640
- self.patches_resolution = patches_resolution
641
- self.num_patches = patches_resolution[0] * patches_resolution[1]
642
-
643
- self.in_chans = in_chans
644
- self.embed_dim = embed_dim
645
-
646
- def forward(self, x, x_size):
647
- B, HW, C = x.shape
648
- x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
649
- return x
650
-
651
-
652
- class Upsample(nn.Sequential):
653
- """Upsample module.
654
-
655
- Args:
656
- scale (int): Scale factor. Supported scales: 2^n and 3.
657
- num_feat (int): Channel number of intermediate features.
658
- """
659
-
660
- def __init__(self, scale, num_feat):
661
- m = []
662
- if (scale & (scale - 1)) == 0: # scale = 2^n
663
- for _ in range(int(math.log(scale, 2))):
664
- m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
665
- m.append(nn.PixelShuffle(2))
666
- elif scale == 3:
667
- m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
668
- m.append(nn.PixelShuffle(3))
669
- else:
670
- raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
671
- super(Upsample, self).__init__(*m)
672
-
673
-
674
- class UpsampleOneStep(nn.Sequential):
675
- """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
676
- Used in lightweight SR to save parameters.
677
-
678
- Args:
679
- scale (int): Scale factor. Supported scales: 2^n and 3.
680
- num_feat (int): Channel number of intermediate features.
681
-
682
- """
683
-
684
- def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
685
- self.num_feat = num_feat
686
- self.input_resolution = input_resolution
687
- m = []
688
- m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
689
- m.append(nn.PixelShuffle(scale))
690
- super(UpsampleOneStep, self).__init__(*m)
691
-
692
-
693
- class HiT_SRF(nn.Module, PyTorchModelHubMixin):
694
- """ HiT-SRF network.
695
-
696
- Args:
697
- img_size (int | tuple(int)): Input image size. Default 64
698
- patch_size (int | tuple(int)): Patch size. Default: 1
699
- in_chans (int): Number of input image channels. Default: 3
700
- embed_dim (int): Patch embedding dimension. Default: 96
701
- depths (tuple(int)): Depth of each Transformer block.
702
- num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
703
- base_win_size (tuple[int]): The height and width of the base window.
704
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
705
- drop_rate (float): Dropout rate. Default: 0
706
- value_drop_rate (float): Dropout ratio of value. Default: 0.0
707
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
708
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
709
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
710
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
711
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
712
- upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
713
- img_range (float): Image range. 1. or 255.
714
- upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
715
- resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
716
- hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
717
- """
718
-
719
- def __init__(self, img_size=64, patch_size=1, in_chans=3,
720
- embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
721
- base_win_size=[8,8], mlp_ratio=2.,
722
- drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
723
- norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
724
- use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
725
- hier_win_ratios=[0.5,1,2,4,6,8],
726
- **kwargs):
727
- super(HiT_SRF, self).__init__()
728
- num_in_ch = in_chans
729
- num_out_ch = in_chans
730
- num_feat = 64
731
- self.img_range = img_range
732
- if in_chans == 3:
733
- rgb_mean = (0.4488, 0.4371, 0.4040)
734
- self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
735
- else:
736
- self.mean = torch.zeros(1, 1, 1, 1)
737
- self.upscale = upscale
738
- self.upsampler = upsampler
739
- self.base_win_size = base_win_size
740
-
741
- #####################################################################################################
742
- ################################### 1, shallow feature extraction ###################################
743
- self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
744
-
745
- #####################################################################################################
746
- ################################### 2, deep feature extraction ######################################
747
- self.num_layers = len(depths)
748
- self.embed_dim = embed_dim
749
- self.ape = ape
750
- self.patch_norm = patch_norm
751
- self.num_features = embed_dim
752
- self.mlp_ratio = mlp_ratio
753
-
754
- # split image into non-overlapping patches
755
- self.patch_embed = PatchEmbed(
756
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
757
- norm_layer=norm_layer if self.patch_norm else None)
758
- num_patches = self.patch_embed.num_patches
759
- patches_resolution = self.patch_embed.patches_resolution
760
- self.patches_resolution = patches_resolution
761
-
762
- # merge non-overlapping patches into image
763
- self.patch_unembed = PatchUnEmbed(
764
- img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
765
- norm_layer=norm_layer if self.patch_norm else None)
766
-
767
- # absolute position embedding
768
- if self.ape:
769
- self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
770
- trunc_normal_(self.absolute_pos_embed, std=.02)
771
-
772
- self.pos_drop = nn.Dropout(p=drop_rate)
773
-
774
- # stochastic depth
775
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
776
-
777
- # build Residual Hierarchical Transformer blocks (RHTB)
778
- self.layers = nn.ModuleList()
779
- for i_layer in range(self.num_layers):
780
- layer = RHTB(dim=embed_dim,
781
- input_resolution=(patches_resolution[0],
782
- patches_resolution[1]),
783
- depth=depths[i_layer],
784
- num_heads=num_heads[i_layer],
785
- base_win_size=base_win_size,
786
- mlp_ratio=self.mlp_ratio,
787
- drop=drop_rate, value_drop=value_drop_rate,
788
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
789
- norm_layer=norm_layer,
790
- downsample=None,
791
- use_checkpoint=use_checkpoint,
792
- img_size=img_size,
793
- patch_size=patch_size,
794
- resi_connection=resi_connection,
795
- hier_win_ratios=hier_win_ratios
796
- )
797
- self.layers.append(layer)
798
- self.norm = norm_layer(self.num_features)
799
-
800
- # build the last conv layer in deep feature extraction
801
- if resi_connection == '1conv':
802
- self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
803
- elif resi_connection == '3conv':
804
- # to save parameters and memory
805
- self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
806
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
807
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
808
- nn.LeakyReLU(negative_slope=0.2, inplace=True),
809
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
810
-
811
- #####################################################################################################
812
- ################################ 3, high quality image reconstruction ################################
813
- if self.upsampler == 'pixelshuffle':
814
- # for classical SR
815
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
816
- nn.LeakyReLU(inplace=True))
817
- self.upsample = Upsample(upscale, num_feat)
818
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
819
- elif self.upsampler == 'pixelshuffledirect':
820
- # for lightweight SR (to save parameters)
821
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
822
- (patches_resolution[0], patches_resolution[1]))
823
- elif self.upsampler == 'nearest+conv':
824
- # for real-world SR (less artifacts)
825
- assert self.upscale == 4, 'only support x4 now.'
826
- self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
827
- nn.LeakyReLU(inplace=True))
828
- self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
829
- self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
830
- self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
831
- self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
832
- self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
833
- else:
834
- # for image denoising and JPEG compression artifact reduction
835
- self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
836
-
837
- self.apply(self._init_weights)
838
-
839
- def _init_weights(self, m):
840
- if isinstance(m, nn.Linear):
841
- trunc_normal_(m.weight, std=.02)
842
- if isinstance(m, nn.Linear) and m.bias is not None:
843
- nn.init.constant_(m.bias, 0)
844
- elif isinstance(m, nn.LayerNorm):
845
- nn.init.constant_(m.bias, 0)
846
- nn.init.constant_(m.weight, 1.0)
847
-
848
- @torch.jit.ignore
849
- def no_weight_decay(self):
850
- return {'absolute_pos_embed'}
851
-
852
- @torch.jit.ignore
853
- def no_weight_decay_keywords(self):
854
- return {'relative_position_bias_table'}
855
-
856
-
857
- def forward_features(self, x):
858
- x_size = (x.shape[2], x.shape[3])
859
- x = self.patch_embed(x)
860
- if self.ape:
861
- x = x + self.absolute_pos_embed
862
- x = self.pos_drop(x)
863
-
864
- for layer in self.layers:
865
- x = layer(x, x_size)
866
-
867
- x = self.norm(x) # B L C
868
- x = self.patch_unembed(x, x_size)
869
-
870
- return x
871
-
872
- def infer_image(self, image_path, cuda=True):
873
-
874
- io_backend_opt = {'type':'disk'}
875
- self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
876
-
877
- # load lq image
878
- lq_path = image_path
879
- img_bytes = self.file_client.get(lq_path, 'lq')
880
- img_lq = imfrombytes(img_bytes, float32=True)
881
-
882
- # BGR to RGB, HWC to CHW, numpy to tensor
883
- x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
884
-
885
- if cuda:
886
- x= x.cuda()
887
-
888
- out = self(x)
889
-
890
- if cuda:
891
- out = out.cpu()
892
-
893
- out = tensor2img(out)
894
-
895
- return out
896
-
897
- def forward(self, x):
898
- H, W = x.shape[2:]
899
-
900
- self.mean = self.mean.type_as(x)
901
- x = (x - self.mean) * self.img_range
902
-
903
- if self.upsampler == 'pixelshuffle':
904
- # for classical SR
905
- x = self.conv_first(x)
906
- x = self.conv_after_body(self.forward_features(x)) + x
907
- x = self.conv_before_upsample(x)
908
- x = self.conv_last(self.upsample(x))
909
- elif self.upsampler == 'pixelshuffledirect':
910
- # for lightweight SR
911
- x = self.conv_first(x)
912
- x = self.conv_after_body(self.forward_features(x)) + x
913
- x = self.upsample(x)
914
- elif self.upsampler == 'nearest+conv':
915
- # for real-world SR
916
- x = self.conv_first(x)
917
- x = self.conv_after_body(self.forward_features(x)) + x
918
- x = self.conv_before_upsample(x)
919
- x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
920
- x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
921
- x = self.conv_last(self.lrelu(self.conv_hr(x)))
922
- else:
923
- # for image denoising and JPEG compression artifact reduction
924
- x_first = self.conv_first(x)
925
- res = self.conv_after_body(self.forward_features(x_first)) + x_first
926
- x = x + self.conv_last(res)
927
-
928
- x = x / self.img_range + self.mean
929
-
930
- return x[:, :, :H*self.upscale, :W*self.upscale]
931
-
932
-
933
- if __name__ == '__main__':
934
- upscale = 4
935
- base_win_size = [8, 8]
936
- height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
937
- width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
938
-
939
- ## HiT-SIR
940
- model = HiT_SRF(upscale=4, img_size=(height, width),
941
- base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
942
- embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
943
-
944
- params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
945
- print("params: ", params_num)
946
-
947
-
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.utils.checkpoint as checkpoint
6
+ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
+
8
+ import numpy as np
9
+ from huggingface_hub import PyTorchModelHubMixin
10
+ from utils import FileClient, imfrombytes, img2tensor, tensor2img
11
+
12
+ class DFE(nn.Module):
13
+ """ Dual Feature Extraction
14
+ Args:
15
+ in_features (int): Number of input channels.
16
+ out_features (int): Number of output channels.
17
+ """
18
+ def __init__(self, in_features, out_features):
19
+ super().__init__()
20
+
21
+ self.out_features = out_features
22
+
23
+ self.conv = nn.Sequential(nn.Conv2d(in_features, in_features // 5, 1, 1, 0),
24
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
25
+ nn.Conv2d(in_features // 5, in_features // 5, 3, 1, 1),
26
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
27
+ nn.Conv2d(in_features // 5, out_features, 1, 1, 0))
28
+
29
+ self.linear = nn.Conv2d(in_features, out_features,1,1,0)
30
+
31
+ def forward(self, x, x_size):
32
+
33
+ B, L, C = x.shape
34
+ H, W = x_size
35
+ x = x.permute(0, 2, 1).contiguous().view(B, C, H, W)
36
+ x = self.conv(x) * self.linear(x)
37
+ x = x.view(B, -1, H*W).permute(0,2,1).contiguous()
38
+
39
+ return x
40
+
41
+ class Mlp(nn.Module):
42
+ """ MLP-based Feed-Forward Network
43
+ Args:
44
+ in_features (int): Number of input channels.
45
+ hidden_features (int | None): Number of hidden channels. Default: None
46
+ out_features (int | None): Number of output channels. Default: None
47
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
48
+ drop (float): Dropout rate. Default: 0.0
49
+ """
50
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
51
+ super().__init__()
52
+ out_features = out_features or in_features
53
+ hidden_features = hidden_features or in_features
54
+ self.fc1 = nn.Linear(in_features, hidden_features)
55
+ self.act = act_layer()
56
+ self.fc2 = nn.Linear(hidden_features, out_features)
57
+ self.drop = nn.Dropout(drop)
58
+
59
+ def forward(self, x):
60
+ x = self.fc1(x)
61
+ x = self.act(x)
62
+ x = self.drop(x)
63
+ x = self.fc2(x)
64
+ x = self.drop(x)
65
+ return x
66
+
67
+
68
+ class dwconv(nn.Module):
69
+ def __init__(self,hidden_features):
70
+ super(dwconv, self).__init__()
71
+ self.depthwise_conv = nn.Sequential(
72
+ nn.Conv2d(hidden_features, hidden_features, kernel_size=5, stride=1, padding=2, dilation=1,
73
+ groups=hidden_features), nn.GELU())
74
+ self.hidden_features = hidden_features
75
+ def forward(self,x,x_size):
76
+ x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous() # b Ph*Pw c
77
+ x = self.depthwise_conv(x)
78
+ x = x.flatten(2).transpose(1, 2).contiguous()
79
+ return x
80
+
81
+ class ConvFFN(nn.Module):
82
+
83
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
84
+ super().__init__()
85
+ out_features = out_features or in_features
86
+ hidden_features = hidden_features or in_features
87
+ self.fc1 = nn.Linear(in_features, hidden_features)
88
+ self.act = act_layer()
89
+ self.dwconv = dwconv(hidden_features=hidden_features)
90
+ self.fc2 = nn.Linear(hidden_features, out_features)
91
+ self.drop = nn.Dropout(drop)
92
+
93
+
94
+ def forward(self, x,x_size):
95
+ x = self.fc1(x)
96
+ x = self.act(x)
97
+ x = x + self.dwconv(x,x_size)
98
+ x = self.drop(x)
99
+ x = self.fc2(x)
100
+ x = self.drop(x)
101
+ return x
102
+
103
+ def window_partition(x, window_size):
104
+ """
105
+ Args:
106
+ x: (B, H, W, C)
107
+ window_size (tuple): window size
108
+
109
+ Returns:
110
+ windows: (num_windows*B, window_size, window_size, C)
111
+ """
112
+ B, H, W, C = x.shape
113
+ x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
114
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
115
+ return windows
116
+
117
+
118
+ def window_reverse(windows, window_size, H, W):
119
+ """
120
+ Args:
121
+ windows: (num_windows*B, window_size, window_size, C)
122
+ window_size (tuple): Window size
123
+ H (int): Height of image
124
+ W (int): Width of image
125
+
126
+ Returns:
127
+ x: (B, H, W, C)
128
+ """
129
+ B = int(windows.shape[0] * (window_size[0] * window_size[1]) / (H * W))
130
+ x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
131
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
132
+ return x
133
+
134
+ class DynamicPosBias(nn.Module):
135
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
136
+ """ Dynamic Relative Position Bias.
137
+ Args:
138
+ dim (int): Number of input channels.
139
+ num_heads (int): Number of heads for spatial self-correlation.
140
+ residual (bool): If True, use residual strage to connect conv.
141
+ """
142
+ def __init__(self, dim, num_heads, residual):
143
+ super().__init__()
144
+ self.residual = residual
145
+ self.num_heads = num_heads
146
+ self.pos_dim = dim // 4
147
+ self.pos_proj = nn.Linear(2, self.pos_dim)
148
+ self.pos1 = nn.Sequential(
149
+ nn.LayerNorm(self.pos_dim),
150
+ nn.ReLU(inplace=True),
151
+ nn.Linear(self.pos_dim, self.pos_dim),
152
+ )
153
+ self.pos2 = nn.Sequential(
154
+ nn.LayerNorm(self.pos_dim),
155
+ nn.ReLU(inplace=True),
156
+ nn.Linear(self.pos_dim, self.pos_dim)
157
+ )
158
+ self.pos3 = nn.Sequential(
159
+ nn.LayerNorm(self.pos_dim),
160
+ nn.ReLU(inplace=True),
161
+ nn.Linear(self.pos_dim, self.num_heads)
162
+ )
163
+ def forward(self, biases):
164
+ if self.residual:
165
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
166
+ pos = pos + self.pos1(pos)
167
+ pos = pos + self.pos2(pos)
168
+ pos = self.pos3(pos)
169
+ else:
170
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
171
+ return pos
172
+
173
+ class SCC(nn.Module):
174
+ """ Spatial-Channel Correlation.
175
+ Args:
176
+ dim (int): Number of input channels.
177
+ base_win_size (tuple[int]): The height and width of the base window.
178
+ window_size (tuple[int]): The height and width of the window.
179
+ num_heads (int): Number of heads for spatial self-correlation.
180
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
181
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
182
+ """
183
+
184
+ def __init__(self, dim, base_win_size, window_size, num_heads, value_drop=0., proj_drop=0.):
185
+
186
+ super().__init__()
187
+ # parameters
188
+ self.dim = dim
189
+ self.window_size = window_size
190
+ self.num_heads = num_heads
191
+
192
+ # feature projection
193
+ self.qv = DFE(dim, dim)
194
+ self.proj = nn.Linear(dim, dim)
195
+
196
+ # dropout
197
+ self.value_drop = nn.Dropout(value_drop)
198
+ self.proj_drop = nn.Dropout(proj_drop)
199
+
200
+ # base window size
201
+ min_h = min(self.window_size[0], base_win_size[0])
202
+ min_w = min(self.window_size[1], base_win_size[1])
203
+ self.base_win_size = (min_h, min_w)
204
+
205
+ # normalization factor and spatial linear layer for S-SC
206
+ head_dim = dim // (2*num_heads)
207
+ self.scale = head_dim
208
+ self.spatial_linear = nn.Linear(self.window_size[0]*self.window_size[1] // (self.base_win_size[0]*self.base_win_size[1]), 1)
209
+
210
+ # define a parameter table of relative position bias
211
+ self.H_sp, self.W_sp = self.window_size
212
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
213
+
214
+ def spatial_linear_projection(self, x):
215
+ B, num_h, L, C = x.shape
216
+ H, W = self.window_size
217
+ map_H, map_W = self.base_win_size
218
+
219
+ x = x.view(B, num_h, map_H, H//map_H, map_W, W//map_W, C).permute(0,1,2,4,6,3,5).contiguous().view(B, num_h, map_H*map_W, C, -1)
220
+ x = self.spatial_linear(x).view(B, num_h, map_H*map_W, C)
221
+ return x
222
+
223
+ def spatial_self_correlation(self, q, v):
224
+
225
+ B, num_head, L, C = q.shape
226
+
227
+ # spatial projection
228
+ v = self.spatial_linear_projection(v)
229
+
230
+ # compute correlation map
231
+ corr_map = (q @ v.transpose(-2,-1)) / self.scale
232
+
233
+ # add relative position bias
234
+ # generate mother-set
235
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp, device=v.device)
236
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp, device=v.device)
237
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
238
+ rpe_biases = biases.flatten(1).transpose(0, 1).contiguous().float()
239
+ pos = self.pos(rpe_biases)
240
+
241
+ # select position bias
242
+ coords_h = torch.arange(self.H_sp, device=v.device)
243
+ coords_w = torch.arange(self.W_sp, device=v.device)
244
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
245
+ coords_flatten = torch.flatten(coords, 1)
246
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
247
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
248
+ relative_coords[:, :, 0] += self.H_sp - 1
249
+ relative_coords[:, :, 1] += self.W_sp - 1
250
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
251
+ relative_position_index = relative_coords.sum(-1)
252
+ relative_position_bias = pos[relative_position_index.view(-1)].view(
253
+ self.window_size[0] * self.window_size[1], self.base_win_size[0], self.window_size[0]//self.base_win_size[0], self.base_win_size[1], self.window_size[1]//self.base_win_size[1], -1) # Wh*Ww,Wh*Ww,nH
254
+ relative_position_bias = relative_position_bias.permute(0,1,3,5,2,4).contiguous().view(
255
+ self.window_size[0] * self.window_size[1], self.base_win_size[0]*self.base_win_size[1], self.num_heads, -1).mean(-1)
256
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
257
+ corr_map = corr_map + relative_position_bias.unsqueeze(0)
258
+
259
+ # transformation
260
+ v_drop = self.value_drop(v)
261
+ x = (corr_map @ v_drop).permute(0,2,1,3).contiguous().view(B, L, -1)
262
+
263
+ return x
264
+
265
+ def channel_self_correlation(self, q, v):
266
+
267
+ B, num_head, L, C = q.shape
268
+
269
+ # apply single head strategy
270
+ q = q.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
271
+ v = v.permute(0,2,1,3).contiguous().view(B, L, num_head*C)
272
+
273
+ # compute correlation map
274
+ corr_map = (q.transpose(-2,-1) @ v) / L
275
+
276
+ # transformation
277
+ v_drop = self.value_drop(v)
278
+ x = (corr_map @ v_drop.transpose(-2,-1)).permute(0,2,1).contiguous().view(B, L, -1)
279
+
280
+ return x
281
+
282
+ def forward(self, x):
283
+ """
284
+ Args:
285
+ x: input features with shape of (B, H, W, C)
286
+ """
287
+ xB,xH,xW,xC = x.shape
288
+ qv = self.qv(x.view(xB,-1,xC), (xH,xW)).view(xB, xH, xW, xC)
289
+
290
+ # window partition
291
+ qv = window_partition(qv, self.window_size)
292
+ qv = qv.view(-1, self.window_size[0]*self.window_size[1], xC)
293
+
294
+ # qv splitting
295
+ B, L, C = qv.shape
296
+ qv = qv.view(B, L, 2, self.num_heads, C // (2*self.num_heads)).permute(2,0,3,1,4).contiguous()
297
+ q, v = qv[0], qv[1] # B, num_heads, L, C//num_heads
298
+
299
+ # spatial self-correlation (S-SC)
300
+ x_spatial = self.spatial_self_correlation(q, v)
301
+ x_spatial = x_spatial.view(-1, self.window_size[0], self.window_size[1], C//2)
302
+ x_spatial = window_reverse(x_spatial, (self.window_size[0],self.window_size[1]), xH, xW) # xB xH xW xC
303
+
304
+ # channel self-correlation (C-SC)
305
+ x_channel = self.channel_self_correlation(q, v)
306
+ x_channel = x_channel.view(-1, self.window_size[0], self.window_size[1], C//2)
307
+ x_channel = window_reverse(x_channel, (self.window_size[0], self.window_size[1]), xH, xW) # xB xH xW xC
308
+
309
+ # spatial-channel information fusion
310
+ x = torch.cat([x_spatial, x_channel], -1)
311
+ x = self.proj_drop(self.proj(x))
312
+
313
+ return x
314
+
315
+ def extra_repr(self) -> str:
316
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
317
+
318
+
319
+ class HierarchicalTransformerBlock(nn.Module):
320
+ """ Hierarchical Transformer Block.
321
+ Args:
322
+ dim (int): Number of input channels.
323
+ input_resolution (tuple[int]): Input resulotion.
324
+ num_heads (int): Number of heads for spatial self-correlation.
325
+ base_win_size (tuple[int]): The height and width of the base window.
326
+ window_size (tuple[int]): The height and width of the window.
327
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
328
+ drop (float, optional): Dropout rate. Default: 0.0
329
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
330
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
331
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
332
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
333
+ """
334
+
335
+ def __init__(self, dim, input_resolution, num_heads, base_win_size, window_size,
336
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0.,
337
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
338
+ super().__init__()
339
+ self.dim = dim
340
+ self.input_resolution = input_resolution
341
+ self.num_heads = num_heads
342
+ self.window_size = window_size
343
+ self.mlp_ratio = mlp_ratio
344
+
345
+ # check window size
346
+ if (window_size[0] > base_win_size[0]) and (window_size[1] > base_win_size[1]):
347
+ assert window_size[0] % base_win_size[0] == 0, "please ensure the window size is smaller than or divisible by the base window size"
348
+ assert window_size[1] % base_win_size[1] == 0, "please ensure the window size is smaller than or divisible by the base window size"
349
+
350
+
351
+ self.norm1 = norm_layer(dim)
352
+ self.correlation = SCC(
353
+ dim, base_win_size=base_win_size, window_size=self.window_size, num_heads=num_heads,
354
+ value_drop=value_drop, proj_drop=drop)
355
+
356
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
357
+ self.norm2 = norm_layer(dim)
358
+ mlp_hidden_dim = int(dim * mlp_ratio)
359
+ self.mlp = ConvFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
360
+ # self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
361
+
362
+ def check_image_size(self, x, win_size):
363
+ x = x.permute(0,3,1,2).contiguous()
364
+ _, _, h, w = x.size()
365
+ mod_pad_h = (win_size[0] - h % win_size[0]) % win_size[0]
366
+ mod_pad_w = (win_size[1] - w % win_size[1]) % win_size[1]
367
+
368
+ if mod_pad_h >= h or mod_pad_w >= w:
369
+ pad_h, pad_w = h-1, w-1
370
+ x = F.pad(x, (0, pad_w, 0, pad_h), 'reflect')
371
+ else:
372
+ pad_h, pad_w = 0, 0
373
+
374
+ mod_pad_h = mod_pad_h - pad_h
375
+ mod_pad_w = mod_pad_w - pad_w
376
+
377
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
378
+ x = x.permute(0,2,3,1).contiguous()
379
+ return x
380
+
381
+ def forward(self, x, x_size, win_size):
382
+ H, W = x_size
383
+ B, L, C = x.shape
384
+
385
+ shortcut = x
386
+ x = x.view(B, H, W, C)
387
+
388
+ # padding
389
+ x = self.check_image_size(x, win_size)
390
+ _, H_pad, W_pad, _ = x.shape # shape after padding
391
+
392
+ x = self.correlation(x)
393
+
394
+ # unpad
395
+ x = x[:, :H, :W, :].contiguous()
396
+
397
+ # norm
398
+ x = x.view(B, H * W, C)
399
+ x = self.norm1(x)
400
+
401
+ # FFN
402
+ x = shortcut + self.drop_path(x)
403
+ x = x + self.drop_path(self.norm2(self.mlp(x, x_size)))
404
+
405
+ return x
406
+
407
+ def extra_repr(self) -> str:
408
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
409
+ f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}"
410
+
411
+
412
+ class PatchMerging(nn.Module):
413
+ """ Patch Merging Layer.
414
+ Args:
415
+ input_resolution (tuple[int]): Resolution of input feature.
416
+ dim (int): Number of input channels.
417
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
418
+ """
419
+
420
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
421
+ super().__init__()
422
+ self.input_resolution = input_resolution
423
+ self.dim = dim
424
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
425
+ self.norm = norm_layer(4 * dim)
426
+
427
+ def forward(self, x):
428
+ """
429
+ x: B, H*W, C
430
+ """
431
+ H, W = self.input_resolution
432
+ B, L, C = x.shape
433
+ assert L == H * W, "input feature has wrong size"
434
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
435
+
436
+ x = x.view(B, H, W, C)
437
+
438
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
439
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
440
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
441
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
442
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
443
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
444
+
445
+ x = self.norm(x)
446
+ x = self.reduction(x)
447
+
448
+ return x
449
+
450
+ def extra_repr(self) -> str:
451
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
452
+
453
+
454
+ class BasicLayer(nn.Module):
455
+ """ A basic Hierarchical Transformer layer for one stage.
456
+
457
+ Args:
458
+ dim (int): Number of input channels.
459
+ input_resolution (tuple[int]): Input resolution.
460
+ depth (int): Number of blocks.
461
+ num_heads (int): Number of heads for spatial self-correlation.
462
+ base_win_size (tuple[int]): The height and width of the base window.
463
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
464
+ drop (float, optional): Dropout rate. Default: 0.0
465
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
466
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
467
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
468
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
469
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
470
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
471
+ """
472
+
473
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
474
+ mlp_ratio=4., drop=0., value_drop=0.,drop_path=0., norm_layer=nn.LayerNorm,
475
+ downsample=None, use_checkpoint=False, hier_win_ratios=[0.5,1,2,4,6,8]):
476
+
477
+ super().__init__()
478
+ self.dim = dim
479
+ self.input_resolution = input_resolution
480
+ self.depth = depth
481
+ self.use_checkpoint = use_checkpoint
482
+
483
+ self.win_hs = [int(base_win_size[0] * ratio) for ratio in hier_win_ratios]
484
+ self.win_ws = [int(base_win_size[1] * ratio) for ratio in hier_win_ratios]
485
+
486
+ # build blocks
487
+ self.blocks = nn.ModuleList([
488
+ HierarchicalTransformerBlock(dim=dim, input_resolution=input_resolution,
489
+ num_heads=num_heads,
490
+ base_win_size=base_win_size,
491
+ window_size=(self.win_hs[i], self.win_ws[i]),
492
+ mlp_ratio=mlp_ratio,
493
+ drop=drop, value_drop=value_drop,
494
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
495
+ norm_layer=norm_layer)
496
+ for i in range(depth)])
497
+
498
+ # patch merging layer
499
+ if downsample is not None:
500
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
501
+ else:
502
+ self.downsample = None
503
+
504
+ def forward(self, x, x_size):
505
+
506
+ i = 0
507
+ for blk in self.blocks:
508
+ if self.use_checkpoint:
509
+ x = checkpoint.checkpoint(blk, x, x_size, (self.win_hs[i], self.win_ws[i]))
510
+ else:
511
+ x = blk(x, x_size, (self.win_hs[i], self.win_ws[i]))
512
+ i = i + 1
513
+
514
+ if self.downsample is not None:
515
+ x = self.downsample(x)
516
+ return x
517
+
518
+ def extra_repr(self) -> str:
519
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
520
+
521
+
522
+ class RHTB(nn.Module):
523
+ """Residual Hierarchical Transformer Block (RHTB).
524
+ Args:
525
+ dim (int): Number of input channels.
526
+ input_resolution (tuple[int]): Input resolution.
527
+ depth (int): Number of blocks.
528
+ num_heads (int): Number of heads for spatial self-correlation.
529
+ base_win_size (tuple[int]): The height and width of the base window.
530
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
531
+ drop (float, optional): Dropout rate. Default: 0.0
532
+ value_drop (float, optional): Dropout ratio of value. Default: 0.0
533
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
534
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
535
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
536
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
537
+ img_size: Input image size.
538
+ patch_size: Patch size.
539
+ resi_connection: The convolutional block before residual connection.
540
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
541
+ """
542
+
543
+ def __init__(self, dim, input_resolution, depth, num_heads, base_win_size,
544
+ mlp_ratio=4., drop=0., value_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
545
+ downsample=None, use_checkpoint=False, img_size=224, patch_size=4,
546
+ resi_connection='1conv', hier_win_ratios=[0.5,1,2,4,6,8]):
547
+ super(RHTB, self).__init__()
548
+
549
+ self.dim = dim
550
+ self.input_resolution = input_resolution
551
+
552
+ self.residual_group = BasicLayer(dim=dim,
553
+ input_resolution=input_resolution,
554
+ depth=depth,
555
+ num_heads=num_heads,
556
+ base_win_size=base_win_size,
557
+ mlp_ratio=mlp_ratio,
558
+ drop=drop, value_drop=value_drop,
559
+ drop_path=drop_path,
560
+ norm_layer=norm_layer,
561
+ downsample=downsample,
562
+ use_checkpoint=use_checkpoint,
563
+ hier_win_ratios=hier_win_ratios)
564
+
565
+ if resi_connection == '1conv':
566
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
567
+ elif resi_connection == '3conv':
568
+ # to save parameters and memory
569
+ self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
570
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
571
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
572
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
573
+
574
+ self.patch_embed = PatchEmbed(
575
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
576
+ norm_layer=None)
577
+
578
+ self.patch_unembed = PatchUnEmbed(
579
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim,
580
+ norm_layer=None)
581
+
582
+ def forward(self, x, x_size):
583
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
584
+
585
+
586
+ class PatchEmbed(nn.Module):
587
+ r""" Image to Patch Embedding
588
+
589
+ Args:
590
+ img_size (int): Image size. Default: 224.
591
+ patch_size (int): Patch token size. Default: 4.
592
+ in_chans (int): Number of input image channels. Default: 3.
593
+ embed_dim (int): Number of linear projection output channels. Default: 96.
594
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
595
+ """
596
+
597
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
598
+ super().__init__()
599
+ img_size = to_2tuple(img_size)
600
+ patch_size = to_2tuple(patch_size)
601
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
602
+ self.img_size = img_size
603
+ self.patch_size = patch_size
604
+ self.patches_resolution = patches_resolution
605
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
606
+
607
+ self.in_chans = in_chans
608
+ self.embed_dim = embed_dim
609
+
610
+ if norm_layer is not None:
611
+ self.norm = norm_layer(embed_dim)
612
+ else:
613
+ self.norm = None
614
+
615
+ def forward(self, x):
616
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
617
+ if self.norm is not None:
618
+ x = self.norm(x)
619
+ return x
620
+
621
+
622
+ class PatchUnEmbed(nn.Module):
623
+ r""" Image to Patch Unembedding
624
+
625
+ Args:
626
+ img_size (int): Image size. Default: 224.
627
+ patch_size (int): Patch token size. Default: 4.
628
+ in_chans (int): Number of input image channels. Default: 3.
629
+ embed_dim (int): Number of linear projection output channels. Default: 96.
630
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
631
+ """
632
+
633
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
634
+ super().__init__()
635
+ img_size = to_2tuple(img_size)
636
+ patch_size = to_2tuple(patch_size)
637
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
638
+ self.img_size = img_size
639
+ self.patch_size = patch_size
640
+ self.patches_resolution = patches_resolution
641
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
642
+
643
+ self.in_chans = in_chans
644
+ self.embed_dim = embed_dim
645
+
646
+ def forward(self, x, x_size):
647
+ B, HW, C = x.shape
648
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
649
+ return x
650
+
651
+
652
+ class Upsample(nn.Sequential):
653
+ """Upsample module.
654
+
655
+ Args:
656
+ scale (int): Scale factor. Supported scales: 2^n and 3.
657
+ num_feat (int): Channel number of intermediate features.
658
+ """
659
+
660
+ def __init__(self, scale, num_feat):
661
+ m = []
662
+ if (scale & (scale - 1)) == 0: # scale = 2^n
663
+ for _ in range(int(math.log(scale, 2))):
664
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
665
+ m.append(nn.PixelShuffle(2))
666
+ elif scale == 3:
667
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
668
+ m.append(nn.PixelShuffle(3))
669
+ else:
670
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
671
+ super(Upsample, self).__init__(*m)
672
+
673
+
674
+ class UpsampleOneStep(nn.Sequential):
675
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
676
+ Used in lightweight SR to save parameters.
677
+
678
+ Args:
679
+ scale (int): Scale factor. Supported scales: 2^n and 3.
680
+ num_feat (int): Channel number of intermediate features.
681
+
682
+ """
683
+
684
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
685
+ self.num_feat = num_feat
686
+ self.input_resolution = input_resolution
687
+ m = []
688
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
689
+ m.append(nn.PixelShuffle(scale))
690
+ super(UpsampleOneStep, self).__init__(*m)
691
+
692
+
693
+ class HiT_SRF(nn.Module, PyTorchModelHubMixin):
694
+ """ HiT-SRF network.
695
+
696
+ Args:
697
+ img_size (int | tuple(int)): Input image size. Default 64
698
+ patch_size (int | tuple(int)): Patch size. Default: 1
699
+ in_chans (int): Number of input image channels. Default: 3
700
+ embed_dim (int): Patch embedding dimension. Default: 96
701
+ depths (tuple(int)): Depth of each Transformer block.
702
+ num_heads (tuple(int)): Number of heads for spatial self-correlation in different layers.
703
+ base_win_size (tuple[int]): The height and width of the base window.
704
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
705
+ drop_rate (float): Dropout rate. Default: 0
706
+ value_drop_rate (float): Dropout ratio of value. Default: 0.0
707
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
708
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
709
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
710
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
711
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
712
+ upscale (int): Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
713
+ img_range (float): Image range. 1. or 255.
714
+ upsampler (str): The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
715
+ resi_connection (str): The convolutional block before residual connection. '1conv'/'3conv'
716
+ hier_win_ratios (list): hierarchical window ratios for a transformer block. Default: [0.5,1,2,4,6,8].
717
+ """
718
+
719
+ def __init__(self, img_size=64, patch_size=1, in_chans=3,
720
+ embed_dim=60, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6],
721
+ base_win_size=[8,8], mlp_ratio=2.,
722
+ drop_rate=0., value_drop_rate=0., drop_path_rate=0.,
723
+ norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
724
+ use_checkpoint=False, upscale=4, img_range=1., upsampler='pixelshuffledirect', resi_connection='1conv',
725
+ hier_win_ratios=[0.5,1,2,4,6,8],
726
+ **kwargs):
727
+ super(HiT_SRF, self).__init__()
728
+ num_in_ch = in_chans
729
+ num_out_ch = in_chans
730
+ num_feat = 64
731
+ self.img_range = img_range
732
+ if in_chans == 3:
733
+ rgb_mean = (0.4488, 0.4371, 0.4040)
734
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
735
+ else:
736
+ self.mean = torch.zeros(1, 1, 1, 1)
737
+ self.upscale = upscale
738
+ self.upsampler = upsampler
739
+ self.base_win_size = base_win_size
740
+
741
+ #####################################################################################################
742
+ ################################### 1, shallow feature extraction ###################################
743
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
744
+
745
+ #####################################################################################################
746
+ ################################### 2, deep feature extraction ######################################
747
+ self.num_layers = len(depths)
748
+ self.embed_dim = embed_dim
749
+ self.ape = ape
750
+ self.patch_norm = patch_norm
751
+ self.num_features = embed_dim
752
+ self.mlp_ratio = mlp_ratio
753
+
754
+ # split image into non-overlapping patches
755
+ self.patch_embed = PatchEmbed(
756
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
757
+ norm_layer=norm_layer if self.patch_norm else None)
758
+ num_patches = self.patch_embed.num_patches
759
+ patches_resolution = self.patch_embed.patches_resolution
760
+ self.patches_resolution = patches_resolution
761
+
762
+ # merge non-overlapping patches into image
763
+ self.patch_unembed = PatchUnEmbed(
764
+ img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
765
+ norm_layer=norm_layer if self.patch_norm else None)
766
+
767
+ # absolute position embedding
768
+ if self.ape:
769
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
770
+ trunc_normal_(self.absolute_pos_embed, std=.02)
771
+
772
+ self.pos_drop = nn.Dropout(p=drop_rate)
773
+
774
+ # stochastic depth
775
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
776
+
777
+ # build Residual Hierarchical Transformer blocks (RHTB)
778
+ self.layers = nn.ModuleList()
779
+ for i_layer in range(self.num_layers):
780
+ layer = RHTB(dim=embed_dim,
781
+ input_resolution=(patches_resolution[0],
782
+ patches_resolution[1]),
783
+ depth=depths[i_layer],
784
+ num_heads=num_heads[i_layer],
785
+ base_win_size=base_win_size,
786
+ mlp_ratio=self.mlp_ratio,
787
+ drop=drop_rate, value_drop=value_drop_rate,
788
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
789
+ norm_layer=norm_layer,
790
+ downsample=None,
791
+ use_checkpoint=use_checkpoint,
792
+ img_size=img_size,
793
+ patch_size=patch_size,
794
+ resi_connection=resi_connection,
795
+ hier_win_ratios=hier_win_ratios
796
+ )
797
+ self.layers.append(layer)
798
+ self.norm = norm_layer(self.num_features)
799
+
800
+ # build the last conv layer in deep feature extraction
801
+ if resi_connection == '1conv':
802
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
803
+ elif resi_connection == '3conv':
804
+ # to save parameters and memory
805
+ self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
806
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
807
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
808
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
809
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
810
+
811
+ #####################################################################################################
812
+ ################################ 3, high quality image reconstruction ################################
813
+ if self.upsampler == 'pixelshuffle':
814
+ # for classical SR
815
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
816
+ nn.LeakyReLU(inplace=True))
817
+ self.upsample = Upsample(upscale, num_feat)
818
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
819
+ elif self.upsampler == 'pixelshuffledirect':
820
+ # for lightweight SR (to save parameters)
821
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
822
+ (patches_resolution[0], patches_resolution[1]))
823
+ elif self.upsampler == 'nearest+conv':
824
+ # for real-world SR (less artifacts)
825
+ assert self.upscale == 4, 'only support x4 now.'
826
+ self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1),
827
+ nn.LeakyReLU(inplace=True))
828
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
829
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
830
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
831
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
832
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
833
+ else:
834
+ # for image denoising and JPEG compression artifact reduction
835
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
836
+
837
+ self.apply(self._init_weights)
838
+
839
+ def _init_weights(self, m):
840
+ if isinstance(m, nn.Linear):
841
+ trunc_normal_(m.weight, std=.02)
842
+ if isinstance(m, nn.Linear) and m.bias is not None:
843
+ nn.init.constant_(m.bias, 0)
844
+ elif isinstance(m, nn.LayerNorm):
845
+ nn.init.constant_(m.bias, 0)
846
+ nn.init.constant_(m.weight, 1.0)
847
+
848
+ @torch.jit.ignore
849
+ def no_weight_decay(self):
850
+ return {'absolute_pos_embed'}
851
+
852
+ @torch.jit.ignore
853
+ def no_weight_decay_keywords(self):
854
+ return {'relative_position_bias_table'}
855
+
856
+
857
+ def forward_features(self, x):
858
+ x_size = (x.shape[2], x.shape[3])
859
+ x = self.patch_embed(x)
860
+ if self.ape:
861
+ x = x + self.absolute_pos_embed
862
+ x = self.pos_drop(x)
863
+
864
+ for layer in self.layers:
865
+ x = layer(x, x_size)
866
+
867
+ x = self.norm(x) # B L C
868
+ x = self.patch_unembed(x, x_size)
869
+
870
+ return x
871
+
872
+ def infer_image(self, image_path, device):
873
+
874
+ io_backend_opt = {'type':'disk'}
875
+ self.file_client = FileClient(io_backend_opt.pop('type'), **io_backend_opt)
876
+
877
+ # load lq image
878
+ lq_path = image_path
879
+ img_bytes = self.file_client.get(lq_path, 'lq')
880
+ img_lq = imfrombytes(img_bytes, float32=True)
881
+
882
+ # BGR to RGB, HWC to CHW, numpy to tensor
883
+ x = img2tensor(img_lq, bgr2rgb=True, float32=True)[None,...]
884
+
885
+ x= x.to(device)
886
+
887
+ out = self(x)
888
+
889
+ out = out.cpu()
890
+
891
+ out = tensor2img(out)
892
+
893
+ return out
894
+
895
+ def forward(self, x):
896
+ H, W = x.shape[2:]
897
+
898
+ self.mean = self.mean.type_as(x)
899
+ x = (x - self.mean) * self.img_range
900
+
901
+ if self.upsampler == 'pixelshuffle':
902
+ # for classical SR
903
+ x = self.conv_first(x)
904
+ x = self.conv_after_body(self.forward_features(x)) + x
905
+ x = self.conv_before_upsample(x)
906
+ x = self.conv_last(self.upsample(x))
907
+ elif self.upsampler == 'pixelshuffledirect':
908
+ # for lightweight SR
909
+ x = self.conv_first(x)
910
+ x = self.conv_after_body(self.forward_features(x)) + x
911
+ x = self.upsample(x)
912
+ elif self.upsampler == 'nearest+conv':
913
+ # for real-world SR
914
+ x = self.conv_first(x)
915
+ x = self.conv_after_body(self.forward_features(x)) + x
916
+ x = self.conv_before_upsample(x)
917
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
918
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
919
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
920
+ else:
921
+ # for image denoising and JPEG compression artifact reduction
922
+ x_first = self.conv_first(x)
923
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
924
+ x = x + self.conv_last(res)
925
+
926
+ x = x / self.img_range + self.mean
927
+
928
+ return x[:, :, :H*self.upscale, :W*self.upscale]
929
+
930
+
931
+ if __name__ == '__main__':
932
+ upscale = 4
933
+ base_win_size = [8, 8]
934
+ height = (1024 // upscale // base_win_size[0] + 1) * base_win_size[0]
935
+ width = (720 // upscale // base_win_size[1] + 1) * base_win_size[1]
936
+
937
+ ## HiT-SIR
938
+ model = HiT_SRF(upscale=4, img_size=(height, width),
939
+ base_win_size=base_win_size, img_range=1., depths=[6, 6, 6, 6],
940
+ embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect')
941
+
942
+ params_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
943
+ print("params: ", params_num)
944
+
945
+