josedolot commited on
Commit
69b3540
·
1 Parent(s): c731cd7

Upload encoders/dpn.py

Browse files
Files changed (1) hide show
  1. encoders/dpn.py +170 -0
encoders/dpn.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2
+
3
+ Attributes:
4
+
5
+ _out_channels (list of int): specify number of channels for each encoder feature tensor
6
+ _depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7
+ _in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8
+
9
+ Methods:
10
+
11
+ forward(self, x: torch.Tensor)
12
+ produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13
+ shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14
+ with resolution same as input `x` tensor).
15
+
16
+ Input: `x` with shape (1, 3, 64, 64)
17
+ Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18
+ [(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19
+ (1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20
+
21
+ also should support number of features according to specified depth, e.g. if depth = 5,
22
+ number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23
+ depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ from pretrainedmodels.models.dpn import DPN
31
+ from pretrainedmodels.models.dpn import pretrained_settings
32
+
33
+ from ._base import EncoderMixin
34
+
35
+
36
+ class DPNEncoder(DPN, EncoderMixin):
37
+ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs):
38
+ super().__init__(**kwargs)
39
+ self._stage_idxs = stage_idxs
40
+ self._depth = depth
41
+ self._out_channels = out_channels
42
+ self._in_channels = 3
43
+
44
+ del self.last_linear
45
+
46
+ def get_stages(self):
47
+ return [
48
+ nn.Identity(),
49
+ nn.Sequential(self.features[0].conv, self.features[0].bn, self.features[0].act),
50
+ nn.Sequential(self.features[0].pool, self.features[1 : self._stage_idxs[0]]),
51
+ self.features[self._stage_idxs[0] : self._stage_idxs[1]],
52
+ self.features[self._stage_idxs[1] : self._stage_idxs[2]],
53
+ self.features[self._stage_idxs[2] : self._stage_idxs[3]],
54
+ ]
55
+
56
+ def forward(self, x):
57
+
58
+ stages = self.get_stages()
59
+
60
+ features = []
61
+ for i in range(self._depth + 1):
62
+ x = stages[i](x)
63
+ if isinstance(x, (list, tuple)):
64
+ features.append(F.relu(torch.cat(x, dim=1), inplace=True))
65
+ else:
66
+ features.append(x)
67
+
68
+ return features
69
+
70
+ def load_state_dict(self, state_dict, **kwargs):
71
+ state_dict.pop("last_linear.bias", None)
72
+ state_dict.pop("last_linear.weight", None)
73
+ super().load_state_dict(state_dict, **kwargs)
74
+
75
+
76
+ dpn_encoders = {
77
+ "dpn68": {
78
+ "encoder": DPNEncoder,
79
+ "pretrained_settings": pretrained_settings["dpn68"],
80
+ "params": {
81
+ "stage_idxs": (4, 8, 20, 24),
82
+ "out_channels": (3, 10, 144, 320, 704, 832),
83
+ "groups": 32,
84
+ "inc_sec": (16, 32, 32, 64),
85
+ "k_r": 128,
86
+ "k_sec": (3, 4, 12, 3),
87
+ "num_classes": 1000,
88
+ "num_init_features": 10,
89
+ "small": True,
90
+ "test_time_pool": True,
91
+ },
92
+ },
93
+ "dpn68b": {
94
+ "encoder": DPNEncoder,
95
+ "pretrained_settings": pretrained_settings["dpn68b"],
96
+ "params": {
97
+ "stage_idxs": (4, 8, 20, 24),
98
+ "out_channels": (3, 10, 144, 320, 704, 832),
99
+ "b": True,
100
+ "groups": 32,
101
+ "inc_sec": (16, 32, 32, 64),
102
+ "k_r": 128,
103
+ "k_sec": (3, 4, 12, 3),
104
+ "num_classes": 1000,
105
+ "num_init_features": 10,
106
+ "small": True,
107
+ "test_time_pool": True,
108
+ },
109
+ },
110
+ "dpn92": {
111
+ "encoder": DPNEncoder,
112
+ "pretrained_settings": pretrained_settings["dpn92"],
113
+ "params": {
114
+ "stage_idxs": (4, 8, 28, 32),
115
+ "out_channels": (3, 64, 336, 704, 1552, 2688),
116
+ "groups": 32,
117
+ "inc_sec": (16, 32, 24, 128),
118
+ "k_r": 96,
119
+ "k_sec": (3, 4, 20, 3),
120
+ "num_classes": 1000,
121
+ "num_init_features": 64,
122
+ "test_time_pool": True,
123
+ },
124
+ },
125
+ "dpn98": {
126
+ "encoder": DPNEncoder,
127
+ "pretrained_settings": pretrained_settings["dpn98"],
128
+ "params": {
129
+ "stage_idxs": (4, 10, 30, 34),
130
+ "out_channels": (3, 96, 336, 768, 1728, 2688),
131
+ "groups": 40,
132
+ "inc_sec": (16, 32, 32, 128),
133
+ "k_r": 160,
134
+ "k_sec": (3, 6, 20, 3),
135
+ "num_classes": 1000,
136
+ "num_init_features": 96,
137
+ "test_time_pool": True,
138
+ },
139
+ },
140
+ "dpn107": {
141
+ "encoder": DPNEncoder,
142
+ "pretrained_settings": pretrained_settings["dpn107"],
143
+ "params": {
144
+ "stage_idxs": (5, 13, 33, 37),
145
+ "out_channels": (3, 128, 376, 1152, 2432, 2688),
146
+ "groups": 50,
147
+ "inc_sec": (20, 64, 64, 128),
148
+ "k_r": 200,
149
+ "k_sec": (4, 8, 20, 3),
150
+ "num_classes": 1000,
151
+ "num_init_features": 128,
152
+ "test_time_pool": True,
153
+ },
154
+ },
155
+ "dpn131": {
156
+ "encoder": DPNEncoder,
157
+ "pretrained_settings": pretrained_settings["dpn131"],
158
+ "params": {
159
+ "stage_idxs": (5, 13, 41, 45),
160
+ "out_channels": (3, 128, 352, 832, 1984, 2688),
161
+ "groups": 40,
162
+ "inc_sec": (16, 32, 32, 128),
163
+ "k_r": 160,
164
+ "k_sec": (4, 8, 28, 3),
165
+ "num_classes": 1000,
166
+ "num_init_features": 128,
167
+ "test_time_pool": True,
168
+ },
169
+ },
170
+ }