Spaces:
Runtime error
Runtime error
File size: 965 Bytes
2537604 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import timm
import torch.nn as nn
class TimmUniversalEncoder(nn.Module):
def __init__(self, name, pretrained=True, in_channels=3, depth=5, output_stride=32):
super().__init__()
kwargs = dict(
in_chans=in_channels,
features_only=True,
output_stride=output_stride,
pretrained=pretrained,
out_indices=tuple(range(depth)),
)
# not all models support output stride argument, drop it by default
if output_stride == 32:
kwargs.pop("output_stride")
self.model = timm.create_model(name, **kwargs)
self._in_channels = in_channels
self._out_channels = [in_channels, ] + self.model.feature_info.channels()
self._depth = depth
def forward(self, x):
features = self.model(x)
features = [x,] + features
return features
@property
def out_channels(self):
return self._out_channels
|