josedolot commited on
Commit
1ce1d40
·
1 Parent(s): 69b3540

Upload encoders/efficientnet.py

Browse files
Files changed (1) hide show
  1. encoders/efficientnet.py +178 -0
encoders/efficientnet.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2
+
3
+ Attributes:
4
+
5
+ _out_channels (list of int): specify number of channels for each encoder feature tensor
6
+ _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7
+ _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8
+
9
+ Methods:
10
+
11
+ forward(self, x: torch.Tensor)
12
+ produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13
+ shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14
+ with resolution same as input `x` tensor).
15
+
16
+ Input: `x` with shape (1, 3, 64, 64)
17
+ Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18
+ [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19
+ (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20
+
21
+ also should support number of features according to specified depth, e.g. if depth = 5,
22
+ number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23
+ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
+ """
25
+ import torch.nn as nn
26
+ from efficientnet_pytorch import EfficientNet
27
+ from efficientnet_pytorch.utils import url_map, url_map_advprop, get_model_params
28
+
29
+ from ._base import EncoderMixin
30
+
31
+
32
+ class EfficientNetEncoder(EfficientNet, EncoderMixin):
33
+ def __init__(self, stage_idxs, out_channels, model_name, depth=5):
34
+
35
+ blocks_args, global_params = get_model_params(model_name, override_params=None)
36
+ super().__init__(blocks_args, global_params)
37
+
38
+ self._stage_idxs = stage_idxs
39
+ self._out_channels = out_channels
40
+ self._depth = depth
41
+ self._in_channels = 3
42
+
43
+ del self._fc
44
+
45
+ def get_stages(self):
46
+ return [
47
+ nn.Identity(),
48
+ nn.Sequential(self._conv_stem, self._bn0, self._swish),
49
+ self._blocks[:self._stage_idxs[0]],
50
+ self._blocks[self._stage_idxs[0]:self._stage_idxs[1]],
51
+ self._blocks[self._stage_idxs[1]:self._stage_idxs[2]],
52
+ self._blocks[self._stage_idxs[2]:],
53
+ ]
54
+
55
+ def forward(self, x):
56
+ stages = self.get_stages()
57
+
58
+ block_number = 0.
59
+ drop_connect_rate = self._global_params.drop_connect_rate
60
+
61
+ features = []
62
+ for i in range(self._depth + 1):
63
+
64
+ # Identity and Sequential stages
65
+ if i < 2:
66
+ x = stages[i](x)
67
+
68
+ # Block stages need drop_connect rate
69
+ else:
70
+ for module in stages[i]:
71
+ drop_connect = drop_connect_rate * block_number / len(self._blocks)
72
+ block_number += 1.
73
+ x = module(x, drop_connect)
74
+
75
+ features.append(x)
76
+
77
+ return features
78
+
79
+ def load_state_dict(self, state_dict, **kwargs):
80
+ state_dict.pop("_fc.bias", None)
81
+ state_dict.pop("_fc.weight", None)
82
+ super().load_state_dict(state_dict, **kwargs)
83
+
84
+
85
+ def _get_pretrained_settings(encoder):
86
+ pretrained_settings = {
87
+ "imagenet": {
88
+ "mean": [0.485, 0.456, 0.406],
89
+ "std": [0.229, 0.224, 0.225],
90
+ "url": url_map[encoder],
91
+ "input_space": "RGB",
92
+ "input_range": [0, 1],
93
+ },
94
+ "advprop": {
95
+ "mean": [0.5, 0.5, 0.5],
96
+ "std": [0.5, 0.5, 0.5],
97
+ "url": url_map_advprop[encoder],
98
+ "input_space": "RGB",
99
+ "input_range": [0, 1],
100
+ }
101
+ }
102
+ return pretrained_settings
103
+
104
+
105
+ efficient_net_encoders = {
106
+ "efficientnet-b0": {
107
+ "encoder": EfficientNetEncoder,
108
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b0"),
109
+ "params": {
110
+ "out_channels": (3, 32, 24, 40, 112, 320),
111
+ "stage_idxs": (3, 5, 9, 16),
112
+ "model_name": "efficientnet-b0",
113
+ },
114
+ },
115
+ "efficientnet-b1": {
116
+ "encoder": EfficientNetEncoder,
117
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b1"),
118
+ "params": {
119
+ "out_channels": (3, 32, 24, 40, 112, 320),
120
+ "stage_idxs": (5, 8, 16, 23),
121
+ "model_name": "efficientnet-b1",
122
+ },
123
+ },
124
+ "efficientnet-b2": {
125
+ "encoder": EfficientNetEncoder,
126
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b2"),
127
+ "params": {
128
+ "out_channels": (3, 32, 24, 48, 120, 352),
129
+ "stage_idxs": (5, 8, 16, 23),
130
+ "model_name": "efficientnet-b2",
131
+ },
132
+ },
133
+ "efficientnet-b3": {
134
+ "encoder": EfficientNetEncoder,
135
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b3"),
136
+ "params": {
137
+ "out_channels": (3, 40, 32, 48, 136, 384),
138
+ "stage_idxs": (5, 8, 18, 26),
139
+ "model_name": "efficientnet-b3",
140
+ },
141
+ },
142
+ "efficientnet-b4": {
143
+ "encoder": EfficientNetEncoder,
144
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b4"),
145
+ "params": {
146
+ "out_channels": (3, 48, 32, 56, 160, 448),
147
+ "stage_idxs": (6, 10, 22, 32),
148
+ "model_name": "efficientnet-b4",
149
+ },
150
+ },
151
+ "efficientnet-b5": {
152
+ "encoder": EfficientNetEncoder,
153
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b5"),
154
+ "params": {
155
+ "out_channels": (3, 48, 40, 64, 176, 512),
156
+ "stage_idxs": (8, 13, 27, 39),
157
+ "model_name": "efficientnet-b5",
158
+ },
159
+ },
160
+ "efficientnet-b6": {
161
+ "encoder": EfficientNetEncoder,
162
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b6"),
163
+ "params": {
164
+ "out_channels": (3, 56, 40, 72, 200, 576),
165
+ "stage_idxs": (9, 15, 31, 45),
166
+ "model_name": "efficientnet-b6",
167
+ },
168
+ },
169
+ "efficientnet-b7": {
170
+ "encoder": EfficientNetEncoder,
171
+ "pretrained_settings": _get_pretrained_settings("efficientnet-b7"),
172
+ "params": {
173
+ "out_channels": (3, 64, 48, 80, 224, 640),
174
+ "stage_idxs": (11, 18, 38, 55),
175
+ "model_name": "efficientnet-b7",
176
+ },
177
+ },
178
+ }