josedolot commited on
Commit
7b056a9
·
1 Parent(s): 549e090

Upload hybridnets/model.py

Browse files
Files changed (1) hide show
  1. hybridnets/model.py +800 -0
hybridnets/model.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ from torchvision.ops.boxes import nms as nms_torch
4
+ import torch.nn.functional as F
5
+ import math
6
+ from functools import partial
7
+
8
+
9
+ def nms(dets, thresh):
10
+ return nms_torch(dets[:, :4], dets[:, 4], thresh)
11
+
12
+
13
+ class SeparableConvBlock(nn.Module):
14
+ def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False):
15
+ super(SeparableConvBlock, self).__init__()
16
+ if out_channels is None:
17
+ out_channels = in_channels
18
+
19
+ # Q: whether separate conv
20
+ # share bias between depthwise_conv and pointwise_conv
21
+ # or just pointwise_conv apply bias.
22
+ # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias.
23
+
24
+ self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels,
25
+ kernel_size=3, stride=1, groups=in_channels, bias=False)
26
+ self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1)
27
+
28
+ self.norm = norm
29
+ if self.norm:
30
+ # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow
31
+ self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3)
32
+
33
+ self.activation = activation
34
+ if self.activation:
35
+ self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
36
+
37
+ def forward(self, x):
38
+ x = self.depthwise_conv(x)
39
+ x = self.pointwise_conv(x)
40
+
41
+ if self.norm:
42
+ x = self.bn(x)
43
+
44
+ if self.activation:
45
+ x = self.swish(x)
46
+
47
+ return x
48
+
49
+
50
+ class BiFPN(nn.Module):
51
+ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True,
52
+ use_p8=False):
53
+ """
54
+
55
+ Args:
56
+ num_channels:
57
+ conv_channels:
58
+ first_time: whether the input comes directly from the efficientnet,
59
+ if True, downchannel it first, and downsample P5 to generate P6 then P7
60
+ epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon
61
+ onnx_export: if True, use Swish instead of MemoryEfficientSwish
62
+ """
63
+ super(BiFPN, self).__init__()
64
+ self.epsilon = epsilon
65
+ self.use_p8 = use_p8
66
+
67
+ # Conv layers
68
+ self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
69
+ self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
70
+ self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
71
+ self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
72
+ self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
73
+ self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
74
+ self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
75
+ self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
76
+ if use_p8:
77
+ self.conv7_up = SeparableConvBlock(num_channels, onnx_export=onnx_export)
78
+ self.conv8_down = SeparableConvBlock(num_channels, onnx_export=onnx_export)
79
+
80
+ # Feature scaling layers
81
+ self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
82
+ self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
83
+ self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
84
+ self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
85
+
86
+ self.p4_downsample = MaxPool2dStaticSamePadding(3, 2)
87
+ self.p5_downsample = MaxPool2dStaticSamePadding(3, 2)
88
+ self.p6_downsample = MaxPool2dStaticSamePadding(3, 2)
89
+ self.p7_downsample = MaxPool2dStaticSamePadding(3, 2)
90
+ if use_p8:
91
+ self.p7_upsample = nn.Upsample(scale_factor=2, mode='nearest')
92
+ self.p8_downsample = MaxPool2dStaticSamePadding(3, 2)
93
+
94
+ self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
95
+
96
+ self.first_time = first_time
97
+ if self.first_time:
98
+ self.p5_down_channel = nn.Sequential(
99
+ Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
100
+ nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
101
+ )
102
+ self.p4_down_channel = nn.Sequential(
103
+ Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
104
+ nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
105
+ )
106
+ self.p3_down_channel = nn.Sequential(
107
+ Conv2dStaticSamePadding(conv_channels[0], num_channels, 1),
108
+ nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
109
+ )
110
+
111
+ self.p5_to_p6 = nn.Sequential(
112
+ Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
113
+ nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
114
+ MaxPool2dStaticSamePadding(3, 2)
115
+ )
116
+ self.p6_to_p7 = nn.Sequential(
117
+ MaxPool2dStaticSamePadding(3, 2)
118
+ )
119
+ if use_p8:
120
+ self.p7_to_p8 = nn.Sequential(
121
+ MaxPool2dStaticSamePadding(3, 2)
122
+ )
123
+
124
+ self.p4_down_channel_2 = nn.Sequential(
125
+ Conv2dStaticSamePadding(conv_channels[1], num_channels, 1),
126
+ nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
127
+ )
128
+ self.p5_down_channel_2 = nn.Sequential(
129
+ Conv2dStaticSamePadding(conv_channels[2], num_channels, 1),
130
+ nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3),
131
+ )
132
+
133
+ # Weight
134
+ self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
135
+ self.p6_w1_relu = nn.ReLU()
136
+ self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
137
+ self.p5_w1_relu = nn.ReLU()
138
+ self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
139
+ self.p4_w1_relu = nn.ReLU()
140
+ self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
141
+ self.p3_w1_relu = nn.ReLU()
142
+
143
+ self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
144
+ self.p4_w2_relu = nn.ReLU()
145
+ self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
146
+ self.p5_w2_relu = nn.ReLU()
147
+ self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True)
148
+ self.p6_w2_relu = nn.ReLU()
149
+ self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True)
150
+ self.p7_w2_relu = nn.ReLU()
151
+
152
+ self.attention = attention
153
+
154
+ def forward(self, inputs):
155
+ """
156
+ illustration of a minimal bifpn unit
157
+ P7_0 -------------------------> P7_2 -------->
158
+ |-------------| ↑
159
+ ↓ |
160
+ P6_0 ---------> P6_1 ---------> P6_2 -------->
161
+ |-------------|--------------↑ ↑
162
+ ↓ |
163
+ P5_0 ---------> P5_1 ---------> P5_2 -------->
164
+ |-------------|--------------↑ ↑
165
+ ↓ |
166
+ P4_0 ---------> P4_1 ---------> P4_2 -------->
167
+ |-------------|--------------↑ ↑
168
+ |--------------↓ |
169
+ P3_0 -------------------------> P3_2 -------->
170
+ """
171
+
172
+ # downsample channels using same-padding conv2d to target phase's if not the same
173
+ # judge: same phase as target,
174
+ # if same, pass;
175
+ # elif earlier phase, downsample to target phase's by pooling
176
+ # elif later phase, upsample to target phase's by nearest interpolation
177
+
178
+ if self.attention:
179
+ outs = self._forward_fast_attention(inputs)
180
+ else:
181
+ outs = self._forward(inputs)
182
+
183
+ return outs
184
+
185
+ def _forward_fast_attention(self, inputs):
186
+ if self.first_time:
187
+ p3, p4, p5 = inputs
188
+
189
+ p6_in = self.p5_to_p6(p5)
190
+ p7_in = self.p6_to_p7(p6_in)
191
+
192
+ p3_in = self.p3_down_channel(p3)
193
+ p4_in = self.p4_down_channel(p4)
194
+ p5_in = self.p5_down_channel(p5)
195
+
196
+ else:
197
+ # P3_0, P4_0, P5_0, P6_0 and P7_0
198
+ p3_in, p4_in, p5_in, p6_in, p7_in = inputs
199
+
200
+ # P7_0 to P7_2
201
+
202
+ # Weights for P6_0 and P7_0 to P6_1
203
+ p6_w1 = self.p6_w1_relu(self.p6_w1)
204
+ weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
205
+ # Connections for P6_0 and P7_0 to P6_1 respectively
206
+ p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in)))
207
+ # Weights for P5_0 and P6_1 to P5_1
208
+ p5_w1 = self.p5_w1_relu(self.p5_w1)
209
+ weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
210
+ # Connections for P5_0 and P6_1 to P5_1 respectively
211
+ p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up)))
212
+
213
+ # Weights for P4_0 and P5_1 to P4_1
214
+ p4_w1 = self.p4_w1_relu(self.p4_w1)
215
+ weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
216
+ # Connections for P4_0 and P5_1 to P4_1 respectively
217
+ p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up)))
218
+
219
+ # Weights for P3_0 and P4_1 to P3_2
220
+ p3_w1 = self.p3_w1_relu(self.p3_w1)
221
+ weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
222
+ # Connections for P3_0 and P4_1 to P3_2 respectively
223
+ p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up)))
224
+
225
+ if self.first_time:
226
+ p4_in = self.p4_down_channel_2(p4)
227
+ p5_in = self.p5_down_channel_2(p5)
228
+
229
+ # Weights for P4_0, P4_1 and P3_2 to P4_2
230
+ p4_w2 = self.p4_w2_relu(self.p4_w2)
231
+ weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
232
+ # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
233
+ p4_out = self.conv4_down(
234
+ self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out)))
235
+
236
+ # Weights for P5_0, P5_1 and P4_2 to P5_2
237
+ p5_w2 = self.p5_w2_relu(self.p5_w2)
238
+ weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
239
+ # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
240
+ p5_out = self.conv5_down(
241
+ self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out)))
242
+
243
+ # Weights for P6_0, P6_1 and P5_2 to P6_2
244
+ p6_w2 = self.p6_w2_relu(self.p6_w2)
245
+ weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
246
+ # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
247
+ p6_out = self.conv6_down(
248
+ self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out)))
249
+
250
+ # Weights for P7_0 and P6_2 to P7_2
251
+ p7_w2 = self.p7_w2_relu(self.p7_w2)
252
+ weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
253
+ # Connections for P7_0 and P6_2 to P7_2
254
+ p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out)))
255
+
256
+ return p3_out, p4_out, p5_out, p6_out, p7_out
257
+
258
+ def _forward(self, inputs):
259
+ if self.first_time:
260
+ p3, p4, p5 = inputs
261
+
262
+ p6_in = self.p5_to_p6(p5)
263
+ p7_in = self.p6_to_p7(p6_in)
264
+ if self.use_p8:
265
+ p8_in = self.p7_to_p8(p7_in)
266
+
267
+ p3_in = self.p3_down_channel(p3)
268
+ p4_in = self.p4_down_channel(p4)
269
+ p5_in = self.p5_down_channel(p5)
270
+
271
+ else:
272
+ if self.use_p8:
273
+ # P3_0, P4_0, P5_0, P6_0, P7_0 and P8_0
274
+ p3_in, p4_in, p5_in, p6_in, p7_in, p8_in = inputs
275
+ else:
276
+ # P3_0, P4_0, P5_0, P6_0 and P7_0
277
+ p3_in, p4_in, p5_in, p6_in, p7_in = inputs
278
+
279
+ if self.use_p8:
280
+ # P8_0 to P8_2
281
+
282
+ # Connections for P7_0 and P8_0 to P7_1 respectively
283
+ p7_up = self.conv7_up(self.swish(p7_in + self.p7_upsample(p8_in)))
284
+
285
+ # Connections for P6_0 and P7_0 to P6_1 respectively
286
+ p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_up)))
287
+ else:
288
+ # P7_0 to P7_2
289
+
290
+ # Connections for P6_0 and P7_0 to P6_1 respectively
291
+ p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in)))
292
+
293
+ # Connections for P5_0 and P6_1 to P5_1 respectively
294
+ p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up)))
295
+
296
+ # Connections for P4_0 and P5_1 to P4_1 respectively
297
+ p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up)))
298
+
299
+ # Connections for P3_0 and P4_1 to P3_2 respectively
300
+ p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up)))
301
+
302
+ if self.first_time:
303
+ p4_in = self.p4_down_channel_2(p4)
304
+ p5_in = self.p5_down_channel_2(p5)
305
+
306
+ # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
307
+ p4_out = self.conv4_down(
308
+ self.swish(p4_in + p4_up + self.p4_downsample(p3_out)))
309
+
310
+ # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
311
+ p5_out = self.conv5_down(
312
+ self.swish(p5_in + p5_up + self.p5_downsample(p4_out)))
313
+
314
+ # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
315
+ p6_out = self.conv6_down(
316
+ self.swish(p6_in + p6_up + self.p6_downsample(p5_out)))
317
+
318
+ if self.use_p8:
319
+ # Connections for P7_0, P7_1 and P6_2 to P7_2 respectively
320
+ p7_out = self.conv7_down(
321
+ self.swish(p7_in + p7_up + self.p7_downsample(p6_out)))
322
+
323
+ # Connections for P8_0 and P7_2 to P8_2
324
+ p8_out = self.conv8_down(self.swish(p8_in + self.p8_downsample(p7_out)))
325
+
326
+ return p3_out, p4_out, p5_out, p6_out, p7_out, p8_out
327
+ else:
328
+ # Connections for P7_0 and P6_2 to P7_2
329
+ p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out)))
330
+
331
+ return p3_out, p4_out, p5_out, p6_out, p7_out
332
+
333
+
334
+ class Regressor(nn.Module):
335
+ def __init__(self, in_channels, num_anchors, num_layers, pyramid_levels=5, onnx_export=False):
336
+ super(Regressor, self).__init__()
337
+ self.num_layers = num_layers
338
+
339
+ self.conv_list = nn.ModuleList(
340
+ [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
341
+ self.bn_list = nn.ModuleList(
342
+ [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
343
+ range(pyramid_levels)])
344
+ self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False)
345
+ self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
346
+
347
+ def forward(self, inputs):
348
+ feats = []
349
+ for feat, bn_list in zip(inputs, self.bn_list):
350
+ for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
351
+ feat = conv(feat)
352
+ feat = bn(feat)
353
+ feat = self.swish(feat)
354
+ feat = self.header(feat)
355
+
356
+ feat = feat.permute(0, 2, 3, 1)
357
+ feat = feat.contiguous().view(feat.shape[0], -1, 4)
358
+
359
+ feats.append(feat)
360
+
361
+ feats = torch.cat(feats, dim=1)
362
+
363
+ return feats
364
+
365
+
366
+ class Conv3x3BNSwish(nn.Module):
367
+ def __init__(self, in_channels, out_channels, upsample=False):
368
+ super().__init__()
369
+
370
+ self.swish = Swish()
371
+
372
+ self.upsample = upsample
373
+
374
+ self.block = nn.Sequential(
375
+ Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=(3, 3), stride=1, padding=1, bias=False),
376
+ nn.BatchNorm2d(out_channels, momentum=0.01, eps=1e-3),
377
+ )
378
+
379
+ self.conv_sp = SeparableConvBlock(out_channels, onnx_export=False)
380
+
381
+ # self.block = nn.Sequential(
382
+ # nn.Conv2d(
383
+ # in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
384
+ # ),
385
+ # nn.GroupNorm(32, out_channels),
386
+ # nn.ReLU(inplace=True),
387
+ # )
388
+
389
+ def forward(self, x):
390
+ x = self.conv_sp(self.swish(self.block(x)))
391
+ if self.upsample:
392
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
393
+ return x
394
+
395
+
396
+ class SegmentationBlock(nn.Module):
397
+ def __init__(self, in_channels, out_channels, n_upsamples=0):
398
+ super().__init__()
399
+
400
+ blocks = [Conv3x3BNSwish(in_channels, out_channels, upsample=bool(n_upsamples))]
401
+
402
+ if n_upsamples > 1:
403
+ for _ in range(1, n_upsamples):
404
+ blocks.append(Conv3x3BNSwish(out_channels, out_channels, upsample=True))
405
+
406
+ self.block = nn.Sequential(*blocks)
407
+
408
+ def forward(self, x):
409
+ return self.block(x)
410
+
411
+
412
+ class MergeBlock(nn.Module):
413
+ def __init__(self, policy):
414
+ super().__init__()
415
+ if policy not in ["add", "cat"]:
416
+ raise ValueError(
417
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(
418
+ policy
419
+ )
420
+ )
421
+ self.policy = policy
422
+
423
+ def forward(self, x):
424
+ if self.policy == 'add':
425
+ return sum(x)
426
+ elif self.policy == 'cat':
427
+ return torch.cat(x, dim=1)
428
+ else:
429
+ raise ValueError(
430
+ "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
431
+ )
432
+
433
+
434
+ class BiFPNDecoder(nn.Module):
435
+ def __init__(
436
+ self,
437
+ encoder_depth=5,
438
+ pyramid_channels=64,
439
+ segmentation_channels=64,
440
+ dropout=0.2,
441
+ merge_policy="add", ):
442
+ super().__init__()
443
+
444
+ self.seg_blocks = nn.ModuleList([
445
+ SegmentationBlock(pyramid_channels, segmentation_channels, n_upsamples=n_upsamples)
446
+ for n_upsamples in [5,4, 3, 2, 1]
447
+ ])
448
+
449
+ self.seg_p2 = SegmentationBlock(32, 64, n_upsamples=0)
450
+
451
+ self.merge = MergeBlock(merge_policy)
452
+
453
+ self.dropout = nn.Dropout2d(p=dropout, inplace=True)
454
+
455
+ def forward(self, inputs):
456
+ p2, p3, p4, p5, p6, p7 = inputs
457
+
458
+ feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p7, p6, p5, p4, p3])]
459
+
460
+ p2 = self.seg_p2(p2)
461
+
462
+ p3,p4,p5,p6,p7 = feature_pyramid
463
+
464
+ x = self.merge((p2,p3,p4,p5,p6,p7))
465
+
466
+ x = self.dropout(x)
467
+
468
+ return x
469
+
470
+
471
+ class Classifier(nn.Module):
472
+ def __init__(self, in_channels, num_anchors, num_classes, num_layers, pyramid_levels=5, onnx_export=False):
473
+ super(Classifier, self).__init__()
474
+ self.num_anchors = num_anchors
475
+ self.num_classes = num_classes
476
+ self.num_layers = num_layers
477
+ self.conv_list = nn.ModuleList(
478
+ [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)])
479
+ self.bn_list = nn.ModuleList(
480
+ [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in
481
+ range(pyramid_levels)])
482
+ self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False)
483
+ self.swish = MemoryEfficientSwish() if not onnx_export else Swish()
484
+
485
+ def forward(self, inputs):
486
+ feats = []
487
+ for feat, bn_list in zip(inputs, self.bn_list):
488
+ for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list):
489
+ feat = conv(feat)
490
+ feat = bn(feat)
491
+ feat = self.swish(feat)
492
+ feat = self.header(feat)
493
+
494
+ feat = feat.permute(0, 2, 3, 1)
495
+ feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors,
496
+ self.num_classes)
497
+ feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes)
498
+
499
+ feats.append(feat)
500
+
501
+ feats = torch.cat(feats, dim=1)
502
+ feats = feats.sigmoid()
503
+
504
+ return feats
505
+
506
+
507
+ class SwishImplementation(torch.autograd.Function):
508
+ @staticmethod
509
+ def forward(ctx, i):
510
+ result = i * torch.sigmoid(i)
511
+ ctx.save_for_backward(i)
512
+ return result
513
+
514
+ @staticmethod
515
+ def backward(ctx, grad_output):
516
+ i = ctx.saved_variables[0]
517
+ sigmoid_i = torch.sigmoid(i)
518
+ return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i)))
519
+
520
+
521
+ class MemoryEfficientSwish(nn.Module):
522
+ def forward(self, x):
523
+ return SwishImplementation.apply(x)
524
+
525
+
526
+ class Swish(nn.Module):
527
+ def forward(self, x):
528
+ return x * torch.sigmoid(x)
529
+
530
+
531
+ def drop_connect(inputs, p, training):
532
+ """ Drop connect. """
533
+ if not training: return inputs
534
+ batch_size = inputs.shape[0]
535
+ keep_prob = 1 - p
536
+ random_tensor = keep_prob
537
+ random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
538
+ binary_tensor = torch.floor(random_tensor)
539
+ output = inputs / keep_prob * binary_tensor
540
+ return output
541
+
542
+
543
+ def get_same_padding_conv2d(image_size=None):
544
+ """ Chooses static padding if you have specified an image size, and dynamic padding otherwise.
545
+ Static padding is necessary for ONNX exporting of models. """
546
+ if image_size is None:
547
+ return Conv2dDynamicSamePadding
548
+ else:
549
+ return partial(Conv2dStaticSamePadding, image_size=image_size)
550
+
551
+
552
+ class Conv2dDynamicSamePadding(nn.Conv2d):
553
+ """ 2D Convolutions like TensorFlow, for a dynamic image size """
554
+
555
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
556
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias)
557
+ self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2
558
+
559
+ def forward(self, x):
560
+ ih, iw = x.size()[-2:]
561
+ kh, kw = self.weight.size()[-2:]
562
+ sh, sw = self.stride
563
+ oh, ow = math.ceil(ih / sh), math.ceil(iw / sw)
564
+ pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0)
565
+ pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0)
566
+ if pad_h > 0 or pad_w > 0:
567
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
568
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
569
+
570
+
571
+ class MBConvBlock(nn.Module):
572
+ """
573
+ Mobile Inverted Residual Bottleneck Block
574
+
575
+ Args:
576
+ block_args (namedtuple): BlockArgs, see above
577
+ global_params (namedtuple): GlobalParam, see above
578
+
579
+ Attributes:
580
+ has_se (bool): Whether the block contains a Squeeze and Excitation layer.
581
+ """
582
+
583
+ def __init__(self, block_args, global_params):
584
+ super().__init__()
585
+ self._block_args = block_args
586
+ self._bn_mom = 1 - global_params.batch_norm_momentum
587
+ self._bn_eps = global_params.batch_norm_epsilon
588
+ self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1)
589
+ self.id_skip = block_args.id_skip # skip connection and drop connect
590
+
591
+ # Get static or dynamic convolution depending on image size
592
+ Conv2d = get_same_padding_conv2d(image_size=global_params.image_size)
593
+
594
+ # Expansion phase
595
+ inp = self._block_args.input_filters # number of input channels
596
+ oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels
597
+ if self._block_args.expand_ratio != 1:
598
+ self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False)
599
+ self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
600
+
601
+ # Depthwise convolution phase
602
+ k = self._block_args.kernel_size
603
+ s = self._block_args.stride
604
+ self._depthwise_conv = Conv2d(
605
+ in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise
606
+ kernel_size=k, stride=s, bias=False)
607
+ self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps)
608
+
609
+ # Squeeze and Excitation layer, if desired
610
+ if self.has_se:
611
+ num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio))
612
+ self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1)
613
+ self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1)
614
+
615
+ # Output phase
616
+ final_oup = self._block_args.output_filters
617
+ self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False)
618
+ self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps)
619
+ self._swish = MemoryEfficientSwish()
620
+
621
+ def forward(self, inputs, drop_connect_rate=None):
622
+ """
623
+ :param inputs: input tensor
624
+ :param drop_connect_rate: drop connect rate (float, between 0 and 1)
625
+ :return: output of block
626
+ """
627
+
628
+ # Expansion and Depthwise Convolution
629
+ x = inputs
630
+ if self._block_args.expand_ratio != 1:
631
+ x = self._expand_conv(inputs)
632
+ x = self._bn0(x)
633
+ x = self._swish(x)
634
+
635
+ x = self._depthwise_conv(x)
636
+ x = self._bn1(x)
637
+ x = self._swish(x)
638
+
639
+ # Squeeze and Excitation
640
+ if self.has_se:
641
+ x_squeezed = F.adaptive_avg_pool2d(x, 1)
642
+ x_squeezed = self._se_reduce(x_squeezed)
643
+ x_squeezed = self._swish(x_squeezed)
644
+ x_squeezed = self._se_expand(x_squeezed)
645
+ x = torch.sigmoid(x_squeezed) * x
646
+
647
+ x = self._project_conv(x)
648
+ x = self._bn2(x)
649
+
650
+ # Skip connection and drop connect
651
+ input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters
652
+ if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters:
653
+ if drop_connect_rate:
654
+ x = drop_connect(x, p=drop_connect_rate, training=self.training)
655
+ x = x + inputs # skip connection
656
+ return x
657
+
658
+ def set_swish(self, memory_efficient=True):
659
+ """Sets swish function as memory efficient (for training) or standard (for export)"""
660
+ self._swish = MemoryEfficientSwish() if memory_efficient else Swish()
661
+
662
+
663
+ class Conv2dStaticSamePadding(nn.Module):
664
+ """
665
+ The real keras/tensorflow conv2d with same padding
666
+ """
667
+
668
+ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs):
669
+ super().__init__()
670
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride,
671
+ bias=bias, groups=groups)
672
+ self.stride = self.conv.stride
673
+ self.kernel_size = self.conv.kernel_size
674
+ self.dilation = self.conv.dilation
675
+
676
+ if isinstance(self.stride, int):
677
+ self.stride = [self.stride] * 2
678
+ elif len(self.stride) == 1:
679
+ self.stride = [self.stride[0]] * 2
680
+
681
+ if isinstance(self.kernel_size, int):
682
+ self.kernel_size = [self.kernel_size] * 2
683
+ elif len(self.kernel_size) == 1:
684
+ self.kernel_size = [self.kernel_size[0]] * 2
685
+
686
+ def forward(self, x):
687
+ h, w = x.shape[-2:]
688
+
689
+ extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
690
+ extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
691
+
692
+ left = extra_h // 2
693
+ right = extra_h - left
694
+ top = extra_v // 2
695
+ bottom = extra_v - top
696
+
697
+ x = F.pad(x, [left, right, top, bottom])
698
+
699
+ x = self.conv(x)
700
+ return x
701
+
702
+
703
+ class MaxPool2dStaticSamePadding(nn.Module):
704
+ """
705
+ The real keras/tensorflow MaxPool2d with same padding
706
+ """
707
+
708
+ def __init__(self, *args, **kwargs):
709
+ super().__init__()
710
+ self.pool = nn.MaxPool2d(*args, **kwargs)
711
+ self.stride = self.pool.stride
712
+ self.kernel_size = self.pool.kernel_size
713
+
714
+ if isinstance(self.stride, int):
715
+ self.stride = [self.stride] * 2
716
+ elif len(self.stride) == 1:
717
+ self.stride = [self.stride[0]] * 2
718
+
719
+ if isinstance(self.kernel_size, int):
720
+ self.kernel_size = [self.kernel_size] * 2
721
+ elif len(self.kernel_size) == 1:
722
+ self.kernel_size = [self.kernel_size[0]] * 2
723
+
724
+ def forward(self, x):
725
+ h, w = x.shape[-2:]
726
+
727
+ extra_h = (math.ceil(w / self.stride[1]) - 1) * self.stride[1] - w + self.kernel_size[1]
728
+ extra_v = (math.ceil(h / self.stride[0]) - 1) * self.stride[0] - h + self.kernel_size[0]
729
+
730
+ left = extra_h // 2
731
+ right = extra_h - left
732
+ top = extra_v // 2
733
+ bottom = extra_v - top
734
+
735
+ x = F.pad(x, [left, right, top, bottom])
736
+
737
+ x = self.pool(x)
738
+ return x
739
+
740
+
741
+ class Activation(nn.Module):
742
+
743
+ def __init__(self, name, **params):
744
+
745
+ super().__init__()
746
+
747
+ if name is None or name == 'identity':
748
+ self.activation = nn.Identity(**params)
749
+ elif name == 'sigmoid':
750
+ self.activation = nn.Sigmoid()
751
+ elif name == 'softmax2d':
752
+ self.activation = nn.Softmax(dim=1, **params)
753
+ elif name == 'softmax':
754
+ self.activation = nn.Softmax(**params)
755
+ elif name == 'logsoftmax':
756
+ self.activation = nn.LogSoftmax(**params)
757
+ elif name == 'tanh':
758
+ self.activation = nn.Tanh()
759
+ # elif name == 'argmax':
760
+ # self.activation = ArgMax(**params)
761
+ # elif name == 'argmax2d':
762
+ # self.activation = ArgMax(dim=1, **params)
763
+ # elif name == 'clamp':
764
+ # self.activation = Clamp(**params)
765
+ elif callable(name):
766
+ self.activation = name(**params)
767
+ else:
768
+ raise ValueError('Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {}'.format(name))
769
+ def forward(self, x):
770
+ return self.activation(x)
771
+
772
+
773
+ class SegmentationHead(nn.Sequential):
774
+
775
+ def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
776
+ conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
777
+ upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
778
+ activation = Activation(activation)
779
+ super().__init__(conv2d, upsampling, activation)
780
+
781
+
782
+ class ClassificationHead(nn.Sequential):
783
+
784
+ def __init__(self, in_channels, classes, pooling="avg", dropout=0.2, activation=None):
785
+ if pooling not in ("max", "avg"):
786
+ raise ValueError("Pooling should be one of ('max', 'avg'), got {}.".format(pooling))
787
+ pool = nn.AdaptiveAvgPool2d(1) if pooling == 'avg' else nn.AdaptiveMaxPool2d(1)
788
+ flatten = nn.Flatten()
789
+ dropout = nn.Dropout(p=dropout, inplace=True) if dropout else nn.Identity()
790
+ linear = nn.Linear(in_channels, classes, bias=True)
791
+ activation = Activation(activation)
792
+ super().__init__(pool, flatten, dropout, linear, activation)
793
+
794
+
795
+ if __name__ == '__main__':
796
+ from tensorboardX import SummaryWriter
797
+
798
+
799
+ def count_parameters(model):
800
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)