josedolot commited on
Commit
f7dc56f
·
1 Parent(s): b459bd6

Upload encoders/timm_regnet.py

Browse files
Files changed (1) hide show
  1. encoders/timm_regnet.py +332 -0
encoders/timm_regnet.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._base import EncoderMixin
2
+ from timm.models.regnet import RegNet
3
+ import torch.nn as nn
4
+
5
+
6
+ class RegNetEncoder(RegNet, EncoderMixin):
7
+ def __init__(self, out_channels, depth=5, **kwargs):
8
+ super().__init__(**kwargs)
9
+ self._depth = depth
10
+ self._out_channels = out_channels
11
+ self._in_channels = 3
12
+
13
+ del self.head
14
+
15
+ def get_stages(self):
16
+ return [
17
+ nn.Identity(),
18
+ self.stem,
19
+ self.s1,
20
+ self.s2,
21
+ self.s3,
22
+ self.s4,
23
+ ]
24
+
25
+ def forward(self, x):
26
+ stages = self.get_stages()
27
+
28
+ features = []
29
+ for i in range(self._depth + 1):
30
+ x = stages[i](x)
31
+ features.append(x)
32
+
33
+ return features
34
+
35
+ def load_state_dict(self, state_dict, **kwargs):
36
+ state_dict.pop("head.fc.weight", None)
37
+ state_dict.pop("head.fc.bias", None)
38
+ super().load_state_dict(state_dict, **kwargs)
39
+
40
+
41
+ regnet_weights = {
42
+ 'timm-regnetx_002': {
43
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_002-e7e85e5c.pth',
44
+ },
45
+ 'timm-regnetx_004': {
46
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_004-7d0e9424.pth',
47
+ },
48
+ 'timm-regnetx_006': {
49
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_006-85ec1baa.pth',
50
+ },
51
+ 'timm-regnetx_008': {
52
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_008-d8b470eb.pth',
53
+ },
54
+ 'timm-regnetx_016': {
55
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_016-65ca972a.pth',
56
+ },
57
+ 'timm-regnetx_032': {
58
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_032-ed0c7f7e.pth',
59
+ },
60
+ 'timm-regnetx_040': {
61
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_040-73c2a654.pth',
62
+ },
63
+ 'timm-regnetx_064': {
64
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_064-29278baa.pth',
65
+ },
66
+ 'timm-regnetx_080': {
67
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_080-7c7fcab1.pth',
68
+ },
69
+ 'timm-regnetx_120': {
70
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth',
71
+ },
72
+ 'timm-regnetx_160': {
73
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth',
74
+ },
75
+ 'timm-regnetx_320': {
76
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth',
77
+ },
78
+ 'timm-regnety_002': {
79
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth',
80
+ },
81
+ 'timm-regnety_004': {
82
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth',
83
+ },
84
+ 'timm-regnety_006': {
85
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth',
86
+ },
87
+ 'timm-regnety_008': {
88
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_008-dc900dbe.pth',
89
+ },
90
+ 'timm-regnety_016': {
91
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_016-54367f74.pth',
92
+ },
93
+ 'timm-regnety_032': {
94
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/regnety_032_ra-7f2439f9.pth'
95
+ },
96
+ 'timm-regnety_040': {
97
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_040-f0d569f9.pth'
98
+ },
99
+ 'timm-regnety_064': {
100
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_064-0a48325c.pth'
101
+ },
102
+ 'timm-regnety_080': {
103
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_080-e7f3eb93.pth',
104
+ },
105
+ 'timm-regnety_120': {
106
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_120-721ba79a.pth',
107
+ },
108
+ 'timm-regnety_160': {
109
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_160-d64013cd.pth',
110
+ },
111
+ 'timm-regnety_320': {
112
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'
113
+ }
114
+ }
115
+
116
+ pretrained_settings = {}
117
+ for model_name, sources in regnet_weights.items():
118
+ pretrained_settings[model_name] = {}
119
+ for source_name, source_url in sources.items():
120
+ pretrained_settings[model_name][source_name] = {
121
+ "url": source_url,
122
+ 'input_size': [3, 224, 224],
123
+ 'input_range': [0, 1],
124
+ 'mean': [0.485, 0.456, 0.406],
125
+ 'std': [0.229, 0.224, 0.225],
126
+ 'num_classes': 1000
127
+ }
128
+
129
+ # at this point I am too lazy to copy configs, so I just used the same configs from timm's repo
130
+
131
+
132
+ def _mcfg(**kwargs):
133
+ cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
134
+ cfg.update(**kwargs)
135
+ return cfg
136
+
137
+
138
+ timm_regnet_encoders = {
139
+ 'timm-regnetx_002': {
140
+ 'encoder': RegNetEncoder,
141
+ "pretrained_settings": pretrained_settings["timm-regnetx_002"],
142
+ 'params': {
143
+ 'out_channels': (3, 32, 24, 56, 152, 368),
144
+ 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13)
145
+ },
146
+ },
147
+ 'timm-regnetx_004': {
148
+ 'encoder': RegNetEncoder,
149
+ "pretrained_settings": pretrained_settings["timm-regnetx_004"],
150
+ 'params': {
151
+ 'out_channels': (3, 32, 32, 64, 160, 384),
152
+ 'cfg': _mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22)
153
+ },
154
+ },
155
+ 'timm-regnetx_006': {
156
+ 'encoder': RegNetEncoder,
157
+ "pretrained_settings": pretrained_settings["timm-regnetx_006"],
158
+ 'params': {
159
+ 'out_channels': (3, 32, 48, 96, 240, 528),
160
+ 'cfg': _mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16)
161
+ },
162
+ },
163
+ 'timm-regnetx_008': {
164
+ 'encoder': RegNetEncoder,
165
+ "pretrained_settings": pretrained_settings["timm-regnetx_008"],
166
+ 'params': {
167
+ 'out_channels': (3, 32, 64, 128, 288, 672),
168
+ 'cfg': _mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16)
169
+ },
170
+ },
171
+ 'timm-regnetx_016': {
172
+ 'encoder': RegNetEncoder,
173
+ "pretrained_settings": pretrained_settings["timm-regnetx_016"],
174
+ 'params': {
175
+ 'out_channels': (3, 32, 72, 168, 408, 912),
176
+ 'cfg': _mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18)
177
+ },
178
+ },
179
+ 'timm-regnetx_032': {
180
+ 'encoder': RegNetEncoder,
181
+ "pretrained_settings": pretrained_settings["timm-regnetx_032"],
182
+ 'params': {
183
+ 'out_channels': (3, 32, 96, 192, 432, 1008),
184
+ 'cfg': _mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25)
185
+ },
186
+ },
187
+ 'timm-regnetx_040': {
188
+ 'encoder': RegNetEncoder,
189
+ "pretrained_settings": pretrained_settings["timm-regnetx_040"],
190
+ 'params': {
191
+ 'out_channels': (3, 32, 80, 240, 560, 1360),
192
+ 'cfg': _mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23)
193
+ },
194
+ },
195
+ 'timm-regnetx_064': {
196
+ 'encoder': RegNetEncoder,
197
+ "pretrained_settings": pretrained_settings["timm-regnetx_064"],
198
+ 'params': {
199
+ 'out_channels': (3, 32, 168, 392, 784, 1624),
200
+ 'cfg': _mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17)
201
+ },
202
+ },
203
+ 'timm-regnetx_080': {
204
+ 'encoder': RegNetEncoder,
205
+ "pretrained_settings": pretrained_settings["timm-regnetx_080"],
206
+ 'params': {
207
+ 'out_channels': (3, 32, 80, 240, 720, 1920),
208
+ 'cfg': _mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23)
209
+ },
210
+ },
211
+ 'timm-regnetx_120': {
212
+ 'encoder': RegNetEncoder,
213
+ "pretrained_settings": pretrained_settings["timm-regnetx_120"],
214
+ 'params': {
215
+ 'out_channels': (3, 32, 224, 448, 896, 2240),
216
+ 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19)
217
+ },
218
+ },
219
+ 'timm-regnetx_160': {
220
+ 'encoder': RegNetEncoder,
221
+ "pretrained_settings": pretrained_settings["timm-regnetx_160"],
222
+ 'params': {
223
+ 'out_channels': (3, 32, 256, 512, 896, 2048),
224
+ 'cfg': _mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22)
225
+ },
226
+ },
227
+ 'timm-regnetx_320': {
228
+ 'encoder': RegNetEncoder,
229
+ "pretrained_settings": pretrained_settings["timm-regnetx_320"],
230
+ 'params': {
231
+ 'out_channels': (3, 32, 336, 672, 1344, 2520),
232
+ 'cfg': _mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23)
233
+ },
234
+ },
235
+ #regnety
236
+ 'timm-regnety_002': {
237
+ 'encoder': RegNetEncoder,
238
+ "pretrained_settings": pretrained_settings["timm-regnety_002"],
239
+ 'params': {
240
+ 'out_channels': (3, 32, 24, 56, 152, 368),
241
+ 'cfg': _mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25)
242
+ },
243
+ },
244
+ 'timm-regnety_004': {
245
+ 'encoder': RegNetEncoder,
246
+ "pretrained_settings": pretrained_settings["timm-regnety_004"],
247
+ 'params': {
248
+ 'out_channels': (3, 32, 48, 104, 208, 440),
249
+ 'cfg': _mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25)
250
+ },
251
+ },
252
+ 'timm-regnety_006': {
253
+ 'encoder': RegNetEncoder,
254
+ "pretrained_settings": pretrained_settings["timm-regnety_006"],
255
+ 'params': {
256
+ 'out_channels': (3, 32, 48, 112, 256, 608),
257
+ 'cfg': _mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25)
258
+ },
259
+ },
260
+ 'timm-regnety_008': {
261
+ 'encoder': RegNetEncoder,
262
+ "pretrained_settings": pretrained_settings["timm-regnety_008"],
263
+ 'params': {
264
+ 'out_channels': (3, 32, 64, 128, 320, 768),
265
+ 'cfg': _mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25)
266
+ },
267
+ },
268
+ 'timm-regnety_016': {
269
+ 'encoder': RegNetEncoder,
270
+ "pretrained_settings": pretrained_settings["timm-regnety_016"],
271
+ 'params': {
272
+ 'out_channels': (3, 32, 48, 120, 336, 888),
273
+ 'cfg': _mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25)
274
+ },
275
+ },
276
+ 'timm-regnety_032': {
277
+ 'encoder': RegNetEncoder,
278
+ "pretrained_settings": pretrained_settings["timm-regnety_032"],
279
+ 'params': {
280
+ 'out_channels': (3, 32, 72, 216, 576, 1512),
281
+ 'cfg': _mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25)
282
+ },
283
+ },
284
+ 'timm-regnety_040': {
285
+ 'encoder': RegNetEncoder,
286
+ "pretrained_settings": pretrained_settings["timm-regnety_040"],
287
+ 'params': {
288
+ 'out_channels': (3, 32, 128, 192, 512, 1088),
289
+ 'cfg': _mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25)
290
+ },
291
+ },
292
+ 'timm-regnety_064': {
293
+ 'encoder': RegNetEncoder,
294
+ "pretrained_settings": pretrained_settings["timm-regnety_064"],
295
+ 'params': {
296
+ 'out_channels': (3, 32, 144, 288, 576, 1296),
297
+ 'cfg': _mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25)
298
+ },
299
+ },
300
+ 'timm-regnety_080': {
301
+ 'encoder': RegNetEncoder,
302
+ "pretrained_settings": pretrained_settings["timm-regnety_080"],
303
+ 'params': {
304
+ 'out_channels': (3, 32, 168, 448, 896, 2016),
305
+ 'cfg': _mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25)
306
+ },
307
+ },
308
+ 'timm-regnety_120': {
309
+ 'encoder': RegNetEncoder,
310
+ "pretrained_settings": pretrained_settings["timm-regnety_120"],
311
+ 'params': {
312
+ 'out_channels': (3, 32, 224, 448, 896, 2240),
313
+ 'cfg': _mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25)
314
+ },
315
+ },
316
+ 'timm-regnety_160': {
317
+ 'encoder': RegNetEncoder,
318
+ "pretrained_settings": pretrained_settings["timm-regnety_160"],
319
+ 'params': {
320
+ 'out_channels': (3, 32, 224, 448, 1232, 3024),
321
+ 'cfg': _mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25)
322
+ },
323
+ },
324
+ 'timm-regnety_320': {
325
+ 'encoder': RegNetEncoder,
326
+ "pretrained_settings": pretrained_settings["timm-regnety_320"],
327
+ 'params': {
328
+ 'out_channels': (3, 32, 232, 696, 1392, 3712),
329
+ 'cfg': _mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25)
330
+ },
331
+ },
332
+ }