xmutly commited on
Commit
5df2892
·
verified ·
1 Parent(s): 602c614

Upload 3 files

Browse files
Files changed (3) hide show
  1. IPG/IPG_arch.py +1242 -0
  2. IPG/arch_util.py +315 -0
  3. IPG/ipg_kit.py +199 -0
IPG/IPG_arch.py ADDED
@@ -0,0 +1,1242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+
17
+ import math, os
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.utils.checkpoint as checkpoint
21
+ import torch.nn.functional as F
22
+ from IPG.arch_util import to_2tuple, trunc_normal_
23
+ import numpy as np
24
+ import einops
25
+
26
+ from IPG.ipg_kit import flex, cossim, local_sampling, global_sampling
27
+
28
+ list_to_save = list()
29
+
30
+
31
+ class ChannelAttention(nn.Module):
32
+ """Channel attention used in RCAN.
33
+ Args:
34
+ num_feat (int): Channel number of intermediate features.
35
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
36
+ """
37
+
38
+ def __init__(self, num_feat, squeeze_factor=16):
39
+ super(ChannelAttention, self).__init__()
40
+ self.attention = nn.Sequential(
41
+ nn.AdaptiveAvgPool2d(1),
42
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
43
+ nn.ReLU(inplace=True),
44
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
45
+ nn.Sigmoid())
46
+
47
+ def forward(self, x):
48
+ y = self.attention(x)
49
+ return x * y
50
+
51
+
52
+ class CAB(nn.Module):
53
+
54
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30, conv_type=''):
55
+ super(CAB, self).__init__()
56
+ self.num_feat, self.compress_ratio, self.squeeze_factor = num_feat, compress_ratio, squeeze_factor
57
+ if conv_type == '':
58
+ self.cab = nn.Sequential(
59
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
60
+ nn.GELU(),
61
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
62
+ ChannelAttention(num_feat, squeeze_factor)
63
+ )
64
+ else:
65
+ self.cab = nn.Sequential(*self.block_selection(conv_type))
66
+
67
+ def block_selection(self, conv_type: str):
68
+ '''
69
+ only support post-ca; max conv num 2
70
+ '''
71
+ self.conv_type = conv_type
72
+ conv_types = conv_type.split('-')
73
+ keep_dim = ('dw' in conv_type) or (conv_type.count('conv') < 2)
74
+
75
+ dims = [self.num_feat, self.num_feat // (self.compress_ratio if not keep_dim else 1), self.num_feat]
76
+ conv_num = 0
77
+ blocks = list()
78
+ for name in conv_types:
79
+ if name == 'ca':
80
+ break
81
+ elif name == 'gelu':
82
+ blocks.append(nn.GELU())
83
+ elif name.startswith('conv'):
84
+ blocks.append(nn.Conv2d(dims[conv_num], dims[conv_num + 1], int(name[-1]), 1, (int(name[-1]) - 1) // 2))
85
+ conv_num += 1
86
+ elif name.startswith('dwconv'):
87
+ blocks.append(nn.Conv2d(dims[conv_num], dims[conv_num + 1], int(name[-1]), 1, (int(name[-1]) - 1) // 2,
88
+ groups=dims[conv_num]))
89
+ conv_num += 1
90
+
91
+ blocks.append(ChannelAttention(self.num_feat, self.squeeze_factor))
92
+
93
+ return blocks
94
+
95
+ def forward(self, x):
96
+ ''' x: (b c h w)
97
+ output: (b c h w)
98
+ '''
99
+ return self.cab(x)
100
+
101
+ def flops(self, n):
102
+ flops = 0
103
+ if self.conv_type == 'dwconv3-gelu-conv1-ca':
104
+ flops += self.num_feat * 9 * n + self.num_feat * self.num_feat * 1 * n
105
+ elif self.conv_type == 'conv3-gelu-conv3-ca':
106
+ flops += 2 * self.num_feat * (self.num_feat // self.compress_ratio) * 9 * n
107
+ else:
108
+ flops += 2 * self.num_feat * (
109
+ 1 if True else (self.num_feat // self.compress_ratio)) * 9 * n # two convs in cab
110
+ flops += 2 * (self.num_feat // self.squeeze_factor) * self.num_feat * 1 * 1 * 1 # channel_attention: 2 convs
111
+ return flops
112
+
113
+
114
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
115
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
116
+
117
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
118
+ """
119
+ if drop_prob == 0. or not training:
120
+ return x
121
+ keep_prob = 1 - drop_prob
122
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
123
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
124
+ random_tensor.floor_() # binarize
125
+ output = x.div(keep_prob) * random_tensor
126
+ return output
127
+
128
+
129
+ class DropPath(nn.Module):
130
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
131
+
132
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
133
+ """
134
+
135
+ def __init__(self, drop_prob=None):
136
+ super(DropPath, self).__init__()
137
+ self.drop_prob = drop_prob
138
+
139
+ def forward(self, x):
140
+ return drop_path(x, self.drop_prob, self.training)
141
+
142
+
143
+ class Mlp(nn.Module):
144
+
145
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
146
+ super().__init__()
147
+ out_features = out_features or in_features
148
+ hidden_features = hidden_features or in_features
149
+ self.fc1 = nn.Linear(in_features, hidden_features)
150
+ self.act = act_layer()
151
+ self.fc2 = nn.Linear(hidden_features, out_features)
152
+ self.drop = nn.Dropout(drop)
153
+
154
+ def forward(self, x):
155
+ x = self.fc1(x)
156
+ x = self.act(x)
157
+ x = self.drop(x)
158
+ x = self.fc2(x)
159
+ x = self.drop(x)
160
+ return x
161
+
162
+
163
+ class dwconv(nn.Module):
164
+ def __init__(self, hidden_features, tp='dwconv5'):
165
+ super(dwconv, self).__init__()
166
+ self.depthwise_conv = nn.Sequential(
167
+ nn.Conv2d(hidden_features, hidden_features, kernel_size=int(tp[-1]), stride=1,
168
+ padding=(int(tp[-1]) - 1) // 2, dilation=1,
169
+ groups=hidden_features if tp.startswith('dw') else 1), nn.GELU())
170
+ self.hidden_features = hidden_features
171
+
172
+ def forward(self, x, x_size):
173
+ x = x.transpose(1, 2).view(x.shape[0], self.hidden_features, x_size[0], x_size[1]).contiguous() # b Ph*Pw c
174
+ x = self.depthwise_conv(x)
175
+ x = x.flatten(2).transpose(1, 2).contiguous()
176
+ return x
177
+
178
+
179
+ class ConvFFN(nn.Module):
180
+
181
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., **kwargs):
182
+ super().__init__()
183
+ out_features = out_features or in_features
184
+ hidden_features = hidden_features or in_features
185
+ self.in_features, self.hidden_features = in_features, hidden_features
186
+ self.fc1 = nn.Linear(in_features, hidden_features)
187
+ self.act = act_layer()
188
+ self.before_add = nn.Identity()
189
+ self.after_add = nn.Identity()
190
+ if kwargs.get('FFNtype') is None:
191
+ self.kernel_size = 5
192
+ self.dwconv = dwconv(hidden_features=hidden_features)
193
+ elif kwargs.get('FFNtype') == 'none':
194
+ self.kernel_size = 0
195
+ self.dwconv = nn.Identity()
196
+ elif kwargs.get('FFNtype').startswith('basic'):
197
+ self.kernel_size = int(kwargs.get('FFNtype')[-1]) # figure out kernel size
198
+ self.dwconv = dwconv(hidden_features=hidden_features, tp=kwargs.get('FFNtype').split('-')[-1])
199
+ else:
200
+ raise NotImplementedError(f'FFNType {(kwargs.get("FFNtype"))} not implemented!')
201
+ self.fc2 = nn.Linear(hidden_features, out_features)
202
+ self.drop = nn.Dropout(drop)
203
+
204
+ def forward(self, x, x_size):
205
+ x = self.fc1(x)
206
+ x = self.act(x)
207
+ x = self.before_add(x)
208
+ if self.kernel_size > 0:
209
+ x = x + self.dwconv(x, x_size)
210
+ x = self.after_add(x)
211
+ x = self.drop(x)
212
+ x = self.fc2(x)
213
+ x = self.drop(x)
214
+ return x
215
+
216
+ def flops(self, n):
217
+ flops = 2 * n * self.in_features * self.hidden_features # fc1, fc2
218
+ flops += n * self.kernel_size * self.kernel_size * self.hidden_features # dwconv
219
+ return flops
220
+
221
+
222
+ class IPG_Grapher(nn.Module):
223
+
224
+ def __init__(self, dim, window_size, num_heads, bias=True, proj_drop=0.,
225
+ unfold_dict=None, head_wise=None, top_k=None, **kwargs):
226
+
227
+ super().__init__()
228
+ self.dim = dim
229
+ self.group_size = window_size
230
+ self.num_heads = num_heads
231
+
232
+ # graph_related
233
+ self.unfold_dict = unfold_dict
234
+ self.head_wise = head_wise
235
+ self.top_k = top_k
236
+ self.sample_size = unfold_dict['kernel_size']
237
+ self.graph_switch = kwargs.get('graph_switch', True)
238
+
239
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
240
+
241
+ self.proj_group = nn.Linear(dim, dim, bias=bias)
242
+ self.proj_sample = nn.Linear(dim, dim * 2, bias=bias)
243
+
244
+ self.proj = nn.Linear(dim, dim)
245
+
246
+ # rel pos bias
247
+ self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True),
248
+ nn.ReLU(inplace=True),
249
+ nn.Linear(512, num_heads, bias=False))
250
+
251
+ # get relative_coords_table
252
+ relative_coords_h = torch.arange(-(self.sample_size[0] - 1), self.group_size[0], dtype=torch.float32)
253
+ relative_coords_w = torch.arange(-(self.sample_size[1] - 1), self.group_size[1], dtype=torch.float32)
254
+ relative_coords_table = torch.stack(
255
+ torch.meshgrid([relative_coords_h,
256
+ relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
257
+
258
+ relative_coords_table[:, :, :, 0] /= (self.group_size[0] - 1)
259
+ relative_coords_table[:, :, :, 1] /= (self.group_size[1] - 1)
260
+ relative_coords_table *= 8 # normalize to -8, 8
261
+ relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
262
+ torch.abs(relative_coords_table) + 1.0) / np.log2(8)
263
+
264
+ self.register_buffer("relative_coords_table", relative_coords_table)
265
+
266
+ relative_position_index = self.get_rel_pos_index()
267
+ self.register_buffer("relative_position_index", relative_position_index)
268
+
269
+ self.relative_position_bias_table = None
270
+
271
+ def get_rel_pos_index(self):
272
+ group_size = self.group_size
273
+ sample_size = self.unfold_dict['kernel_size']
274
+
275
+ coords_grid = torch.stack(torch.meshgrid([torch.arange(group_size[0]), torch.arange(group_size[1])]))
276
+ coords_grid_flatten = torch.flatten(coords_grid, 1)
277
+
278
+ coords_sample = torch.stack(torch.meshgrid([torch.arange(sample_size[0]), torch.arange(sample_size[1])]))
279
+ coords_sample_flatten = torch.flatten(coords_sample, 1)
280
+
281
+ relative_coords = coords_sample_flatten[:, None, :] - coords_grid_flatten[:, :, None]
282
+
283
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
284
+ relative_coords[:, :, 0] += group_size[0] - sample_size[0] + 1
285
+ relative_coords[:, :, 0] *= group_size[1] + sample_size[1] - 1
286
+ relative_coords[:, :, 1] += group_size[1] - sample_size[1] + 1
287
+
288
+ relative_position_index = relative_coords.sum(-1)
289
+ return relative_position_index
290
+
291
+ def rel_pos_bias(self):
292
+ if self.training and self.relative_position_bias_table is not None:
293
+ self.relative_position_bias_table = None # clear
294
+
295
+ if not self.training and self.relative_position_bias_table is not None:
296
+ relative_position_bias_table = self.relative_position_bias_table
297
+ else:
298
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
299
+ # store
300
+ if not self.training and self.relative_position_bias_table is None:
301
+ self.relative_position_bias_table = relative_position_bias_table
302
+
303
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
304
+ self.group_size[0] * self.group_size[1], self.sample_size[0] * self.sample_size[1], -1)
305
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
306
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
307
+ return relative_position_bias.unsqueeze(0)
308
+
309
+ def get_correlation(self, x1, x2, graph):
310
+ scale = torch.exp(torch.clamp(self.logit_scale, max=4.6052))
311
+ if self.graph_switch:
312
+ assert (x1.size(-2) == graph.size(-2)) and (x2.size(-2) == graph.size(-1))
313
+
314
+ sim = cossim(x1, x2, graph=graph if self.graph_switch else None)
315
+
316
+ sim = sim * scale + self.rel_pos_bias()
317
+
318
+ sim = F.softmax(sim, dim=-1)
319
+
320
+ return sim
321
+
322
+ def forward(self, x_complete, graph=None, sampling_method=0):
323
+
324
+ if sampling_method == 0:
325
+ x = local_sampling(x_complete, group_size=self.group_size, unfold_dict=None, output=0, tp='bhwc')
326
+ else:
327
+ x = global_sampling(x_complete, group_size=self.group_size, sample_size=None, output=0, tp='bhwc')
328
+
329
+ b_, n, c = x.shape
330
+ x1 = einops.rearrange(self.proj_group(x), 'b n (h c) -> b h n c', b=b_, n=n, h=self.num_heads)
331
+
332
+ if sampling_method == 0:
333
+ x_sampled = local_sampling(self.proj_sample(x_complete), group_size=self.group_size,
334
+ unfold_dict=self.unfold_dict, output=1, tp='bhwc')
335
+ else:
336
+ x_sampled = global_sampling(self.proj_sample(x_complete), group_size=self.group_size,
337
+ sample_size=self.sample_size, output=1, tp='bhwc')
338
+
339
+ x2, feat = einops.rearrange(x_sampled, 'b n (div h c) -> div b h n c', div=2, h=self.num_heads,
340
+ c=c // self.num_heads)
341
+
342
+ corr = self.get_correlation(x1, x2, graph)
343
+
344
+ x = (corr @ feat).transpose(1, 2).reshape(b_, n, c)
345
+ x = self.proj(x)
346
+
347
+ return x
348
+
349
+ def extra_repr(self) -> str:
350
+ return f'dim={self.dim}, top_k={self.top_k}, ' \
351
+ f'sample_size={self.sample_size}'
352
+
353
+ def flops(self, N):
354
+ # calculate theoretical flops for graph aggregation
355
+ flops = 0
356
+ # parametrized similarity
357
+ flops += N * self.dim * 2 * self.dim
358
+ # self mapping
359
+ flops += N * self.dim * self.dim
360
+ # sim calc
361
+ flops += N * self.dim * self.top_k
362
+ flops += self.num_heads * N * self.sample_size[0] * self.sample_size[1] # relative pos
363
+ # aggregation
364
+ flops += N * self.dim * self.top_k
365
+ # project
366
+ flops += N * self.dim * self.dim
367
+ return flops
368
+
369
+
370
+ class GAL(nn.Module):
371
+
372
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, sampling_method=0,
373
+ mlp_ratio=4., bias=True, drop=0., drop_path=0.,
374
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, **kwargs):
375
+ super().__init__()
376
+ self.dim = dim
377
+ self.input_resolution = input_resolution
378
+ self.num_heads = num_heads
379
+ self.window_size = window_size
380
+ self.sampling_method = sampling_method
381
+ self.mlp_ratio = mlp_ratio
382
+
383
+ self.norm1 = norm_layer(dim)
384
+ self.grapher = IPG_Grapher(
385
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
386
+ bias=bias, proj_drop=drop, **kwargs)
387
+
388
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
389
+ self.norm2 = norm_layer(dim)
390
+ mlp_hidden_dim = int(dim * mlp_ratio)
391
+ self.mlp = ConvFFN(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, **kwargs)
392
+ attn_mask = None
393
+
394
+ self.register_buffer("attn_mask", attn_mask)
395
+
396
+ '''CAB related'''
397
+ self.conv_scale = kwargs.get('conv_scale') or 0
398
+ compress_ratio = kwargs.get('compress_ratio') or 3
399
+ squeeze_factor = kwargs.get('squeeze_factor') or 30
400
+ conv_type = kwargs.get('conv_type') or ''
401
+ self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor,
402
+ conv_type=conv_type) if self.conv_scale != 0 else None
403
+
404
+ def forward(self, x, x_size, graph):
405
+ H, W = x_size
406
+ B, _, C = x.shape
407
+
408
+ shortcut = x
409
+ x = x.view(B, H, W, C)
410
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1).contiguous().view(B, H * W,
411
+ C) if self.conv_scale != 0 else 0
412
+
413
+ x = self.grapher(x, graph=graph[0] if self.sampling_method == 0 else graph[1],
414
+ sampling_method=self.sampling_method)
415
+
416
+ # regroup
417
+ if self.sampling_method:
418
+ x = einops.rearrange(x, '(b numh numw) (sh sw) c -> b (sh numh sw numw) c', numh=H // self.window_size,
419
+ numw=W // self.window_size, sh=self.window_size, sw=self.window_size)
420
+ else:
421
+ x = einops.rearrange(x, '(b numh numw) (sh sw) c -> b (numh sh numw sw) c', numh=H // self.window_size,
422
+ numw=W // self.window_size, sh=self.window_size, sw=self.window_size)
423
+
424
+ x = shortcut + self.drop_path(self.norm1(x)) + conv_x * self.conv_scale # Channel Attention
425
+
426
+ # FFN
427
+ x = x + self.drop_path(self.norm2(self.mlp(x, x_size)))
428
+
429
+ return x
430
+
431
+ def extra_repr(self) -> str:
432
+ return f"dim={self.dim}, sampling_method={self.sampling_method}, mlp_ratio={self.mlp_ratio}"
433
+
434
+ def flops(self):
435
+ flops = 0
436
+ H, W = self.input_resolution
437
+ # norm1
438
+ flops += self.dim * H * W
439
+ # graph aggregation
440
+ flops += self.grapher.flops(H * W)
441
+ # Channel Attn
442
+ if self.conv_scale != 0:
443
+ flops += nW * self.conv_block.flops(self.window_size * self.window_size)
444
+
445
+ flops += self.mlp.flops(H * W)
446
+ # norm2
447
+ flops += self.dim * H * W
448
+ return flops
449
+
450
+
451
+ class PatchMerging(nn.Module):
452
+ r""" Patch Merging Layer.
453
+
454
+ Args:
455
+ input_resolution (tuple[int]): Resolution of input feature.
456
+ dim (int): Number of input channels.
457
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
458
+ """
459
+
460
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
461
+ super().__init__()
462
+ self.input_resolution = input_resolution
463
+ self.dim = dim
464
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
465
+ self.norm = norm_layer(4 * dim)
466
+
467
+ def forward(self, x):
468
+ """
469
+ x: b, h*w, c
470
+ """
471
+ h, w = self.input_resolution
472
+ b, seq_len, c = x.shape
473
+ assert seq_len == h * w, 'input feature has wrong size'
474
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
475
+
476
+ x = x.view(b, h, w, c)
477
+
478
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
479
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
480
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
481
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
482
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
483
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
484
+
485
+ x = self.norm(x)
486
+ x = self.reduction(x)
487
+
488
+ return x
489
+
490
+ def extra_repr(self) -> str:
491
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
492
+
493
+ def flops(self):
494
+ h, w = self.input_resolution
495
+ flops = h * w * self.dim
496
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
497
+ return flops
498
+
499
+
500
+ class BasicLayer(nn.Module):
501
+
502
+ def __init__(self,
503
+ dim,
504
+ input_resolution,
505
+ depth,
506
+ num_heads,
507
+ window_size,
508
+ mlp_ratio=4.,
509
+ bias=True,
510
+ drop=0.,
511
+ drop_path=0.,
512
+ norm_layer=nn.LayerNorm,
513
+ downsample=None,
514
+ use_checkpoint=False, stage_idx=None, **kwargs):
515
+
516
+ super().__init__()
517
+ self.dim = dim
518
+ self.input_resolution = input_resolution
519
+ self.depth = depth
520
+ self.use_checkpoint = use_checkpoint
521
+
522
+ stages = kwargs.get('stage_spec')[stage_idx]
523
+
524
+ blocks = []
525
+ for i in range(depth):
526
+ if stages[i] == 'GN':
527
+ block = GAL(
528
+ dim=dim,
529
+ input_resolution=input_resolution,
530
+ num_heads=num_heads,
531
+ window_size=window_size,
532
+ sampling_method=0, # flag controlling local/global sampling
533
+ mlp_ratio=mlp_ratio,
534
+ bias=bias,
535
+ drop=drop,
536
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
537
+ norm_layer=norm_layer, **kwargs
538
+ )
539
+ elif stages[i] == 'GS':
540
+ block = GAL(
541
+ dim=dim,
542
+ input_resolution=input_resolution,
543
+ num_heads=num_heads,
544
+ window_size=window_size,
545
+ sampling_method=1, # flag controlling dense/sparse
546
+ mlp_ratio=mlp_ratio,
547
+ bias=bias,
548
+ drop=drop,
549
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
550
+ norm_layer=norm_layer, **kwargs
551
+ )
552
+
553
+ blocks.append(block)
554
+ self.blocks = nn.ModuleList(blocks)
555
+
556
+ # patch merging layer
557
+ if downsample is not None:
558
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
559
+ else:
560
+ self.downsample = None
561
+
562
+ def forward(self, x, x_size, graph):
563
+ for blk in self.blocks:
564
+ if self.use_checkpoint:
565
+ x = checkpoint.checkpoint(blk, x)
566
+ else:
567
+ x = blk(x, x_size, graph)
568
+ if self.downsample is not None:
569
+ x = self.downsample(x)
570
+ return x
571
+
572
+ def extra_repr(self) -> str:
573
+ return f'dim={self.dim}, depth={self.depth}'
574
+
575
+ def flops(self):
576
+ flops = 0
577
+ for blk in self.blocks:
578
+ flops += blk.flops()
579
+ if self.downsample is not None:
580
+ flops += self.downsample.flops()
581
+ return flops
582
+
583
+
584
+ class MGB(nn.Module):
585
+
586
+ def __init__(self,
587
+ dim,
588
+ input_resolution,
589
+ depth,
590
+ num_heads,
591
+ window_size,
592
+ mlp_ratio=4.,
593
+ bias=True,
594
+ drop=0.,
595
+ drop_path=0.,
596
+ norm_layer=nn.LayerNorm,
597
+ downsample=None,
598
+ use_checkpoint=False,
599
+ img_size=224,
600
+ patch_size=4,
601
+ resi_connection='1conv', stage_idx=None, **kwargs):
602
+ super(MGB, self).__init__()
603
+ self.kwargs = kwargs
604
+
605
+ self.dim = dim
606
+ self.input_resolution = input_resolution
607
+
608
+ self.window_size = window_size
609
+ self.sample_size = kwargs.get('sample_size')
610
+ self.padding_size = (self.sample_size - self.window_size) // 2
611
+ self.unfold_dict = dict(kernel_size=(self.sample_size, self.sample_size), stride=(window_size, window_size),
612
+ padding=(self.padding_size, self.padding_size))
613
+
614
+ # graph related
615
+ self.num_head = num_heads
616
+ self.graph_flag = kwargs.get('graph_flags')[stage_idx]
617
+ self.head_wise = kwargs.get('head_wise', 0)
618
+ self.dist_type = kwargs.get('dist_type')
619
+
620
+ self.fast_graph = kwargs.get('fast_graph', 1)
621
+
622
+ self.dist = cossim
623
+ self.top_k = kwargs.get('top_k')[stage_idx] if isinstance(kwargs.get('top_k'), list) else kwargs.get('top_k')
624
+ # flex graph
625
+ self.flex_type = kwargs.get('flex_type')
626
+ self.graph_switch = kwargs.get('graph_switch')
627
+
628
+ self.stage_idx = stage_idx
629
+ self.output_folder = kwargs.get('output_folder')
630
+
631
+ # interdiff diff_scale: control ratio mean/variance of final budget
632
+ self.diff_scale = kwargs.get('diff_scales')[stage_idx] if kwargs.get(
633
+ 'diff_scales') is not None else None # if diff_scale is 0: X_diff scaling not activated
634
+
635
+ self.residual_group = BasicLayer(
636
+ dim=dim,
637
+ input_resolution=input_resolution,
638
+ depth=depth,
639
+ num_heads=num_heads,
640
+ window_size=window_size,
641
+ mlp_ratio=mlp_ratio,
642
+ bias=bias,
643
+ drop=drop,
644
+ drop_path=drop_path,
645
+ norm_layer=norm_layer,
646
+ downsample=downsample,
647
+ use_checkpoint=use_checkpoint, stage_idx=stage_idx, unfold_dict=self.unfold_dict, **kwargs)
648
+
649
+ if resi_connection == '1conv':
650
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
651
+ elif resi_connection == '3conv':
652
+ # to save parameters and memory
653
+ self.conv = nn.Sequential(
654
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
655
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
656
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
657
+
658
+ self.patch_embed = PatchEmbed(
659
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
660
+
661
+ self.patch_unembed = PatchUnEmbed(
662
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
663
+
664
+ self.tensors = None
665
+ self.tolerance = kwargs.get('tolerance', 8)
666
+
667
+ def diff(self, x, shape=(80, 80), scale=2, he=1):
668
+ ''' x: (B,H*W,C)
669
+ diff: (B, H, W)
670
+ '''
671
+ B, _, C = x.shape
672
+ H, W = shape
673
+ x_rs = x.view(B, H, W, C // he, he).mean(-1).permute(0, 3, 1, 2)
674
+ return (x_rs - F.interpolate(
675
+ F.interpolate(x_rs, (H // scale, W // scale), mode='bilinear', align_corners=False), (H, W),
676
+ mode='bilinear', align_corners=False)).abs().sum(dim=1)
677
+
678
+ @torch.no_grad()
679
+ def calc_graph(self, x_, x_size, sim_matric=None):
680
+ if self.output_folder is not None:
681
+ list_to_save.append(x_.cpu())
682
+ if not self.graph_switch:
683
+ return None, None
684
+
685
+ # prepare const tensors
686
+ if self.fast_graph and self.tensors is None:
687
+ self.tensors = (
688
+ torch.tensor([
689
+ [0.5, 1., 0.],
690
+ [0., 0., 0.],
691
+ [0.5, 0., 1.],
692
+ ], dtype=torch.float32).to(x_.device),
693
+ torch.tensor([
694
+ [0.5, 0., 1.],
695
+ [0.5, 1., 0.],
696
+ [0., 0., 0.],
697
+ ], dtype=torch.float32).to(x_.device)
698
+ )
699
+
700
+ ''' Added: x_diff for interdiff_plain'''
701
+ X_diff = [None, None]
702
+ if self.flex_type.startswith('interdiff'):
703
+ X_diff = self.diff(x_, x_size) # (b h w) do var on C dimension
704
+ if (self.diff_scale is not None) and (self.diff_scale != 0): # perform X_diff scaling
705
+ # affine transform
706
+ mu = X_diff.mean(dim=(-2, -1), keepdim=True) # (b 1 1)
707
+ X_diff = mu + (X_diff - mu) / self.diff_scale
708
+
709
+
710
+ ################ overwrite X_diff to sim-matric
711
+ if sim_matric != None:
712
+ X_diff = X_diff*sim_matric.detach()#X_diff*sim_matric.detach()
713
+
714
+ X_diff = [
715
+ einops.rearrange(X_diff, 'b (numh wh) (numw ww)-> (b numh numw) (wh ww)', wh=self.window_size,
716
+ ww=self.window_size),
717
+ einops.rearrange(X_diff, 'b (sh numh) (sw numw) -> (b numh numw) (sh sw)', sh=self.window_size,
718
+ sw=self.window_size)
719
+ ]
720
+
721
+ graph0 = self.calc_graph_(x_, x_size, sampling_method=0, X_diff=X_diff[0])
722
+ graph1 = self.calc_graph_(x_, x_size, sampling_method=1, X_diff=X_diff[1])
723
+ return (graph0, graph1)
724
+
725
+ @torch.no_grad()
726
+ def calc_graph_(self, x_, x_size, sampling_method=0, X_diff=None):
727
+ ''' x: (b c h w)
728
+ '''
729
+ # head_wise: not implemented
730
+ he = self.num_head if self.head_wise else 1
731
+ x = einops.rearrange(x_, 'b (h w) c -> b c h w', h=x_size[0], w=x_size[1])
732
+ # cyclic shift
733
+ if sampling_method: # sparse global
734
+ X_sample, Y_sample = global_sampling(x, group_size=self.window_size, sample_size=self.sample_size, output=2,
735
+ tp='bchw')
736
+ else: # dense local
737
+ X_sample, Y_sample = local_sampling(x, group_size=self.window_size, unfold_dict=self.unfold_dict, output=2,
738
+ tp='bchw')
739
+
740
+ assert X_sample.size(0) == Y_sample.size(0)
741
+
742
+ D = self.dist(X_sample.unsqueeze(1), Y_sample.unsqueeze(1)).squeeze(1) # (b m n)
743
+
744
+ if self.fast_graph: # Fast graph construction
745
+ maskarray = (X_diff / X_diff.sum(dim=-1, keepdim=True)) * D.size(1) * self.top_k
746
+ maskarray = torch.clamp(maskarray, 1, D.size(-1))
747
+
748
+ # search for threshold
749
+ minbound = torch.min(D, dim=-1, keepdim=True)[0]
750
+ maxbound = torch.ones_like(minbound) # D.max(dim=-1, keepdim=True)
751
+ wall = D.mean(dim=-1, keepdim=True)
752
+ MAT = torch.cat([wall, minbound, maxbound], dim=-1)
753
+
754
+ for _ in range(self.tolerance):
755
+ allocated = (D > MAT[..., 0:1]).sum(dim=-1)
756
+ MAT = torch.where(
757
+ (allocated > maskarray).unsqueeze(-1),
758
+ MAT @ self.tensors[0],
759
+ MAT @ self.tensors[1],
760
+ )
761
+
762
+ graph = (D > MAT[..., 0:1]).unsqueeze(1) # add head dim
763
+ else:
764
+ val, idx = D.sort(dim=-1, descending=True) # (b m n)
765
+ b, m, n = idx.shape
766
+
767
+ mask = flex(D, X_sample, idx, self.flex_type, self.top_k, self.kwargs['model'].current_iter,
768
+ self.kwargs['model'].total_iters, X_diff, fast=True) # TODO: calc mask
769
+
770
+ if not self.head_wise: # expand for each head
771
+ idx = idx.unsqueeze(1).expand(b, 1, m, n) # b he m n
772
+ mask = mask.unsqueeze(1).expand(b, 1, m, n) # b he m n
773
+ else:
774
+ idx = einops.rearrange(idx, '(b he) m n -> b he m n', he=he)
775
+ mask = einops.rearrange(mask, '(b he) m n -> b he m n', he=he)
776
+ original_shape = idx.shape
777
+ b_coord = torch.arange(idx.size(0), device=idx.device).int().view(-1, 1, 1, 1) * np.prod(original_shape[1:])
778
+ he_coord = torch.arange(idx.size(1), device=idx.device).int().view(1, -1, 1, 1) * np.prod(
779
+ original_shape[2:])
780
+ m_coord = torch.arange(idx.size(2), device=idx.device).int().view(1, 1, -1, 1) * original_shape[3]
781
+
782
+ overall_coord = b_coord + he_coord + m_coord + idx
783
+ selected_coord = torch.masked_select(overall_coord, mask)
784
+ graph = torch.ones_like(idx).bool()
785
+ graph.view(-1)[selected_coord] = False # turned off connections
786
+ '''save graph'''
787
+ if self.output_folder is not None:
788
+ list_to_save.append(graph.cpu())
789
+
790
+ return graph
791
+
792
+ def forward(self, x, x_size, prev_graph=None, sim_matric=None):
793
+ graph = self.calc_graph(x, x_size, sim_matric) if self.graph_flag else prev_graph
794
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size, graph), x_size))) + x, graph
795
+
796
+ def flops(self):
797
+ flops = 0
798
+ h, w = self.input_resolution
799
+ # self added: graph flops (2 graphs)
800
+ if self.graph_switch:
801
+ # interdiff_plain:
802
+ if self.flex_type == 'interdiff_plain':
803
+ flops += h // 2 * w // 2 * 4 * self.dim
804
+ flops += h * w * 4 * self.dim
805
+ flops += 2 * h * w * self.dim * self.sample_size * self.sample_size # matrix mul for GRAM (B, wH*wW, dim) * (B, dim, oH*oW); two graphs
806
+ if self.fast_graph:
807
+ sort_flops = 2 * self.tolerance * 3 * 3
808
+ else:
809
+ sort_flops = round(self.sample_size * self.sample_size * math.log2(self.sample_size * self.sample_size))
810
+ # print('SORT FLOPS:', sort_flops * h * w)
811
+ flops += sort_flops * h * w
812
+ flops += self.residual_group.flops()
813
+ flops += h * w * self.dim * self.dim * 9
814
+ flops += self.patch_embed.flops()
815
+ flops += self.patch_unembed.flops()
816
+
817
+ return flops
818
+
819
+
820
+ class PatchEmbed(nn.Module):
821
+ r""" Image to Patch Embedding
822
+
823
+ Args:
824
+ img_size (int): Image size. Default: 224.
825
+ patch_size (int): Patch token size. Default: 4.
826
+ in_chans (int): Number of input image channels. Default: 3.
827
+ embed_dim (int): Number of linear projection output channels. Default: 96.
828
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
829
+ """
830
+
831
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
832
+ super().__init__()
833
+ img_size = to_2tuple(img_size)
834
+ patch_size = to_2tuple(patch_size)
835
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
836
+ self.img_size = img_size
837
+ self.patch_size = patch_size
838
+ self.patches_resolution = patches_resolution
839
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
840
+
841
+ self.in_chans = in_chans
842
+ self.embed_dim = embed_dim
843
+
844
+ if norm_layer is not None:
845
+ self.norm = norm_layer(embed_dim)
846
+ else:
847
+ self.norm = None
848
+
849
+ def forward(self, x):
850
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
851
+ if self.norm is not None:
852
+ x = self.norm(x)
853
+ return x
854
+
855
+ def flops(self):
856
+ flops = 0
857
+ h, w = self.img_size
858
+ if self.norm is not None:
859
+ flops += h * w * self.embed_dim
860
+ return flops
861
+
862
+
863
+ class PatchUnEmbed(nn.Module):
864
+ r""" Image to Patch Unembedding
865
+
866
+ Args:
867
+ img_size (int): Image size. Default: 224.
868
+ patch_size (int): Patch token size. Default: 4.
869
+ in_chans (int): Number of input image channels. Default: 3.
870
+ embed_dim (int): Number of linear projection output channels. Default: 96.
871
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
872
+ """
873
+
874
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
875
+ super().__init__()
876
+ img_size = to_2tuple(img_size)
877
+ patch_size = to_2tuple(patch_size)
878
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
879
+ self.img_size = img_size
880
+ self.patch_size = patch_size
881
+ self.patches_resolution = patches_resolution
882
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
883
+
884
+ self.in_chans = in_chans
885
+ self.embed_dim = embed_dim
886
+
887
+ def forward(self, x, x_size):
888
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
889
+ return x
890
+
891
+ def flops(self): # self added
892
+ return 0
893
+
894
+
895
+ class Upsample(nn.Sequential):
896
+ """Upsample module.
897
+
898
+ Args:
899
+ scale (int): Scale factor. Supported scales: 2^n and 3.
900
+ num_feat (int): Channel number of intermediate features.
901
+ """
902
+
903
+ def __init__(self, scale, num_feat):
904
+ self.scale = scale
905
+ self.num_feat = num_feat
906
+ m = []
907
+ if (scale & (scale - 1)) == 0: # scale = 2^n
908
+ for _ in range(int(math.log(scale, 2))):
909
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
910
+ m.append(nn.PixelShuffle(2))
911
+ elif scale == 3:
912
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
913
+ m.append(nn.PixelShuffle(3))
914
+ else:
915
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
916
+ super(Upsample, self).__init__(*m)
917
+
918
+ def flops(self, n):
919
+ flops = 0
920
+ scale = self.scale
921
+ num_feat = self.num_feat
922
+ this_n = n
923
+ if (scale & (scale - 1)) == 0: # scale = 2^n
924
+ for _ in range(int(math.log(scale, 2))):
925
+ flops += num_feat * 4 * num_feat * 3 * 3 * this_n
926
+ this_n *= 4
927
+ elif scale == 3:
928
+ flops += num_feat * 9 * num_feat * 3 * 3 * n
929
+ # print('Upsampler flops (G): ',flops//1e9)
930
+ return flops
931
+
932
+
933
+ class UpsampleOneStep(nn.Sequential):
934
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
935
+ Used in lightweight SR to save parameters.
936
+
937
+ Args:
938
+ scale (int): Scale factor. Supported scales: 2^n and 3.
939
+ num_feat (int): Channel number of intermediate features.
940
+
941
+ """
942
+
943
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
944
+ self.num_feat = num_feat
945
+ self.input_resolution = input_resolution
946
+ m = []
947
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
948
+ m.append(nn.PixelShuffle(scale))
949
+ super(UpsampleOneStep, self).__init__(*m)
950
+
951
+ def flops(self):
952
+ h, w = self.input_resolution
953
+ flops = h * w * self.num_feat * 3 * 9
954
+ return flops
955
+
956
+
957
+ class IPG(nn.Module):
958
+
959
+ def __init__(self,
960
+ img_size=64,
961
+ patch_size=1,
962
+ in_chans=3,
963
+ out_chans=32,
964
+ embed_dim=96,
965
+ depths=(6, 6, 6, 6),
966
+ num_heads=(6, 6, 6, 6),
967
+ window_size=7,
968
+ mlp_ratio=4.,
969
+ bias=True,
970
+ drop_rate=0.,
971
+ attn_drop_rate=0.,
972
+ drop_path_rate=0.1,
973
+ norm_layer=nn.LayerNorm,
974
+ ape=False,
975
+ patch_norm=True,
976
+ use_checkpoint=False,
977
+ upscale=2,
978
+ img_range=1.,
979
+ upsampler='',
980
+ resi_connection='1conv',
981
+ **kwargs):
982
+ super(IPG, self).__init__()
983
+ num_in_ch = in_chans
984
+ num_out_ch = out_chans
985
+ num_feat = 64
986
+ self.img_range = img_range
987
+ if in_chans == 3:
988
+ rgb_mean = (0.4488, 0.4371, 0.4040)
989
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
990
+ else:
991
+ self.mean = torch.zeros(1, 1, 1, 1)
992
+ self.upscale = upscale
993
+ self.upsampler = upsampler
994
+
995
+ # ------------------------- 1, shallow feature extraction ------------------------- #
996
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
997
+
998
+ # ------------------------- 2, deep feature extraction ------------------------- #
999
+ self.num_layers = len(depths)
1000
+ self.embed_dim = embed_dim
1001
+ self.ape = ape
1002
+ self.patch_norm = patch_norm
1003
+ self.num_features = embed_dim
1004
+ self.mlp_ratio = mlp_ratio
1005
+
1006
+ # split image into non-overlapping patches
1007
+ self.patch_embed = PatchEmbed(
1008
+ img_size=img_size,
1009
+ patch_size=patch_size,
1010
+ in_chans=embed_dim,
1011
+ embed_dim=embed_dim,
1012
+ norm_layer=norm_layer if self.patch_norm else None)
1013
+ num_patches = self.patch_embed.num_patches
1014
+ patches_resolution = self.patch_embed.patches_resolution
1015
+ self.patches_resolution = patches_resolution
1016
+
1017
+ # merge non-overlapping patches into image
1018
+ self.patch_unembed = PatchUnEmbed(
1019
+ img_size=img_size,
1020
+ patch_size=patch_size,
1021
+ in_chans=embed_dim,
1022
+ embed_dim=embed_dim,
1023
+ norm_layer=norm_layer if self.patch_norm else None)
1024
+
1025
+ # absolute position embedding
1026
+ if self.ape:
1027
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
1028
+ trunc_normal_(self.absolute_pos_embed, std=.02)
1029
+
1030
+ self.pos_drop = nn.Dropout(p=drop_rate)
1031
+
1032
+ # stochastic depth
1033
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
1034
+
1035
+ ''' Intermediate outputs '''
1036
+ self.output_folder = kwargs.get('output_folder')
1037
+
1038
+ self.layers = nn.ModuleList()
1039
+ for i_layer in range(self.num_layers):
1040
+ layer = MGB(
1041
+ dim=embed_dim,
1042
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1043
+ depth=depths[i_layer],
1044
+ num_heads=num_heads[i_layer],
1045
+ window_size=window_size,
1046
+ mlp_ratio=self.mlp_ratio,
1047
+ bias=bias,
1048
+ drop=drop_rate,
1049
+ attn_drop=attn_drop_rate,
1050
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
1051
+ norm_layer=norm_layer,
1052
+ downsample=None,
1053
+ use_checkpoint=use_checkpoint,
1054
+ img_size=img_size,
1055
+ patch_size=patch_size,
1056
+ resi_connection=resi_connection, stage_idx=i_layer, **kwargs)
1057
+ self.layers.append(layer)
1058
+ self.norm = norm_layer(self.num_features)
1059
+
1060
+ self.proj = nn.Linear(embed_dim, 1024)
1061
+ self.proj2 = nn.Linear(64,1)
1062
+
1063
+ # build the last conv layer in deep feature extraction
1064
+ if resi_connection == '1conv':
1065
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1066
+ elif resi_connection == '3conv':
1067
+ # to save parameters and memory
1068
+ self.conv_after_body = nn.Sequential(
1069
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
1070
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
1071
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
1072
+
1073
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
1074
+ if self.upsampler == 'pixelshuffle':
1075
+ # for classical SR
1076
+ self.conv_before_upsample = nn.Sequential(
1077
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1078
+ self.upsample = Upsample(upscale, num_feat)
1079
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1080
+ elif self.upsampler == 'pixelshuffledirect':
1081
+ # for lightweight SR (to save parameters)
1082
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
1083
+ (patches_resolution[0], patches_resolution[1]))
1084
+ elif self.upsampler == 'nearest+conv':
1085
+ # for real-world SR (less artifacts)
1086
+ assert self.upscale == 4, 'only support x4 now.'
1087
+ self.conv_before_upsample = nn.Sequential(
1088
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
1089
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1090
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1091
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1092
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1093
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1094
+ else:
1095
+ # for image denoising and JPEG compression artifact reduction
1096
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1097
+
1098
+ self.apply(self._init_weights)
1099
+
1100
+ def _init_weights(self, m):
1101
+ if isinstance(m, nn.Linear):
1102
+ trunc_normal_(m.weight, std=.02)
1103
+ if isinstance(m, nn.Linear) and m.bias is not None:
1104
+ nn.init.constant_(m.bias, 0)
1105
+ elif isinstance(m, nn.LayerNorm):
1106
+ nn.init.constant_(m.bias, 0)
1107
+ nn.init.constant_(m.weight, 1.0)
1108
+
1109
+ @torch.jit.ignore
1110
+ def no_weight_decay(self):
1111
+ return {'absolute_pos_embed'}
1112
+
1113
+ @torch.jit.ignore
1114
+ def no_weight_decay_keywords(self):
1115
+ return {'relative_position_bias_table'}
1116
+
1117
+ def forward_features(self, x, sim_matric=None):
1118
+ x_size = (x.shape[2], x.shape[3])
1119
+ x = self.patch_embed(x)
1120
+ if self.ape:
1121
+ x = x + self.absolute_pos_embed
1122
+ x = self.pos_drop(x)
1123
+ prev_graph = None
1124
+ for layer in self.layers:
1125
+ x, prev_graph = layer(x, x_size, prev_graph, sim_matric)
1126
+
1127
+ x = self.norm(x) # b seq_len c
1128
+ x = self.patch_unembed(x, x_size)
1129
+
1130
+ return x
1131
+
1132
+ def forward(self, x, sim_matric=None):
1133
+ '''
1134
+ Set index & save input
1135
+ '''
1136
+ if (self.output_folder is not None):
1137
+ global list_to_save
1138
+ if not os.path.isdir(self.output_folder):
1139
+ os.makedirs(self.output_folder, exist_ok=True)
1140
+ if len(os.listdir(self.output_folder)) > 0:
1141
+ output_idx = max([int(i[:-4]) if i.endswith('.pkl') and i[:-4].isdecimal() else -1 for i in
1142
+ os.listdir(self.output_folder)]) + 1
1143
+ else:
1144
+ output_idx = 0
1145
+ list_to_save.append(x.cpu())
1146
+ self.mean = self.mean.type_as(x)
1147
+ x = (x - self.mean) * self.img_range
1148
+
1149
+
1150
+ if self.upsampler == 'pixelshuffle':
1151
+ # for classical SR
1152
+ x = self.conv_first(x)
1153
+ x = self.conv_after_body(self.forward_features(x)) + x
1154
+ x = self.conv_before_upsample(x)
1155
+ x = self.conv_last(self.upsample(x))
1156
+
1157
+ elif self.upsampler == 'sam':
1158
+ # x = self.conv_first(x)
1159
+ x = self.conv_after_body(self.forward_features(x,sim_matric)) + x
1160
+ x = self.proj2(x.flatten(2,3))
1161
+ x = x.permute(0,2,1)
1162
+ x=self.proj(x)
1163
+ # x = self.conv_before_upsample(x)
1164
+ # x = self.conv_last(self.upsample(x))
1165
+ elif self.upsampler == 'pixelshuffledirect':
1166
+ # for lightweight SR
1167
+ x = self.conv_first(x)
1168
+ x = self.conv_after_body(self.forward_features(x)) + x
1169
+ x = self.upsample(x)
1170
+ elif self.upsampler == 'nearest+conv':
1171
+ # for real-world SR
1172
+ x = self.conv_first(x)
1173
+ x = self.conv_after_body(self.forward_features(x)) + x
1174
+ x = self.conv_before_upsample(x)
1175
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1176
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1177
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1178
+ else:
1179
+ # for image denoising and JPEG compression artifact reduction
1180
+ x_first = self.conv_first(x)
1181
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1182
+ x = x + self.conv_last(res)
1183
+
1184
+ # x = x / self.img_range + self.mean
1185
+ # ''' Save '''
1186
+ # if (self.output_folder is not None):
1187
+ # list_to_save.append(x.cpu())
1188
+ # torch.save(list_to_save, os.path.join(self.output_folder, str(output_idx) + '.pkl'))
1189
+ # list_to_save = list()
1190
+
1191
+ return x
1192
+
1193
+ def flops(self):
1194
+ flops = 0
1195
+ h, w = self.patches_resolution
1196
+ flops += h * w * 3 * self.embed_dim * 9
1197
+ flops += self.patch_embed.flops()
1198
+ for layer in self.layers:
1199
+ flops += layer.flops()
1200
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
1201
+ flops += self.upsample.flops(h * w)
1202
+ return flops
1203
+
1204
+
1205
+ if __name__ == '__main__':
1206
+ upscale = 4
1207
+ height = (512 // upscale)
1208
+ width = (512 // upscale)
1209
+ model = IPG(
1210
+ upscale=4,
1211
+ in_chans=3,
1212
+ img_size=(height, width),
1213
+ window_size=16,
1214
+ img_range=1.,
1215
+ depths=[6, 6, 6, 6, 6, 6],
1216
+ embed_dim=180,
1217
+ num_heads=[6, 6, 6, 6, 6, 6],
1218
+ mlp_ratio=4,
1219
+ upsampler='pixelshuffle',
1220
+ resi_connection='1conv',
1221
+ graph_flags=[1, 1, 1, 1, 1, 1],
1222
+ stage_spec=[['GN', 'GS', 'GN', 'GS', 'GN', 'GS'], ['GN', 'GS', 'GN', 'GS', 'GN', 'GS'],
1223
+ ['GN', 'GS', 'GN', 'GS', 'GN', 'GS'], ['GN', 'GS', 'GN', 'GS', 'GN', 'GS'],
1224
+ ['GN', 'GS', 'GN', 'GS', 'GN', 'GS'], ['GN', 'GS', 'GN', 'GS', 'GN', 'GS']],
1225
+ dist_type='cossim',
1226
+ top_k=256,
1227
+ head_wise=0,
1228
+ sample_size=32,
1229
+ graph_switch=1,
1230
+ flex_type='interdiff_plain',
1231
+ FFNtype='basic-dwconv3',
1232
+ conv_scale=0,
1233
+ conv_type='dwconv3-gelu-conv1-ca',
1234
+ diff_scales=[10, 1.5, 1.5, 1.5, 1.5, 1.5],
1235
+ fast_graph=1
1236
+ )
1237
+ print(model)
1238
+ print(height, width, model.flops() / 1e9)
1239
+
1240
+ x = torch.randn((1, 3, height, width))
1241
+ x = model(x)
1242
+ print(x.shape)
IPG/arch_util.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+ import math
3
+ import torch
4
+ import warnings
5
+ from itertools import repeat
6
+ from torch import nn as nn
7
+ from torch.nn import functional as F
8
+ from torch.nn import init as init
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+
11
+ # from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv
12
+
13
+
14
+ @torch.no_grad()
15
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
16
+ """Initialize network weights.
17
+
18
+ Args:
19
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
20
+ scale (float): Scale initialized weights, especially for residual
21
+ blocks. Default: 1.
22
+ bias_fill (float): The value to fill bias. Default: 0
23
+ kwargs (dict): Other arguments for initialization function.
24
+ """
25
+ if not isinstance(module_list, list):
26
+ module_list = [module_list]
27
+ for module in module_list:
28
+ for m in module.modules():
29
+ if isinstance(m, nn.Conv2d):
30
+ init.kaiming_normal_(m.weight, **kwargs)
31
+ m.weight.data *= scale
32
+ if m.bias is not None:
33
+ m.bias.data.fill_(bias_fill)
34
+ elif isinstance(m, nn.Linear):
35
+ init.kaiming_normal_(m.weight, **kwargs)
36
+ m.weight.data *= scale
37
+ if m.bias is not None:
38
+ m.bias.data.fill_(bias_fill)
39
+ elif isinstance(m, _BatchNorm):
40
+ init.constant_(m.weight, 1)
41
+ if m.bias is not None:
42
+ m.bias.data.fill_(bias_fill)
43
+
44
+
45
+ def make_layer(basic_block, num_basic_block, **kwarg):
46
+ """Make layers by stacking the same blocks.
47
+
48
+ Args:
49
+ basic_block (nn.module): nn.module class for basic block.
50
+ num_basic_block (int): number of blocks.
51
+
52
+ Returns:
53
+ nn.Sequential: Stacked blocks in nn.Sequential.
54
+ """
55
+ layers = []
56
+ for _ in range(num_basic_block):
57
+ layers.append(basic_block(**kwarg))
58
+ return nn.Sequential(*layers)
59
+
60
+
61
+ class ResidualBlockNoBN(nn.Module):
62
+ """Residual block without BN.
63
+
64
+ It has a style of:
65
+ ---Conv-ReLU-Conv-+-
66
+ |________________|
67
+
68
+ Args:
69
+ num_feat (int): Channel number of intermediate features.
70
+ Default: 64.
71
+ res_scale (float): Residual scale. Default: 1.
72
+ pytorch_init (bool): If set to True, use pytorch default init,
73
+ otherwise, use default_init_weights. Default: False.
74
+ """
75
+
76
+ def __init__(self, num_feat=64, res_scale=1, pytorch_init=False):
77
+ super(ResidualBlockNoBN, self).__init__()
78
+ self.res_scale = res_scale
79
+ self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
80
+ self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True)
81
+ self.relu = nn.ReLU(inplace=True)
82
+
83
+ if not pytorch_init:
84
+ default_init_weights([self.conv1, self.conv2], 0.1)
85
+
86
+ def forward(self, x):
87
+ identity = x
88
+ out = self.conv2(self.relu(self.conv1(x)))
89
+ return identity + out * self.res_scale
90
+
91
+
92
+ class Upsample(nn.Sequential):
93
+ """Upsample module.
94
+
95
+ Args:
96
+ scale (int): Scale factor. Supported scales: 2^n and 3.
97
+ num_feat (int): Channel number of intermediate features.
98
+ """
99
+
100
+ def __init__(self, scale, num_feat):
101
+ m = []
102
+ if (scale & (scale - 1)) == 0: # scale = 2^n
103
+ for _ in range(int(math.log(scale, 2))):
104
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
105
+ m.append(nn.PixelShuffle(2))
106
+ elif scale == 3:
107
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
108
+ m.append(nn.PixelShuffle(3))
109
+ else:
110
+ raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.')
111
+ super(Upsample, self).__init__(*m)
112
+
113
+
114
+ def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True):
115
+ """Warp an image or feature map with optical flow.
116
+
117
+ Args:
118
+ x (Tensor): Tensor with size (n, c, h, w).
119
+ flow (Tensor): Tensor with size (n, h, w, 2), normal value.
120
+ interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'.
121
+ padding_mode (str): 'zeros' or 'border' or 'reflection'.
122
+ Default: 'zeros'.
123
+ align_corners (bool): Before pytorch 1.3, the default value is
124
+ align_corners=True. After pytorch 1.3, the default value is
125
+ align_corners=False. Here, we use the True as default.
126
+
127
+ Returns:
128
+ Tensor: Warped image or feature map.
129
+ """
130
+ assert x.size()[-2:] == flow.size()[1:3]
131
+ _, _, h, w = x.size()
132
+ # create mesh grid
133
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x))
134
+ grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
135
+ grid.requires_grad = False
136
+
137
+ vgrid = grid + flow
138
+ # scale grid to [-1,1]
139
+ vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0
140
+ vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0
141
+ vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
142
+ output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners)
143
+
144
+ # TODO, what if align_corners=False
145
+ return output
146
+
147
+
148
+ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False):
149
+ """Resize a flow according to ratio or shape.
150
+
151
+ Args:
152
+ flow (Tensor): Precomputed flow. shape [N, 2, H, W].
153
+ size_type (str): 'ratio' or 'shape'.
154
+ sizes (list[int | float]): the ratio for resizing or the final output
155
+ shape.
156
+ 1) The order of ratio should be [ratio_h, ratio_w]. For
157
+ downsampling, the ratio should be smaller than 1.0 (i.e., ratio
158
+ < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e.,
159
+ ratio > 1.0).
160
+ 2) The order of output_size should be [out_h, out_w].
161
+ interp_mode (str): The mode of interpolation for resizing.
162
+ Default: 'bilinear'.
163
+ align_corners (bool): Whether align corners. Default: False.
164
+
165
+ Returns:
166
+ Tensor: Resized flow.
167
+ """
168
+ _, _, flow_h, flow_w = flow.size()
169
+ if size_type == 'ratio':
170
+ output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1])
171
+ elif size_type == 'shape':
172
+ output_h, output_w = sizes[0], sizes[1]
173
+ else:
174
+ raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.')
175
+
176
+ input_flow = flow.clone()
177
+ ratio_h = output_h / flow_h
178
+ ratio_w = output_w / flow_w
179
+ input_flow[:, 0, :, :] *= ratio_w
180
+ input_flow[:, 1, :, :] *= ratio_h
181
+ resized_flow = F.interpolate(
182
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
183
+ return resized_flow
184
+
185
+
186
+ # TODO: may write a cpp file
187
+ def pixel_unshuffle(x, scale):
188
+ """ Pixel unshuffle.
189
+
190
+ Args:
191
+ x (Tensor): Input feature with shape (b, c, hh, hw).
192
+ scale (int): Downsample ratio.
193
+
194
+ Returns:
195
+ Tensor: the pixel unshuffled feature.
196
+ """
197
+ b, c, hh, hw = x.size()
198
+ out_channel = c * (scale**2)
199
+ assert hh % scale == 0 and hw % scale == 0
200
+ h = hh // scale
201
+ w = hw // scale
202
+ x_view = x.view(b, c, h, scale, w, scale)
203
+ return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w)
204
+
205
+
206
+ # class DCNv2Pack(ModulatedDeformConvPack):
207
+ # """Modulated deformable conv for deformable alignment.
208
+ #
209
+ # Different from the official DCNv2Pack, which generates offsets and masks
210
+ # from the preceding features, this DCNv2Pack takes another different
211
+ # features to generate offsets and masks.
212
+ #
213
+ # Ref:
214
+ # Delving Deep into Deformable Alignment in Video Super-Resolution.
215
+ # """
216
+ #
217
+ # def forward(self, x, feat):
218
+ # out = self.conv_offset(feat)
219
+ # o1, o2, mask = torch.chunk(out, 3, dim=1)
220
+ # offset = torch.cat((o1, o2), dim=1)
221
+ # mask = torch.sigmoid(mask)
222
+ #
223
+ # offset_absmean = torch.mean(torch.abs(offset))
224
+ # if offset_absmean > 50:
225
+ # logger = get_root_logger()
226
+ # logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
227
+ #
228
+ # if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
229
+ # return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
230
+ # self.dilation, mask)
231
+ # else:
232
+ # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
233
+ # self.dilation, self.groups, self.deformable_groups)
234
+
235
+
236
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
237
+ # From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
238
+ # Cut & paste from Pytorch official master until it's in a few official releases - RW
239
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
240
+ def norm_cdf(x):
241
+ # Computes standard normal cumulative distribution function
242
+ return (1. + math.erf(x / math.sqrt(2.))) / 2.
243
+
244
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
245
+ warnings.warn(
246
+ 'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
247
+ 'The distribution of values may be incorrect.',
248
+ stacklevel=2)
249
+
250
+ with torch.no_grad():
251
+ # Values are generated by using a truncated uniform distribution and
252
+ # then using the inverse CDF for the normal distribution.
253
+ # Get upper and lower cdf values
254
+ low = norm_cdf((a - mean) / std)
255
+ up = norm_cdf((b - mean) / std)
256
+
257
+ # Uniformly fill tensor with values from [low, up], then translate to
258
+ # [2l-1, 2u-1].
259
+ tensor.uniform_(2 * low - 1, 2 * up - 1)
260
+
261
+ # Use inverse cdf transform for normal distribution to get truncated
262
+ # standard normal
263
+ tensor.erfinv_()
264
+
265
+ # Transform to proper mean, std
266
+ tensor.mul_(std * math.sqrt(2.))
267
+ tensor.add_(mean)
268
+
269
+ # Clamp to ensure it's in the proper range
270
+ tensor.clamp_(min=a, max=b)
271
+ return tensor
272
+
273
+
274
+ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
275
+ r"""Fills the input Tensor with values drawn from a truncated
276
+ normal distribution.
277
+
278
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/weight_init.py
279
+
280
+ The values are effectively drawn from the
281
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
282
+ with values outside :math:`[a, b]` redrawn until they are within
283
+ the bounds. The method used for generating the random values works
284
+ best when :math:`a \leq \text{mean} \leq b`.
285
+
286
+ Args:
287
+ tensor: an n-dimensional `torch.Tensor`
288
+ mean: the mean of the normal distribution
289
+ std: the standard deviation of the normal distribution
290
+ a: the minimum cutoff value
291
+ b: the maximum cutoff value
292
+
293
+ Examples:
294
+ >>> w = torch.empty(3, 5)
295
+ >>> nn.init.trunc_normal_(w)
296
+ """
297
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
298
+
299
+
300
+ # From Pytorch
301
+ def _ntuple(n):
302
+
303
+ def parse(x):
304
+ if isinstance(x, collections.abc.Iterable):
305
+ return x
306
+ return tuple(repeat(x, n))
307
+
308
+ return parse
309
+
310
+
311
+ to_1tuple = _ntuple(1)
312
+ to_2tuple = _ntuple(2)
313
+ to_3tuple = _ntuple(3)
314
+ to_4tuple = _ntuple(4)
315
+ to_ntuple = _ntuple
IPG/ipg_kit.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Huawei Technologies Co., Ltd
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ============================================================================
15
+
16
+ import torch
17
+ import einops
18
+ import torch.nn.functional as F
19
+
20
+
21
+ def get_mask(idx, array):
22
+ '''
23
+ array: b m, records # of elements to be masked
24
+ '''
25
+ b, m = array.shape
26
+ n = idx.size(-1)
27
+ A = torch.arange(n, dtype=idx.dtype, device=idx.device).unsqueeze(0).unsqueeze(0).expand(b, m, n) # 1 1 n -> b m n
28
+ mask = A < array.unsqueeze(-1)
29
+ return mask
30
+
31
+
32
+ def alloc(var, rest, budget, tp, maximum, times=0, fast=False):
33
+ '''
34
+ var: (b m) variance of each pixel POSITIVE VALUE
35
+ rest: (b m) list of already allocated budgets
36
+ budget: (b) remaining to be allocated
37
+ tp: mean type, plain/softmax
38
+ maximum: maximum budget for each pixel
39
+ '''
40
+ b, m = var.shape
41
+ if tp == 'plain':
42
+ var_p = var * (rest < maximum)
43
+ var_sum = var_p.sum(dim=-1, keepdim=True) # b 1
44
+ proportion = var_p / var_sum # b m
45
+ elif tp == 'softmax':
46
+ var_p = var.clone()
47
+ var_p[rest >= maximum] = -float('inf') # maximum
48
+ proportion = torch.nn.functional.softmax(var_p, dim=-1) # b m
49
+ allocation = torch.round(proportion * budget.unsqueeze(1)) # b m
50
+ new_rest = torch.clamp(rest + allocation, 0, maximum) # b m
51
+ remain_budget = budget - (new_rest - rest).sum(dim=-1) # b m allocated
52
+ negative_remain = (remain_budget < 0)
53
+ while negative_remain.sum() > 0:
54
+ offset = torch.eye(m, device=rest.device)[
55
+ torch.randint(m, (negative_remain.sum().int().item(),), device=rest.device)]
56
+ new_rest[negative_remain] = torch.clamp(new_rest[negative_remain] - offset, 1, maximum) # reduce by one
57
+
58
+ # update remain budget
59
+ remain_budget = budget - (new_rest - rest).sum(dim=-1) # b m allocated
60
+ negative_remain = (remain_budget < 0)
61
+
62
+ if (remain_budget > 0).sum() > 0:
63
+ if times < 3:
64
+ new_rest[remain_budget > 0] = alloc(var[remain_budget > 0], new_rest[remain_budget > 0],
65
+ remain_budget[remain_budget > 0], tp, maximum, times + 1, fast=fast)
66
+ elif not fast: # precise budget allocation
67
+ positive_remain = (remain_budget > 0)
68
+ while positive_remain.sum() > 0:
69
+ offset = torch.eye(m, device=rest.device)[
70
+ torch.randint(m, (positive_remain.sum().int().item(),), device=rest.device)]
71
+ new_rest[positive_remain] = torch.clamp(new_rest[positive_remain] + offset, 1, maximum) # add by one
72
+ # update remain budget
73
+ remain_budget = budget - (new_rest - rest).sum(dim=-1) # b m allocated
74
+ positive_remain = (remain_budget > 0)
75
+ return new_rest
76
+
77
+
78
+ def flex(D_: torch.Tensor, X: torch.Tensor, idx: torch.Tensor, flex_type, topk_, current_iter, total_iters, X_diff,
79
+ fast=False, return_maskarray=False):
80
+ '''
81
+ D: (b m n) Gram matrix, sorted on last dim, descending
82
+ X: (b numh numw he) c (sh sw) X_data
83
+ idx: (b m n) sorted index of D
84
+ x_size: (h, w) 2-tuple tensor
85
+ OUT: (b m n) Binary mask
86
+ '''
87
+ b, m, n = D_.shape
88
+ if flex_type is None or flex_type == 'none':
89
+ mask_array = topk_ * torch.ones((b, m), dtype=torch.int, device=D_.device)
90
+
91
+ elif flex_type == 'gsort':
92
+ D = D_.clone()
93
+ D -= (D == D.max(dim=-1, keepdim=True)) * 100000 # neglect max position
94
+ val, g_idx = torch.sort(D.view(b, -1), dim=-1, descending=True) # global sort
95
+ # g_idx: (b m*n)
96
+ g_idx += m * n * torch.arange(b, dtype=g_idx.dtype, device=g_idx.device).unsqueeze(-1) # b 1
97
+ non_topk_idx = g_idx[:, topk_ * (m - 1):] # select top k, neglect max
98
+
99
+ mask_ = torch.ones_like(D).bool()
100
+ mask_.view(-1)[non_topk_idx.reshape(-1)] = False # set to negative value
101
+ mask_array = mask_.sum(dim=-1)
102
+ mask_array += 1 # include max, ensure each pixel has at least one match
103
+
104
+ elif flex_type == 'interdiff_plain': # interpolate and diff
105
+
106
+ rest = torch.ones_like(X_diff)
107
+ budget = torch.ones(b, dtype=torch.int, device=idx.device) * (topk_ - 1) * idx.size(1)
108
+ mask_array = alloc(X_diff, rest, budget, tp='plain', maximum=idx.size(-1), fast=fast)
109
+ else:
110
+ raise NotImplementedError(f'Graph type {flex_type} not implemented...')
111
+
112
+ if return_maskarray:
113
+ return mask_array
114
+
115
+ mask = ~get_mask(idx, mask_array) # negated
116
+
117
+ return mask
118
+
119
+
120
+ def cossim(X_sample, Y_sample, graph=None):
121
+ if graph is not None:
122
+ return torch.einsum('a b m c, a b n c -> a b m n', F.normalize(X_sample, dim=-1),
123
+ F.normalize(Y_sample, dim=-1)) + (-100.) * (~graph)
124
+ return torch.einsum('a b m c, a b n c -> a b m n', F.normalize(X_sample, dim=-1), F.normalize(Y_sample, dim=-1))
125
+
126
+
127
+ def local_sampling(x, group_size, unfold_dict, output=0, tp='bhwc'):
128
+ '''
129
+ output:
130
+ x (grouped) [B, nn, c]
131
+ x_unfold [B, NN, C]
132
+ 0/1/2: grouped, sampled, both
133
+ '''
134
+ if isinstance(group_size, int):
135
+ group_size = (group_size, group_size)
136
+
137
+ if output != 1:
138
+ if tp == 'bhwc':
139
+ x_grouped = einops.rearrange(x, 'b (numh sh) (numw sw) c-> (b numh numw) (sh sw) c', sh=group_size[0],
140
+ sw=group_size[1])
141
+ elif tp == 'bchw':
142
+ x_grouped = einops.rearrange(x, 'b c (numh sh) (numw sw)-> (b numh numw) (sh sw) c', sh=group_size[0],
143
+ sw=group_size[1])
144
+
145
+ if output == 0:
146
+ return x_grouped
147
+
148
+ if tp == 'bhwc':
149
+ x = einops.rearrange(x, 'b h w c -> b c h w')
150
+
151
+ x_sampled = einops.rearrange(F.unfold(x, **unfold_dict), 'b (c k0 k1) l -> (b l) (k0 k1) c',
152
+ k0=unfold_dict['kernel_size'][0], k1=unfold_dict['kernel_size'][1])
153
+
154
+ if output == 1:
155
+ return x_sampled
156
+
157
+ assert x_grouped.size(0) == x_sampled.size(0)
158
+ return x_grouped, x_sampled
159
+
160
+
161
+ def global_sampling(x, group_size, sample_size, output=0, tp='bhwc'):
162
+ '''
163
+ output:
164
+ x (grouped) [B, nn, c]
165
+ x_unfold [B, NN, C]
166
+ '''
167
+ if isinstance(group_size, int):
168
+ group_size = (group_size, group_size)
169
+ if isinstance(sample_size, int):
170
+ sample_size = (sample_size, sample_size)
171
+
172
+ if output != 1:
173
+ if tp == 'bchw':
174
+ x_grouped = einops.rearrange(x, 'b c (sh numh) (sw numw) -> (b numh numw) (sh sw) c', sh=group_size[0],
175
+ sw=group_size[1])
176
+ elif tp == 'bhwc':
177
+ x_grouped = einops.rearrange(x, 'b (sh numh) (sw numw) c -> (b numh numw) (sh sw) c', sh=group_size[0],
178
+ sw=group_size[1])
179
+
180
+ if output == 0:
181
+ return x_grouped
182
+
183
+ if tp == 'bchw':
184
+ x_sampled = einops.rearrange(x, 'b c (sh extrah numh) (sw extraw numw) -> b extrah numh extraw numw c sh sw',
185
+ sh=sample_size[0], sw=sample_size[1], extrah=1, extraw=1)
186
+ elif tp == 'bhwc':
187
+ x_sampled = einops.rearrange(x, 'b (sh extrah numh) (sw extraw numw) c -> b extrah numh extraw numw c sh sw',
188
+ sh=sample_size[0], sw=sample_size[1], extrah=1, extraw=1)
189
+ b_y, _, numh, _, numw, c_y, sh_y, sw_y = x_sampled.shape
190
+ ratio_h, ratio_w = sample_size[0] // group_size[0], sample_size[1] // group_size[1]
191
+ x_sampled = x_sampled.expand(b_y, ratio_h, numh, ratio_w, numw, c_y, sh_y, sw_y).reshape(-1, c_y,
192
+ sh_y * sw_y).permute(0, 2,
193
+ 1)
194
+
195
+ if output == 1:
196
+ return x_sampled
197
+
198
+ assert x_grouped.size(0) == x_sampled.size(0)
199
+ return x_grouped, x_sampled