lixc commited on
Commit
755671e
·
1 Parent(s): d9e5c87

add md5 and pytorch weights

Browse files
Files changed (4) hide show
  1. README.md +20 -0
  2. md5.txt +3 -0
  3. swav_imagenet_layer2.pt +3 -0
  4. trace_layer2.py +303 -0
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ ```python
4
+ import trace_layer2 as models
5
+ import torch
6
+
7
+ x=torch.randn(1, 3, 224, 224)
8
+
9
+ state_dict = torch.load('swav_imagenet_layer2.pt', map_location='cpu')
10
+
11
+ model = models.resnet50w2()
12
+ model.load_state_dict(state_dict)
13
+ model.eval()
14
+ feature = model(x)
15
+
16
+ traced_model = torch.jit.load('traced_swav_imagenet_layer2.pt', map_location='cpu')
17
+ traced_model.eval()
18
+ feature = traced_model(x)
19
+
20
+ ```
md5.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ 2d46caef59dd661c695114df1c161733 swav_imagenet_layer2.pt
2
+ d4c65a44119dcb606e5f2c4efe986847 swav_imagenet_layer2_sim.onnx
3
+ 428ac1fe949cc24c65a7df974b32bc2f traced_swav_imagenet_layer2.pt
swav_imagenet_layer2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1923d34e543b3120e23f4a62c8ce83d204530329d0c9a19c18008b1d5e9dc25d
3
+ size 23087521
trace_layer2.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+
12
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
13
+ """3x3 convolution with padding"""
14
+ return nn.Conv2d(
15
+ in_planes,
16
+ out_planes,
17
+ kernel_size=3,
18
+ stride=stride,
19
+ padding=dilation,
20
+ groups=groups,
21
+ bias=False,
22
+ dilation=dilation,
23
+ )
24
+
25
+
26
+ def conv1x1(in_planes, out_planes, stride=1):
27
+ """1x1 convolution"""
28
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
29
+
30
+
31
+ class BasicBlock(nn.Module):
32
+ expansion = 1
33
+ __constants__ = ["downsample"]
34
+
35
+ def __init__(
36
+ self,
37
+ inplanes,
38
+ planes,
39
+ stride=1,
40
+ downsample=None,
41
+ groups=1,
42
+ base_width=64,
43
+ dilation=1,
44
+ norm_layer=None,
45
+ ):
46
+ super(BasicBlock, self).__init__()
47
+ if norm_layer is None:
48
+ norm_layer = nn.BatchNorm2d
49
+ if groups != 1 or base_width != 64:
50
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
51
+ if dilation > 1:
52
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
53
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
54
+ self.conv1 = conv3x3(inplanes, planes, stride)
55
+ self.bn1 = norm_layer(planes)
56
+ self.relu = nn.ReLU(inplace=True)
57
+ self.conv2 = conv3x3(planes, planes)
58
+ self.bn2 = norm_layer(planes)
59
+ self.downsample = downsample
60
+ self.stride = stride
61
+
62
+ def forward(self, x):
63
+ identity = x
64
+
65
+ out = self.conv1(x)
66
+ out = self.bn1(out)
67
+ out = self.relu(out)
68
+
69
+ out = self.conv2(out)
70
+ out = self.bn2(out)
71
+
72
+ if self.downsample is not None:
73
+ identity = self.downsample(x)
74
+
75
+ out += identity
76
+ out = self.relu(out)
77
+
78
+ return out
79
+
80
+
81
+ class Bottleneck(nn.Module):
82
+ expansion = 4
83
+ __constants__ = ["downsample"]
84
+
85
+ def __init__(
86
+ self,
87
+ inplanes,
88
+ planes,
89
+ stride=1,
90
+ downsample=None,
91
+ groups=1,
92
+ base_width=64,
93
+ dilation=1,
94
+ norm_layer=None,
95
+ ):
96
+ super(Bottleneck, self).__init__()
97
+ if norm_layer is None:
98
+ norm_layer = nn.BatchNorm2d
99
+ width = int(planes * (base_width / 64.0)) * groups
100
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
101
+ self.conv1 = conv1x1(inplanes, width)
102
+ self.bn1 = norm_layer(width)
103
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
104
+ self.bn2 = norm_layer(width)
105
+ self.conv3 = conv1x1(width, planes * self.expansion)
106
+ self.bn3 = norm_layer(planes * self.expansion)
107
+ self.relu = nn.ReLU(inplace=True)
108
+ self.downsample = downsample
109
+ self.stride = stride
110
+
111
+ def forward(self, x):
112
+ identity = x
113
+
114
+ out = self.conv1(x)
115
+ out = self.bn1(out)
116
+ out = self.relu(out)
117
+
118
+ out = self.conv2(out)
119
+ out = self.bn2(out)
120
+ out = self.relu(out)
121
+
122
+ out = self.conv3(out)
123
+ out = self.bn3(out)
124
+
125
+ if self.downsample is not None:
126
+ identity = self.downsample(x)
127
+
128
+ out += identity
129
+ out = self.relu(out)
130
+
131
+ return out
132
+
133
+
134
+ class ResNet(nn.Module):
135
+ def __init__(
136
+ self,
137
+ block,
138
+ layers,
139
+ num_classes=1000,
140
+ zero_init_residual=False,
141
+ groups=1,
142
+ widen=1,
143
+ width_per_group=64,
144
+ replace_stride_with_dilation=None,
145
+ norm_layer=None,
146
+ ):
147
+ super(ResNet, self).__init__()
148
+ if norm_layer is None:
149
+ norm_layer = nn.BatchNorm2d
150
+ self._norm_layer = norm_layer
151
+
152
+ self.inplanes = width_per_group * widen
153
+ self.dilation = 1
154
+ if replace_stride_with_dilation is None:
155
+ # each element in the tuple indicates if we should replace
156
+ # the 2x2 stride with a dilated convolution instead
157
+ replace_stride_with_dilation = [False, False, False]
158
+ if len(replace_stride_with_dilation) != 3:
159
+ raise ValueError(
160
+ "replace_stride_with_dilation should be None "
161
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation)
162
+ )
163
+ self.groups = groups
164
+ self.base_width = width_per_group
165
+
166
+ num_out_filters = width_per_group * widen
167
+ self.conv1 = nn.Conv2d(
168
+ 3, num_out_filters, kernel_size=7, stride=2, padding=3, bias=False
169
+ )
170
+ self.bn1 = norm_layer(num_out_filters)
171
+ self.relu = nn.ReLU(inplace=True)
172
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
173
+ self.layer1 = self._make_layer(block, num_out_filters, layers[0])
174
+ num_out_filters *= 2
175
+ self.layer2 = self._make_layer(
176
+ block, num_out_filters, layers[1], stride=2, dilate=replace_stride_with_dilation[0]
177
+ )
178
+ #num_out_filters *= 2
179
+ #self.layer3 = self._make_layer(
180
+ # block, num_out_filters, layers[2], stride=2, dilate=replace_stride_with_dilation[1]
181
+ #)
182
+ #num_out_filters *= 2
183
+ #self.layer4 = self._make_layer(
184
+ # block, num_out_filters, layers[3], stride=2, dilate=replace_stride_with_dilation[2]
185
+ #)
186
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
187
+ #self.fc = nn.Linear(512 * block.expansion * widen, num_classes)
188
+
189
+ # Zero-initialize the last BN in each residual branch,
190
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
191
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
192
+ if zero_init_residual:
193
+ for m in self.modules():
194
+ if isinstance(m, Bottleneck):
195
+ nn.init.constant_(m.bn3.weight, 0)
196
+ elif isinstance(m, BasicBlock):
197
+ nn.init.constant_(m.bn2.weight, 0)
198
+
199
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
200
+ norm_layer = self._norm_layer
201
+ downsample = None
202
+ previous_dilation = self.dilation
203
+ if dilate:
204
+ self.dilation *= stride
205
+ stride = 1
206
+ if stride != 1 or self.inplanes != planes * block.expansion:
207
+ downsample = nn.Sequential(
208
+ conv1x1(self.inplanes, planes * block.expansion, stride),
209
+ norm_layer(planes * block.expansion),
210
+ )
211
+
212
+ layers = []
213
+ layers.append(
214
+ block(
215
+ self.inplanes,
216
+ planes,
217
+ stride,
218
+ downsample,
219
+ self.groups,
220
+ self.base_width,
221
+ previous_dilation,
222
+ norm_layer,
223
+ )
224
+ )
225
+ self.inplanes = planes * block.expansion
226
+ for _ in range(1, blocks):
227
+ layers.append(
228
+ block(
229
+ self.inplanes,
230
+ planes,
231
+ groups=self.groups,
232
+ base_width=self.base_width,
233
+ dilation=self.dilation,
234
+ norm_layer=norm_layer,
235
+ )
236
+ )
237
+
238
+ return nn.Sequential(*layers)
239
+
240
+ def forward(self, x):
241
+ x = self.conv1(x)
242
+ x = self.bn1(x)
243
+ x = self.relu(x)
244
+ x = self.maxpool(x)
245
+ x = self.layer1(x)
246
+ x = self.layer2(x)
247
+ #x = self.layer3(x)
248
+ #x = self.layer4(x)
249
+
250
+ x = self.avgpool(x)
251
+ x = torch.flatten(x, 1)
252
+ #x = self.fc(x)
253
+
254
+ return x
255
+
256
+
257
+ def resnet50(**kwargs):
258
+ return ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
259
+
260
+
261
+ def resnet50w2(**kwargs):
262
+ return ResNet(Bottleneck, [3, 4, 6, 3], widen=2, **kwargs)
263
+
264
+
265
+ def resnet50w4(**kwargs):
266
+ return ResNet(Bottleneck, [3, 4, 6, 3], widen=4, **kwargs)
267
+
268
+
269
+ def resnet50w5(**kwargs):
270
+ return ResNet(Bottleneck, [3, 4, 6, 3], widen=5, **kwargs)
271
+
272
+
273
+ if __name__ == '__main__':
274
+ import onnxruntime as ort
275
+ x=torch.rand(1,3,224,224)
276
+ model = resnet50w2()
277
+ model.eval()
278
+
279
+ swav_state_dict = torch.load('/opt/software/github/he-ai/swav_RN50w2_400ep_pretrain.pth.tar')
280
+
281
+ for k in list(swav_state_dict.keys()):
282
+ if k.startswith('module.layer3') or k.startswith('module.layer4') or k.startswith('module.pro'):del swav_state_dict[k]
283
+
284
+ for k in list(swav_state_dict.keys()):
285
+ swav_state_dict[k.replace('module.', '')] = swav_state_dict[k]
286
+ del swav_state_dict[k]
287
+ msg = model.load_state_dict(swav_state_dict, strict=False)
288
+ print(msg)
289
+ torch.save(swav_state_dict, 'swav_imagenet_layer2.pt')
290
+
291
+ traced_script_module = torch.jit.trace(model, x)
292
+ traced_script_module.save("traced_swav_imagenet_layer2.pt")
293
+ traced_feature = traced_script_module(x).detach().cpu().numpy()
294
+ print(traced_feature)
295
+ print(model(x))
296
+
297
+
298
+ dynamic_axes={"x": {0:"batch_size"}, 'feature': {0:'batch_size'}}
299
+ torch.onnx.export(model, x, "swav_imagenet_layer2.onnx", verbose=False, input_names=['x'], output_names=['feature'], dynamic_axes=dynamic_axes, do_constant_folding=True)
300
+ ort_session = ort.InferenceSession("swav_imagenet_layer2.onnx")
301
+ onnx_outputs = ort_session.run(None, {'x':x.numpy()})
302
+ print(onnx_outputs[0])
303
+