josedolot commited on
Commit
5b4e116
·
1 Parent(s): f3e1e8a

Upload encoders/__init__.py

Browse files
Files changed (1) hide show
  1. encoders/__init__.py +105 -0
encoders/__init__.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import torch.utils.model_zoo as model_zoo
3
+
4
+ from .resnet import resnet_encoders
5
+ from .dpn import dpn_encoders
6
+ from .vgg import vgg_encoders
7
+ from .senet import senet_encoders
8
+ from .densenet import densenet_encoders
9
+ from .inceptionresnetv2 import inceptionresnetv2_encoders
10
+ from .inceptionv4 import inceptionv4_encoders
11
+ from .efficientnet import efficient_net_encoders
12
+ from .mobilenet import mobilenet_encoders
13
+ from .xception import xception_encoders
14
+ from .timm_efficientnet import timm_efficientnet_encoders
15
+ from .timm_resnest import timm_resnest_encoders
16
+ from .timm_res2net import timm_res2net_encoders
17
+ from .timm_regnet import timm_regnet_encoders
18
+ from .timm_sknet import timm_sknet_encoders
19
+ from .timm_mobilenetv3 import timm_mobilenetv3_encoders
20
+ from .timm_gernet import timm_gernet_encoders
21
+
22
+ from .timm_universal import TimmUniversalEncoder
23
+
24
+ from ._preprocessing import preprocess_input
25
+
26
+ encoders = {}
27
+ encoders.update(resnet_encoders)
28
+ encoders.update(dpn_encoders)
29
+ encoders.update(vgg_encoders)
30
+ encoders.update(senet_encoders)
31
+ encoders.update(densenet_encoders)
32
+ encoders.update(inceptionresnetv2_encoders)
33
+ encoders.update(inceptionv4_encoders)
34
+ encoders.update(efficient_net_encoders)
35
+ encoders.update(mobilenet_encoders)
36
+ encoders.update(xception_encoders)
37
+ encoders.update(timm_efficientnet_encoders)
38
+ encoders.update(timm_resnest_encoders)
39
+ encoders.update(timm_res2net_encoders)
40
+ encoders.update(timm_regnet_encoders)
41
+ encoders.update(timm_sknet_encoders)
42
+ encoders.update(timm_mobilenetv3_encoders)
43
+ encoders.update(timm_gernet_encoders)
44
+
45
+
46
+ def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs):
47
+
48
+ if name.startswith("tu-"):
49
+ name = name[3:]
50
+ encoder = TimmUniversalEncoder(
51
+ name=name,
52
+ in_channels=in_channels,
53
+ depth=depth,
54
+ output_stride=output_stride,
55
+ pretrained=weights is not None,
56
+ **kwargs
57
+ )
58
+ return encoder
59
+
60
+ try:
61
+ Encoder = encoders[name]["encoder"]
62
+ except KeyError:
63
+ raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys())))
64
+
65
+ params = encoders[name]["params"]
66
+ params.update(depth=depth)
67
+ encoder = Encoder(**params)
68
+
69
+ if weights is not None:
70
+ try:
71
+ settings = encoders[name]["pretrained_settings"][weights]
72
+ except KeyError:
73
+ raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format(
74
+ weights, name, list(encoders[name]["pretrained_settings"].keys()),
75
+ ))
76
+ encoder.load_state_dict(model_zoo.load_url(settings["url"]))
77
+
78
+ encoder.set_in_channels(in_channels, pretrained=weights is not None)
79
+ if output_stride != 32:
80
+ encoder.make_dilated(output_stride)
81
+
82
+ return encoder
83
+
84
+
85
+ def get_encoder_names():
86
+ return list(encoders.keys())
87
+
88
+
89
+ def get_preprocessing_params(encoder_name, pretrained="imagenet"):
90
+ settings = encoders[encoder_name]["pretrained_settings"]
91
+
92
+ if pretrained not in settings.keys():
93
+ raise ValueError("Available pretrained options {}".format(settings.keys()))
94
+
95
+ formatted_settings = {}
96
+ formatted_settings["input_space"] = settings[pretrained].get("input_space")
97
+ formatted_settings["input_range"] = settings[pretrained].get("input_range")
98
+ formatted_settings["mean"] = settings[pretrained].get("mean")
99
+ formatted_settings["std"] = settings[pretrained].get("std")
100
+ return formatted_settings
101
+
102
+
103
+ def get_preprocessing_fn(encoder_name, pretrained="imagenet"):
104
+ params = get_preprocessing_params(encoder_name, pretrained=pretrained)
105
+ return functools.partial(preprocess_input, **params)