josedolot commited on
Commit
2eb9939
·
1 Parent(s): 883a7f9

Upload encoders/timm_resnest.py

Browse files
Files changed (1) hide show
  1. encoders/timm_resnest.py +208 -0
encoders/timm_resnest.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._base import EncoderMixin
2
+ from timm.models.resnet import ResNet
3
+ from timm.models.resnest import ResNestBottleneck
4
+ import torch.nn as nn
5
+
6
+
7
+ class ResNestEncoder(ResNet, EncoderMixin):
8
+ def __init__(self, out_channels, depth=5, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self._depth = depth
11
+ self._out_channels = out_channels
12
+ self._in_channels = 3
13
+
14
+ del self.fc
15
+ del self.global_pool
16
+
17
+ def get_stages(self):
18
+ return [
19
+ nn.Identity(),
20
+ nn.Sequential(self.conv1, self.bn1, self.act1),
21
+ nn.Sequential(self.maxpool, self.layer1),
22
+ self.layer2,
23
+ self.layer3,
24
+ self.layer4,
25
+ ]
26
+
27
+ def make_dilated(self, stage_list, dilation_list):
28
+ raise ValueError("ResNest encoders do not support dilated mode")
29
+
30
+ def forward(self, x):
31
+ stages = self.get_stages()
32
+
33
+ features = []
34
+ for i in range(self._depth + 1):
35
+ x = stages[i](x)
36
+ features.append(x)
37
+
38
+ return features
39
+
40
+ def load_state_dict(self, state_dict, **kwargs):
41
+ state_dict.pop("fc.bias", None)
42
+ state_dict.pop("fc.weight", None)
43
+ super().load_state_dict(state_dict, **kwargs)
44
+
45
+
46
+ resnest_weights = {
47
+ 'timm-resnest14d': {
48
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest14-9c8fe254.pth'
49
+ },
50
+ 'timm-resnest26d': {
51
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/gluon_resnest26-50eb607c.pth'
52
+ },
53
+ 'timm-resnest50d': {
54
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50-528c19ca.pth',
55
+ },
56
+ 'timm-resnest101e': {
57
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest101-22405ba7.pth',
58
+ },
59
+ 'timm-resnest200e': {
60
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest200-75117900.pth',
61
+ },
62
+ 'timm-resnest269e': {
63
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest269-0cc87c48.pth',
64
+ },
65
+ 'timm-resnest50d_4s2x40d': {
66
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_4s2x40d-41d14ed0.pth',
67
+ },
68
+ 'timm-resnest50d_1s4x24d': {
69
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-resnest/resnest50_fast_1s4x24d-d4a4f76f.pth',
70
+ }
71
+ }
72
+
73
+ pretrained_settings = {}
74
+ for model_name, sources in resnest_weights.items():
75
+ pretrained_settings[model_name] = {}
76
+ for source_name, source_url in sources.items():
77
+ pretrained_settings[model_name][source_name] = {
78
+ "url": source_url,
79
+ 'input_size': [3, 224, 224],
80
+ 'input_range': [0, 1],
81
+ 'mean': [0.485, 0.456, 0.406],
82
+ 'std': [0.229, 0.224, 0.225],
83
+ 'num_classes': 1000
84
+ }
85
+
86
+
87
+ timm_resnest_encoders = {
88
+ 'timm-resnest14d': {
89
+ 'encoder': ResNestEncoder,
90
+ "pretrained_settings": pretrained_settings["timm-resnest14d"],
91
+ 'params': {
92
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
93
+ 'block': ResNestBottleneck,
94
+ 'layers': [1, 1, 1, 1],
95
+ 'stem_type': 'deep',
96
+ 'stem_width': 32,
97
+ 'avg_down': True,
98
+ 'base_width': 64,
99
+ 'cardinality': 1,
100
+ 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
101
+ }
102
+ },
103
+ 'timm-resnest26d': {
104
+ 'encoder': ResNestEncoder,
105
+ "pretrained_settings": pretrained_settings["timm-resnest26d"],
106
+ 'params': {
107
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
108
+ 'block': ResNestBottleneck,
109
+ 'layers': [2, 2, 2, 2],
110
+ 'stem_type': 'deep',
111
+ 'stem_width': 32,
112
+ 'avg_down': True,
113
+ 'base_width': 64,
114
+ 'cardinality': 1,
115
+ 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
116
+ }
117
+ },
118
+ 'timm-resnest50d': {
119
+ 'encoder': ResNestEncoder,
120
+ "pretrained_settings": pretrained_settings["timm-resnest50d"],
121
+ 'params': {
122
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
123
+ 'block': ResNestBottleneck,
124
+ 'layers': [3, 4, 6, 3],
125
+ 'stem_type': 'deep',
126
+ 'stem_width': 32,
127
+ 'avg_down': True,
128
+ 'base_width': 64,
129
+ 'cardinality': 1,
130
+ 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
131
+ }
132
+ },
133
+ 'timm-resnest101e': {
134
+ 'encoder': ResNestEncoder,
135
+ "pretrained_settings": pretrained_settings["timm-resnest101e"],
136
+ 'params': {
137
+ 'out_channels': (3, 128, 256, 512, 1024, 2048),
138
+ 'block': ResNestBottleneck,
139
+ 'layers': [3, 4, 23, 3],
140
+ 'stem_type': 'deep',
141
+ 'stem_width': 64,
142
+ 'avg_down': True,
143
+ 'base_width': 64,
144
+ 'cardinality': 1,
145
+ 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
146
+ }
147
+ },
148
+ 'timm-resnest200e': {
149
+ 'encoder': ResNestEncoder,
150
+ "pretrained_settings": pretrained_settings["timm-resnest200e"],
151
+ 'params': {
152
+ 'out_channels': (3, 128, 256, 512, 1024, 2048),
153
+ 'block': ResNestBottleneck,
154
+ 'layers': [3, 24, 36, 3],
155
+ 'stem_type': 'deep',
156
+ 'stem_width': 64,
157
+ 'avg_down': True,
158
+ 'base_width': 64,
159
+ 'cardinality': 1,
160
+ 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
161
+ }
162
+ },
163
+ 'timm-resnest269e': {
164
+ 'encoder': ResNestEncoder,
165
+ "pretrained_settings": pretrained_settings["timm-resnest269e"],
166
+ 'params': {
167
+ 'out_channels': (3, 128, 256, 512, 1024, 2048),
168
+ 'block': ResNestBottleneck,
169
+ 'layers': [3, 30, 48, 8],
170
+ 'stem_type': 'deep',
171
+ 'stem_width': 64,
172
+ 'avg_down': True,
173
+ 'base_width': 64,
174
+ 'cardinality': 1,
175
+ 'block_args': {'radix': 2, 'avd': True, 'avd_first': False}
176
+ },
177
+ },
178
+ 'timm-resnest50d_4s2x40d': {
179
+ 'encoder': ResNestEncoder,
180
+ "pretrained_settings": pretrained_settings["timm-resnest50d_4s2x40d"],
181
+ 'params': {
182
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
183
+ 'block': ResNestBottleneck,
184
+ 'layers': [3, 4, 6, 3],
185
+ 'stem_type': 'deep',
186
+ 'stem_width': 32,
187
+ 'avg_down': True,
188
+ 'base_width': 40,
189
+ 'cardinality': 2,
190
+ 'block_args': {'radix': 4, 'avd': True, 'avd_first': True}
191
+ }
192
+ },
193
+ 'timm-resnest50d_1s4x24d': {
194
+ 'encoder': ResNestEncoder,
195
+ "pretrained_settings": pretrained_settings["timm-resnest50d_1s4x24d"],
196
+ 'params': {
197
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
198
+ 'block': ResNestBottleneck,
199
+ 'layers': [3, 4, 6, 3],
200
+ 'stem_type': 'deep',
201
+ 'stem_width': 32,
202
+ 'avg_down': True,
203
+ 'base_width': 24,
204
+ 'cardinality': 4,
205
+ 'block_args': {'radix': 1, 'avd': True, 'avd_first': True}
206
+ }
207
+ }
208
+ }