kittendev commited on
Commit
650351e
1 Parent(s): 1748dd5

Update core/networks.py

Browse files
Files changed (1) hide show
  1. core/networks.py +360 -355
core/networks.py CHANGED
@@ -1,355 +1,360 @@
1
- # Copyright (C) 2021 * Ltd. All rights reserved.
2
- # author : Sanghyeon Jo <[email protected]>
3
-
4
- import math
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
-
10
- from torchvision import models
11
- import torch.utils.model_zoo as model_zoo
12
-
13
- from .arch_resnet import resnet
14
- from .arch_resnest import resnest
15
- from .abc_modules import ABC_Model
16
-
17
- from .deeplab_utils import ASPP, Decoder
18
- from .aff_utils import PathIndex
19
- from .puzzle_utils import tile_features, merge_features
20
-
21
- from tools.ai.torch_utils import resize_for_tensors
22
-
23
- #######################################################################
24
- # Normalization
25
- #######################################################################
26
- from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
27
-
28
- class FixedBatchNorm(nn.BatchNorm2d):
29
- def forward(self, x):
30
- return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
31
-
32
- def group_norm(features):
33
- return nn.GroupNorm(4, features)
34
- #######################################################################
35
-
36
- class Backbone(nn.Module, ABC_Model):
37
- def __init__(self, model_name, state_path, num_classes=20, mode='fix', segmentation=False):
38
- super().__init__()
39
-
40
- self.mode = mode
41
-
42
- if self.mode == 'fix':
43
- self.norm_fn = FixedBatchNorm
44
- else:
45
- self.norm_fn = nn.BatchNorm2d
46
-
47
- if 'resnet' in model_name:
48
- self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1), batch_norm_fn=self.norm_fn)
49
-
50
- state_dict = torch.load(state_path)
51
- self.model.load_state_dict(state_dict, strict=False)
52
- else:
53
- if segmentation:
54
- dilation, dilated = 4, True
55
- else:
56
- dilation, dilated = 2, False
57
-
58
- self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation, norm_layer=self.norm_fn)
59
-
60
- del self.model.avgpool
61
- del self.model.fc
62
-
63
- self.stage1 = nn.Sequential(self.model.conv1,
64
- self.model.bn1,
65
- self.model.relu,
66
- self.model.maxpool)
67
- self.stage2 = nn.Sequential(self.model.layer1)
68
- self.stage3 = nn.Sequential(self.model.layer2)
69
- self.stage4 = nn.Sequential(self.model.layer3)
70
- self.stage5 = nn.Sequential(self.model.layer4)
71
-
72
- class Classifier(Backbone):
73
- def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
74
- super().__init__(model_name, state_path, num_classes, mode)
75
-
76
- self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False)
77
- self.num_classes = num_classes
78
-
79
- self.initialize([self.classifier])
80
-
81
- def forward(self, x, with_cam=False):
82
- x = self.stage1(x)
83
- x = self.stage2(x)
84
- x = self.stage3(x)
85
- x = self.stage4(x)
86
- x = self.stage5(x)
87
-
88
- if with_cam:
89
- features = self.classifier(x)
90
- logits = self.global_average_pooling_2d(features)
91
- return logits, features
92
- else:
93
- x = self.global_average_pooling_2d(x, keepdims=True)
94
- logits = self.classifier(x).view(-1, self.num_classes)
95
- return logits
96
-
97
- class Classifier_For_Positive_Pooling(Backbone):
98
- def __init__(self, model_name, num_classes=20, mode='fix'):
99
- super().__init__(model_name, num_classes, mode)
100
-
101
- self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False)
102
- self.num_classes = num_classes
103
-
104
- self.initialize([self.classifier])
105
-
106
- def forward(self, x, with_cam=False):
107
- x = self.stage1(x)
108
- x = self.stage2(x)
109
- x = self.stage3(x)
110
- x = self.stage4(x)
111
- x = self.stage5(x)
112
-
113
- if with_cam:
114
- features = self.classifier(x)
115
- logits = self.global_average_pooling_2d(features)
116
- return logits, features
117
- else:
118
- x = self.global_average_pooling_2d(x, keepdims=True)
119
- logits = self.classifier(x).view(-1, self.num_classes)
120
- return logits
121
-
122
- class Classifier_For_Puzzle(Classifier):
123
- def __init__(self, model_name, num_classes=20, mode='fix'):
124
- super().__init__(model_name, num_classes, mode)
125
-
126
- def forward(self, x, num_pieces=1, level=-1):
127
- batch_size = x.size()[0]
128
-
129
- output_dic = {}
130
- layers = [self.stage1, self.stage2, self.stage3, self.stage4, self.stage5, self.classifier]
131
-
132
- for l, layer in enumerate(layers):
133
- l += 1
134
- if level == l:
135
- x = tile_features(x, num_pieces)
136
-
137
- x = layer(x)
138
- output_dic['stage%d'%l] = x
139
-
140
- output_dic['logits'] = self.global_average_pooling_2d(output_dic['stage6'])
141
-
142
- for l in range(len(layers)):
143
- l += 1
144
- if l >= level:
145
- output_dic['stage%d'%l] = merge_features(output_dic['stage%d'%l], num_pieces, batch_size)
146
-
147
- if level is not None:
148
- output_dic['merged_logits'] = self.global_average_pooling_2d(output_dic['stage6'])
149
-
150
- return output_dic
151
-
152
- class AffinityNet(Backbone):
153
- def __init__(self, model_name, path_index=None):
154
- super().__init__(model_name, None, 'fix')
155
-
156
- if '50' in model_name:
157
- fc_edge1_features = 64
158
- else:
159
- fc_edge1_features = 128
160
-
161
- self.fc_edge1 = nn.Sequential(
162
- nn.Conv2d(fc_edge1_features, 32, 1, bias=False),
163
- nn.GroupNorm(4, 32),
164
- nn.ReLU(inplace=True),
165
- )
166
- self.fc_edge2 = nn.Sequential(
167
- nn.Conv2d(256, 32, 1, bias=False),
168
- nn.GroupNorm(4, 32),
169
- nn.ReLU(inplace=True),
170
- )
171
- self.fc_edge3 = nn.Sequential(
172
- nn.Conv2d(512, 32, 1, bias=False),
173
- nn.GroupNorm(4, 32),
174
- nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
175
- nn.ReLU(inplace=True),
176
- )
177
- self.fc_edge4 = nn.Sequential(
178
- nn.Conv2d(1024, 32, 1, bias=False),
179
- nn.GroupNorm(4, 32),
180
- nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
181
- nn.ReLU(inplace=True),
182
- )
183
- self.fc_edge5 = nn.Sequential(
184
- nn.Conv2d(2048, 32, 1, bias=False),
185
- nn.GroupNorm(4, 32),
186
- nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
187
- nn.ReLU(inplace=True),
188
- )
189
- self.fc_edge6 = nn.Conv2d(160, 1, 1, bias=True)
190
-
191
- self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4, self.stage5])
192
- self.edge_layers = nn.ModuleList([self.fc_edge1, self.fc_edge2, self.fc_edge3, self.fc_edge4, self.fc_edge5, self.fc_edge6])
193
-
194
- if path_index is not None:
195
- self.path_index = path_index
196
- self.n_path_lengths = len(self.path_index.path_indices)
197
- for i, pi in enumerate(self.path_index.path_indices):
198
- self.register_buffer("path_indices_" + str(i), torch.from_numpy(pi))
199
-
200
- def train(self, mode=True):
201
- super().train(mode)
202
- self.backbone.eval()
203
-
204
- def forward(self, x, with_affinity=False):
205
- x1 = self.stage1(x).detach()
206
- x2 = self.stage2(x1).detach()
207
- x3 = self.stage3(x2).detach()
208
- x4 = self.stage4(x3).detach()
209
- x5 = self.stage5(x4).detach()
210
-
211
- edge1 = self.fc_edge1(x1)
212
- edge2 = self.fc_edge2(x2)
213
- edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)]
214
- edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)]
215
- edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)]
216
-
217
- edge = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1))
218
-
219
- if with_affinity:
220
- return edge, self.to_affinity(torch.sigmoid(edge))
221
- else:
222
- return edge
223
-
224
- def get_edge(self, x, image_size=512, stride=4):
225
- feat_size = (x.size(2)-1)//stride+1, (x.size(3)-1)//stride+1
226
-
227
- x = F.pad(x, [0, image_size-x.size(3), 0, image_size-x.size(2)])
228
- edge_out = self.forward(x)
229
- edge_out = edge_out[..., :feat_size[0], :feat_size[1]]
230
- edge_out = torch.sigmoid(edge_out[0]/2 + edge_out[1].flip(-1)/2)
231
-
232
- return edge_out
233
-
234
- """
235
- aff = self.to_affinity(torch.sigmoid(edge_out))
236
- pos_aff_loss = (-1) * torch.log(aff + 1e-5)
237
- neg_aff_loss = (-1) * torch.log(1. + 1e-5 - aff)
238
- """
239
- def to_affinity(self, edge):
240
- aff_list = []
241
- edge = edge.view(edge.size(0), -1)
242
-
243
- for i in range(self.n_path_lengths):
244
- ind = self._buffers["path_indices_" + str(i)]
245
- ind_flat = ind.view(-1)
246
- dist = torch.index_select(edge, dim=-1, index=ind_flat)
247
- dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2))
248
- aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2)
249
- aff_list.append(aff)
250
- aff_cat = torch.cat(aff_list, dim=1)
251
- return aff_cat
252
-
253
- class DeepLabv3_Plus(Backbone):
254
- def __init__(self, model_name, num_classes=21, mode='fix', use_group_norm=False):
255
- super().__init__(model_name, num_classes, mode, segmentation=False)
256
-
257
- if use_group_norm:
258
- norm_fn_for_extra_modules = group_norm
259
- else:
260
- norm_fn_for_extra_modules = self.norm_fn
261
-
262
- self.aspp = ASPP(output_stride=16, norm_fn=norm_fn_for_extra_modules)
263
- self.decoder = Decoder(num_classes, 256, norm_fn_for_extra_modules)
264
-
265
- def forward(self, x, with_cam=False):
266
- inputs = x
267
-
268
- x = self.stage1(x)
269
- x = self.stage2(x)
270
- x_low_level = x
271
-
272
- x = self.stage3(x)
273
- x = self.stage4(x)
274
- x = self.stage5(x)
275
-
276
- x = self.aspp(x)
277
- x = self.decoder(x, x_low_level)
278
- x = resize_for_tensors(x, inputs.size()[2:], align_corners=True)
279
-
280
- return x
281
-
282
- class Seg_Model(Backbone):
283
- def __init__(self, model_name, num_classes=21):
284
- super().__init__(model_name, num_classes, mode='fix', segmentation=False)
285
-
286
- self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False)
287
-
288
- def forward(self, inputs):
289
- x = self.stage1(inputs)
290
- x = self.stage2(x)
291
- x = self.stage3(x)
292
- x = self.stage4(x)
293
- x = self.stage5(x)
294
-
295
- logits = self.classifier(x)
296
- # logits = resize_for_tensors(logits, inputs.size()[2:], align_corners=False)
297
-
298
- return logits
299
-
300
- class CSeg_Model(Backbone):
301
- def __init__(self, model_name, num_classes=21):
302
- super().__init__(model_name, num_classes, 'fix')
303
-
304
- if '50' in model_name:
305
- fc_edge1_features = 64
306
- else:
307
- fc_edge1_features = 128
308
-
309
- self.fc_edge1 = nn.Sequential(
310
- nn.Conv2d(fc_edge1_features, 32, 1, bias=False),
311
- nn.GroupNorm(4, 32),
312
- nn.ReLU(inplace=True),
313
- )
314
- self.fc_edge2 = nn.Sequential(
315
- nn.Conv2d(256, 32, 1, bias=False),
316
- nn.GroupNorm(4, 32),
317
- nn.ReLU(inplace=True),
318
- )
319
- self.fc_edge3 = nn.Sequential(
320
- nn.Conv2d(512, 32, 1, bias=False),
321
- nn.GroupNorm(4, 32),
322
- nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
323
- nn.ReLU(inplace=True),
324
- )
325
- self.fc_edge4 = nn.Sequential(
326
- nn.Conv2d(1024, 32, 1, bias=False),
327
- nn.GroupNorm(4, 32),
328
- nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
329
- nn.ReLU(inplace=True),
330
- )
331
- self.fc_edge5 = nn.Sequential(
332
- nn.Conv2d(2048, 32, 1, bias=False),
333
- nn.GroupNorm(4, 32),
334
- nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
335
- nn.ReLU(inplace=True),
336
- )
337
- self.fc_edge6 = nn.Conv2d(160, num_classes, 1, bias=True)
338
-
339
- def forward(self, x):
340
- x1 = self.stage1(x)
341
- x2 = self.stage2(x1)
342
- x3 = self.stage3(x2)
343
- x4 = self.stage4(x3)
344
- x5 = self.stage5(x4)
345
-
346
- edge1 = self.fc_edge1(x1)
347
- edge2 = self.fc_edge2(x2)
348
- edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)]
349
- edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)]
350
- edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)]
351
-
352
- logits = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1))
353
- # logits = resize_for_tensors(logits, x.size()[2:], align_corners=True)
354
-
355
- return logits
 
 
 
 
 
 
1
+ # Copyright (C) 2021 * Ltd. All rights reserved.
2
+ # author : Sanghyeon Jo <[email protected]>
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from torchvision import models
11
+ import torch.utils.model_zoo as model_zoo
12
+
13
+ from .arch_resnet import resnet
14
+ from .arch_resnest import resnest
15
+ from .abc_modules import ABC_Model
16
+
17
+ from .deeplab_utils import ASPP, Decoder
18
+ from .aff_utils import PathIndex
19
+ from .puzzle_utils import tile_features, merge_features
20
+
21
+ from tools.ai.torch_utils import resize_for_tensors
22
+
23
+ #######################################################################
24
+ # Normalization
25
+ #######################################################################
26
+ from .sync_batchnorm.batchnorm import SynchronizedBatchNorm2d
27
+
28
+ class FixedBatchNorm(nn.BatchNorm2d):
29
+ def forward(self, x):
30
+ return F.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps)
31
+
32
+ def group_norm(features):
33
+ return nn.GroupNorm(4, features)
34
+ #######################################################################
35
+
36
+ class Backbone(nn.Module, ABC_Model):
37
+ def __init__(self, model_name, num_classes=20, mode='fix', segmentation=False):
38
+ super().__init__()
39
+
40
+ self.mode = mode
41
+
42
+ if self.mode == 'fix':
43
+ self.norm_fn = FixedBatchNorm
44
+ else:
45
+ self.norm_fn = nn.BatchNorm2d
46
+
47
+ if 'resnet' in model_name:
48
+ self.model = resnet.ResNet(resnet.Bottleneck, resnet.layers_dic[model_name], strides=(2, 2, 2, 1),
49
+ batch_norm_fn=self.norm_fn)
50
+
51
+ state_dict = model_zoo.load_url(resnet.urls_dic[model_name])
52
+ state_dict.pop('fc.weight')
53
+ state_dict.pop('fc.bias')
54
+
55
+ self.model.load_state_dict(state_dict)
56
+ else:
57
+ if segmentation:
58
+ dilation, dilated = 4, True
59
+ else:
60
+ dilation, dilated = 2, False
61
+
62
+ self.model = eval("resnest." + model_name)(pretrained=True, dilated=dilated, dilation=dilation,
63
+ norm_layer=self.norm_fn)
64
+
65
+ del self.model.avgpool
66
+ del self.model.fc
67
+
68
+ self.stage1 = nn.Sequential(self.model.conv1,
69
+ self.model.bn1,
70
+ self.model.relu,
71
+ self.model.maxpool)
72
+ self.stage2 = nn.Sequential(self.model.layer1)
73
+ self.stage3 = nn.Sequential(self.model.layer2)
74
+ self.stage4 = nn.Sequential(self.model.layer3)
75
+ self.stage5 = nn.Sequential(self.model.layer4)
76
+
77
+ class Classifier(Backbone):
78
+ def __init__(self, model_name, state_path, num_classes=20, mode='fix'):
79
+ super().__init__(model_name, state_path, num_classes, mode)
80
+
81
+ self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False)
82
+ self.num_classes = num_classes
83
+
84
+ self.initialize([self.classifier])
85
+
86
+ def forward(self, x, with_cam=False):
87
+ x = self.stage1(x)
88
+ x = self.stage2(x)
89
+ x = self.stage3(x)
90
+ x = self.stage4(x)
91
+ x = self.stage5(x)
92
+
93
+ if with_cam:
94
+ features = self.classifier(x)
95
+ logits = self.global_average_pooling_2d(features)
96
+ return logits, features
97
+ else:
98
+ x = self.global_average_pooling_2d(x, keepdims=True)
99
+ logits = self.classifier(x).view(-1, self.num_classes)
100
+ return logits
101
+
102
+ class Classifier_For_Positive_Pooling(Backbone):
103
+ def __init__(self, model_name, num_classes=20, mode='fix'):
104
+ super().__init__(model_name, num_classes, mode)
105
+
106
+ self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False)
107
+ self.num_classes = num_classes
108
+
109
+ self.initialize([self.classifier])
110
+
111
+ def forward(self, x, with_cam=False):
112
+ x = self.stage1(x)
113
+ x = self.stage2(x)
114
+ x = self.stage3(x)
115
+ x = self.stage4(x)
116
+ x = self.stage5(x)
117
+
118
+ if with_cam:
119
+ features = self.classifier(x)
120
+ logits = self.global_average_pooling_2d(features)
121
+ return logits, features
122
+ else:
123
+ x = self.global_average_pooling_2d(x, keepdims=True)
124
+ logits = self.classifier(x).view(-1, self.num_classes)
125
+ return logits
126
+
127
+ class Classifier_For_Puzzle(Classifier):
128
+ def __init__(self, model_name, num_classes=20, mode='fix'):
129
+ super().__init__(model_name, num_classes, mode)
130
+
131
+ def forward(self, x, num_pieces=1, level=-1):
132
+ batch_size = x.size()[0]
133
+
134
+ output_dic = {}
135
+ layers = [self.stage1, self.stage2, self.stage3, self.stage4, self.stage5, self.classifier]
136
+
137
+ for l, layer in enumerate(layers):
138
+ l += 1
139
+ if level == l:
140
+ x = tile_features(x, num_pieces)
141
+
142
+ x = layer(x)
143
+ output_dic['stage%d'%l] = x
144
+
145
+ output_dic['logits'] = self.global_average_pooling_2d(output_dic['stage6'])
146
+
147
+ for l in range(len(layers)):
148
+ l += 1
149
+ if l >= level:
150
+ output_dic['stage%d'%l] = merge_features(output_dic['stage%d'%l], num_pieces, batch_size)
151
+
152
+ if level is not None:
153
+ output_dic['merged_logits'] = self.global_average_pooling_2d(output_dic['stage6'])
154
+
155
+ return output_dic
156
+
157
+ class AffinityNet(Backbone):
158
+ def __init__(self, model_name, path_index=None):
159
+ super().__init__(model_name, None, 'fix')
160
+
161
+ if '50' in model_name:
162
+ fc_edge1_features = 64
163
+ else:
164
+ fc_edge1_features = 128
165
+
166
+ self.fc_edge1 = nn.Sequential(
167
+ nn.Conv2d(fc_edge1_features, 32, 1, bias=False),
168
+ nn.GroupNorm(4, 32),
169
+ nn.ReLU(inplace=True),
170
+ )
171
+ self.fc_edge2 = nn.Sequential(
172
+ nn.Conv2d(256, 32, 1, bias=False),
173
+ nn.GroupNorm(4, 32),
174
+ nn.ReLU(inplace=True),
175
+ )
176
+ self.fc_edge3 = nn.Sequential(
177
+ nn.Conv2d(512, 32, 1, bias=False),
178
+ nn.GroupNorm(4, 32),
179
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
180
+ nn.ReLU(inplace=True),
181
+ )
182
+ self.fc_edge4 = nn.Sequential(
183
+ nn.Conv2d(1024, 32, 1, bias=False),
184
+ nn.GroupNorm(4, 32),
185
+ nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
186
+ nn.ReLU(inplace=True),
187
+ )
188
+ self.fc_edge5 = nn.Sequential(
189
+ nn.Conv2d(2048, 32, 1, bias=False),
190
+ nn.GroupNorm(4, 32),
191
+ nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
192
+ nn.ReLU(inplace=True),
193
+ )
194
+ self.fc_edge6 = nn.Conv2d(160, 1, 1, bias=True)
195
+
196
+ self.backbone = nn.ModuleList([self.stage1, self.stage2, self.stage3, self.stage4, self.stage5])
197
+ self.edge_layers = nn.ModuleList([self.fc_edge1, self.fc_edge2, self.fc_edge3, self.fc_edge4, self.fc_edge5, self.fc_edge6])
198
+
199
+ if path_index is not None:
200
+ self.path_index = path_index
201
+ self.n_path_lengths = len(self.path_index.path_indices)
202
+ for i, pi in enumerate(self.path_index.path_indices):
203
+ self.register_buffer("path_indices_" + str(i), torch.from_numpy(pi))
204
+
205
+ def train(self, mode=True):
206
+ super().train(mode)
207
+ self.backbone.eval()
208
+
209
+ def forward(self, x, with_affinity=False):
210
+ x1 = self.stage1(x).detach()
211
+ x2 = self.stage2(x1).detach()
212
+ x3 = self.stage3(x2).detach()
213
+ x4 = self.stage4(x3).detach()
214
+ x5 = self.stage5(x4).detach()
215
+
216
+ edge1 = self.fc_edge1(x1)
217
+ edge2 = self.fc_edge2(x2)
218
+ edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)]
219
+ edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)]
220
+ edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)]
221
+
222
+ edge = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1))
223
+
224
+ if with_affinity:
225
+ return edge, self.to_affinity(torch.sigmoid(edge))
226
+ else:
227
+ return edge
228
+
229
+ def get_edge(self, x, image_size=512, stride=4):
230
+ feat_size = (x.size(2)-1)//stride+1, (x.size(3)-1)//stride+1
231
+
232
+ x = F.pad(x, [0, image_size-x.size(3), 0, image_size-x.size(2)])
233
+ edge_out = self.forward(x)
234
+ edge_out = edge_out[..., :feat_size[0], :feat_size[1]]
235
+ edge_out = torch.sigmoid(edge_out[0]/2 + edge_out[1].flip(-1)/2)
236
+
237
+ return edge_out
238
+
239
+ """
240
+ aff = self.to_affinity(torch.sigmoid(edge_out))
241
+ pos_aff_loss = (-1) * torch.log(aff + 1e-5)
242
+ neg_aff_loss = (-1) * torch.log(1. + 1e-5 - aff)
243
+ """
244
+ def to_affinity(self, edge):
245
+ aff_list = []
246
+ edge = edge.view(edge.size(0), -1)
247
+
248
+ for i in range(self.n_path_lengths):
249
+ ind = self._buffers["path_indices_" + str(i)]
250
+ ind_flat = ind.view(-1)
251
+ dist = torch.index_select(edge, dim=-1, index=ind_flat)
252
+ dist = dist.view(dist.size(0), ind.size(0), ind.size(1), ind.size(2))
253
+ aff = torch.squeeze(1 - F.max_pool2d(dist, (dist.size(2), 1)), dim=2)
254
+ aff_list.append(aff)
255
+ aff_cat = torch.cat(aff_list, dim=1)
256
+ return aff_cat
257
+
258
+ class DeepLabv3_Plus(Backbone):
259
+ def __init__(self, model_name, num_classes=21, mode='fix', use_group_norm=False):
260
+ super().__init__(model_name, num_classes, mode, segmentation=False)
261
+
262
+ if use_group_norm:
263
+ norm_fn_for_extra_modules = group_norm
264
+ else:
265
+ norm_fn_for_extra_modules = self.norm_fn
266
+
267
+ self.aspp = ASPP(output_stride=16, norm_fn=norm_fn_for_extra_modules)
268
+ self.decoder = Decoder(num_classes, 256, norm_fn_for_extra_modules)
269
+
270
+ def forward(self, x, with_cam=False):
271
+ inputs = x
272
+
273
+ x = self.stage1(x)
274
+ x = self.stage2(x)
275
+ x_low_level = x
276
+
277
+ x = self.stage3(x)
278
+ x = self.stage4(x)
279
+ x = self.stage5(x)
280
+
281
+ x = self.aspp(x)
282
+ x = self.decoder(x, x_low_level)
283
+ x = resize_for_tensors(x, inputs.size()[2:], align_corners=True)
284
+
285
+ return x
286
+
287
+ class Seg_Model(Backbone):
288
+ def __init__(self, model_name, num_classes=21):
289
+ super().__init__(model_name, num_classes, mode='fix', segmentation=False)
290
+
291
+ self.classifier = nn.Conv2d(2048, num_classes, 1, bias=False)
292
+
293
+ def forward(self, inputs):
294
+ x = self.stage1(inputs)
295
+ x = self.stage2(x)
296
+ x = self.stage3(x)
297
+ x = self.stage4(x)
298
+ x = self.stage5(x)
299
+
300
+ logits = self.classifier(x)
301
+ # logits = resize_for_tensors(logits, inputs.size()[2:], align_corners=False)
302
+
303
+ return logits
304
+
305
+ class CSeg_Model(Backbone):
306
+ def __init__(self, model_name, num_classes=21):
307
+ super().__init__(model_name, num_classes, 'fix')
308
+
309
+ if '50' in model_name:
310
+ fc_edge1_features = 64
311
+ else:
312
+ fc_edge1_features = 128
313
+
314
+ self.fc_edge1 = nn.Sequential(
315
+ nn.Conv2d(fc_edge1_features, 32, 1, bias=False),
316
+ nn.GroupNorm(4, 32),
317
+ nn.ReLU(inplace=True),
318
+ )
319
+ self.fc_edge2 = nn.Sequential(
320
+ nn.Conv2d(256, 32, 1, bias=False),
321
+ nn.GroupNorm(4, 32),
322
+ nn.ReLU(inplace=True),
323
+ )
324
+ self.fc_edge3 = nn.Sequential(
325
+ nn.Conv2d(512, 32, 1, bias=False),
326
+ nn.GroupNorm(4, 32),
327
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
328
+ nn.ReLU(inplace=True),
329
+ )
330
+ self.fc_edge4 = nn.Sequential(
331
+ nn.Conv2d(1024, 32, 1, bias=False),
332
+ nn.GroupNorm(4, 32),
333
+ nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
334
+ nn.ReLU(inplace=True),
335
+ )
336
+ self.fc_edge5 = nn.Sequential(
337
+ nn.Conv2d(2048, 32, 1, bias=False),
338
+ nn.GroupNorm(4, 32),
339
+ nn.Upsample(scale_factor=4, mode='bilinear', align_corners=False),
340
+ nn.ReLU(inplace=True),
341
+ )
342
+ self.fc_edge6 = nn.Conv2d(160, num_classes, 1, bias=True)
343
+
344
+ def forward(self, x):
345
+ x1 = self.stage1(x)
346
+ x2 = self.stage2(x1)
347
+ x3 = self.stage3(x2)
348
+ x4 = self.stage4(x3)
349
+ x5 = self.stage5(x4)
350
+
351
+ edge1 = self.fc_edge1(x1)
352
+ edge2 = self.fc_edge2(x2)
353
+ edge3 = self.fc_edge3(x3)[..., :edge2.size(2), :edge2.size(3)]
354
+ edge4 = self.fc_edge4(x4)[..., :edge2.size(2), :edge2.size(3)]
355
+ edge5 = self.fc_edge5(x5)[..., :edge2.size(2), :edge2.size(3)]
356
+
357
+ logits = self.fc_edge6(torch.cat([edge1, edge2, edge3, edge4, edge5], dim=1))
358
+ # logits = resize_for_tensors(logits, x.size()[2:], align_corners=True)
359
+
360
+ return logits