josedolot commited on
Commit
883a7f9
·
1 Parent(s): f7dc56f

Upload encoders/timm_res2net.py

Browse files
Files changed (1) hide show
  1. encoders/timm_res2net.py +163 -0
encoders/timm_res2net.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._base import EncoderMixin
2
+ from timm.models.resnet import ResNet
3
+ from timm.models.res2net import Bottle2neck
4
+ import torch.nn as nn
5
+
6
+
7
+ class Res2NetEncoder(ResNet, EncoderMixin):
8
+ def __init__(self, out_channels, depth=5, **kwargs):
9
+ super().__init__(**kwargs)
10
+ self._depth = depth
11
+ self._out_channels = out_channels
12
+ self._in_channels = 3
13
+
14
+ del self.fc
15
+ del self.global_pool
16
+
17
+ def get_stages(self):
18
+ return [
19
+ nn.Identity(),
20
+ nn.Sequential(self.conv1, self.bn1, self.act1),
21
+ nn.Sequential(self.maxpool, self.layer1),
22
+ self.layer2,
23
+ self.layer3,
24
+ self.layer4,
25
+ ]
26
+
27
+ def make_dilated(self, stage_list, dilation_list):
28
+ raise ValueError("Res2Net encoders do not support dilated mode")
29
+
30
+ def forward(self, x):
31
+ stages = self.get_stages()
32
+
33
+ features = []
34
+ for i in range(self._depth + 1):
35
+ x = stages[i](x)
36
+ features.append(x)
37
+
38
+ return features
39
+
40
+ def load_state_dict(self, state_dict, **kwargs):
41
+ state_dict.pop("fc.bias", None)
42
+ state_dict.pop("fc.weight", None)
43
+ super().load_state_dict(state_dict, **kwargs)
44
+
45
+
46
+ res2net_weights = {
47
+ 'timm-res2net50_26w_4s': {
48
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_4s-06e79181.pth'
49
+ },
50
+ 'timm-res2net50_48w_2s': {
51
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_48w_2s-afed724a.pth'
52
+ },
53
+ 'timm-res2net50_14w_8s': {
54
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_14w_8s-6527dddc.pth',
55
+ },
56
+ 'timm-res2net50_26w_6s': {
57
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_6s-19041792.pth',
58
+ },
59
+ 'timm-res2net50_26w_8s': {
60
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net50_26w_8s-2c7c9f12.pth',
61
+ },
62
+ 'timm-res2net101_26w_4s': {
63
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2net101_26w_4s-02a759a1.pth',
64
+ },
65
+ 'timm-res2next50': {
66
+ 'imagenet': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-res2net/res2next50_4s-6ef7e7bf.pth',
67
+ }
68
+ }
69
+
70
+ pretrained_settings = {}
71
+ for model_name, sources in res2net_weights.items():
72
+ pretrained_settings[model_name] = {}
73
+ for source_name, source_url in sources.items():
74
+ pretrained_settings[model_name][source_name] = {
75
+ "url": source_url,
76
+ 'input_size': [3, 224, 224],
77
+ 'input_range': [0, 1],
78
+ 'mean': [0.485, 0.456, 0.406],
79
+ 'std': [0.229, 0.224, 0.225],
80
+ 'num_classes': 1000
81
+ }
82
+
83
+
84
+ timm_res2net_encoders = {
85
+ 'timm-res2net50_26w_4s': {
86
+ 'encoder': Res2NetEncoder,
87
+ "pretrained_settings": pretrained_settings["timm-res2net50_26w_4s"],
88
+ 'params': {
89
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
90
+ 'block': Bottle2neck,
91
+ 'layers': [3, 4, 6, 3],
92
+ 'base_width': 26,
93
+ 'block_args': {'scale': 4}
94
+ },
95
+ },
96
+ 'timm-res2net101_26w_4s': {
97
+ 'encoder': Res2NetEncoder,
98
+ "pretrained_settings": pretrained_settings["timm-res2net101_26w_4s"],
99
+ 'params': {
100
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
101
+ 'block': Bottle2neck,
102
+ 'layers': [3, 4, 23, 3],
103
+ 'base_width': 26,
104
+ 'block_args': {'scale': 4}
105
+ },
106
+ },
107
+ 'timm-res2net50_26w_6s': {
108
+ 'encoder': Res2NetEncoder,
109
+ "pretrained_settings": pretrained_settings["timm-res2net50_26w_6s"],
110
+ 'params': {
111
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
112
+ 'block': Bottle2neck,
113
+ 'layers': [3, 4, 6, 3],
114
+ 'base_width': 26,
115
+ 'block_args': {'scale': 6}
116
+ },
117
+ },
118
+ 'timm-res2net50_26w_8s': {
119
+ 'encoder': Res2NetEncoder,
120
+ "pretrained_settings": pretrained_settings["timm-res2net50_26w_8s"],
121
+ 'params': {
122
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
123
+ 'block': Bottle2neck,
124
+ 'layers': [3, 4, 6, 3],
125
+ 'base_width': 26,
126
+ 'block_args': {'scale': 8}
127
+ },
128
+ },
129
+ 'timm-res2net50_48w_2s': {
130
+ 'encoder': Res2NetEncoder,
131
+ "pretrained_settings": pretrained_settings["timm-res2net50_48w_2s"],
132
+ 'params': {
133
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
134
+ 'block': Bottle2neck,
135
+ 'layers': [3, 4, 6, 3],
136
+ 'base_width': 48,
137
+ 'block_args': {'scale': 2}
138
+ },
139
+ },
140
+ 'timm-res2net50_14w_8s': {
141
+ 'encoder': Res2NetEncoder,
142
+ "pretrained_settings": pretrained_settings["timm-res2net50_14w_8s"],
143
+ 'params': {
144
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
145
+ 'block': Bottle2neck,
146
+ 'layers': [3, 4, 6, 3],
147
+ 'base_width': 14,
148
+ 'block_args': {'scale': 8}
149
+ },
150
+ },
151
+ 'timm-res2next50': {
152
+ 'encoder': Res2NetEncoder,
153
+ "pretrained_settings": pretrained_settings["timm-res2next50"],
154
+ 'params': {
155
+ 'out_channels': (3, 64, 256, 512, 1024, 2048),
156
+ 'block': Bottle2neck,
157
+ 'layers': [3, 4, 6, 3],
158
+ 'base_width': 4,
159
+ 'cardinality': 8,
160
+ 'block_args': {'scale': 4}
161
+ },
162
+ }
163
+ }