Spaces:
Runtime error
Runtime error
import functools | |
import torch.utils.model_zoo as model_zoo | |
from .resnet import resnet_encoders | |
from .dpn import dpn_encoders | |
from .vgg import vgg_encoders | |
from .senet import senet_encoders | |
from .densenet import densenet_encoders | |
from .inceptionresnetv2 import inceptionresnetv2_encoders | |
from .inceptionv4 import inceptionv4_encoders | |
from .efficientnet import efficient_net_encoders | |
from .mobilenet import mobilenet_encoders | |
from .xception import xception_encoders | |
from .timm_efficientnet import timm_efficientnet_encoders | |
from .timm_resnest import timm_resnest_encoders | |
from .timm_res2net import timm_res2net_encoders | |
from .timm_regnet import timm_regnet_encoders | |
from .timm_sknet import timm_sknet_encoders | |
from .timm_mobilenetv3 import timm_mobilenetv3_encoders | |
from .timm_gernet import timm_gernet_encoders | |
from .timm_universal import TimmUniversalEncoder | |
from ._preprocessing import preprocess_input | |
encoders = {} | |
encoders.update(resnet_encoders) | |
encoders.update(dpn_encoders) | |
encoders.update(vgg_encoders) | |
encoders.update(senet_encoders) | |
encoders.update(densenet_encoders) | |
encoders.update(inceptionresnetv2_encoders) | |
encoders.update(inceptionv4_encoders) | |
encoders.update(efficient_net_encoders) | |
encoders.update(mobilenet_encoders) | |
encoders.update(xception_encoders) | |
encoders.update(timm_efficientnet_encoders) | |
encoders.update(timm_resnest_encoders) | |
encoders.update(timm_res2net_encoders) | |
encoders.update(timm_regnet_encoders) | |
encoders.update(timm_sknet_encoders) | |
encoders.update(timm_mobilenetv3_encoders) | |
encoders.update(timm_gernet_encoders) | |
def get_encoder(name, in_channels=3, depth=5, weights=None, output_stride=32, **kwargs): | |
if name.startswith("tu-"): | |
name = name[3:] | |
encoder = TimmUniversalEncoder( | |
name=name, | |
in_channels=in_channels, | |
depth=depth, | |
output_stride=output_stride, | |
pretrained=weights is not None, | |
**kwargs | |
) | |
return encoder | |
try: | |
Encoder = encoders[name]["encoder"] | |
except KeyError: | |
raise KeyError("Wrong encoder name `{}`, supported encoders: {}".format(name, list(encoders.keys()))) | |
params = encoders[name]["params"] | |
params.update(depth=depth) | |
encoder = Encoder(**params) | |
if weights is not None: | |
try: | |
settings = encoders[name]["pretrained_settings"][weights] | |
except KeyError: | |
raise KeyError("Wrong pretrained weights `{}` for encoder `{}`. Available options are: {}".format( | |
weights, name, list(encoders[name]["pretrained_settings"].keys()), | |
)) | |
encoder.load_state_dict(model_zoo.load_url(settings["url"])) | |
encoder.set_in_channels(in_channels, pretrained=weights is not None) | |
if output_stride != 32: | |
encoder.make_dilated(output_stride) | |
return encoder | |
def get_encoder_names(): | |
return list(encoders.keys()) | |
def get_preprocessing_params(encoder_name, pretrained="imagenet"): | |
settings = encoders[encoder_name]["pretrained_settings"] | |
if pretrained not in settings.keys(): | |
raise ValueError("Available pretrained options {}".format(settings.keys())) | |
formatted_settings = {} | |
formatted_settings["input_space"] = settings[pretrained].get("input_space") | |
formatted_settings["input_range"] = settings[pretrained].get("input_range") | |
formatted_settings["mean"] = settings[pretrained].get("mean") | |
formatted_settings["std"] = settings[pretrained].get("std") | |
return formatted_settings | |
def get_preprocessing_fn(encoder_name, pretrained="imagenet"): | |
params = get_preprocessing_params(encoder_name, pretrained=pretrained) | |
return functools.partial(preprocess_input, **params) | |