josedolot commited on
Commit
2537604
·
1 Parent(s): 3e89704

Upload encoders/timm_universal.py

Browse files
Files changed (1) hide show
  1. encoders/timm_universal.py +34 -0
encoders/timm_universal.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import timm
2
+ import torch.nn as nn
3
+
4
+
5
+ class TimmUniversalEncoder(nn.Module):
6
+
7
+ def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32):
8
+ super().__init__()
9
+ kwargs = dict(
10
+ in_chans=in_channels,
11
+ features_only=True,
12
+ output_stride=output_stride,
13
+ pretrained=pretrained,
14
+ out_indices=tuple(range(depth)),
15
+ )
16
+
17
+ # not all models support output stride argument, drop it by default
18
+ if output_stride == 32:
19
+ kwargs.pop("output_stride")
20
+
21
+ self.model = timm.create_model(name, **kwargs)
22
+
23
+ self._in_channels = in_channels
24
+ self._out_channels = [in_channels, ] + self.model.feature_info.channels()
25
+ self._depth = depth
26
+
27
+ def forward(self, x):
28
+ features = self.model(x)
29
+ features = [x,] + features
30
+ return features
31
+
32
+ @property
33
+ def out_channels(self):
34
+ return self._out_channels