|
from models.segmentation_models.cen import ChannelExchangingNetwork |
|
from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50 |
|
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion |
|
from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask |
|
from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency |
|
from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency |
|
from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency |
|
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch |
|
from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency |
|
from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix |
|
from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency |
|
from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion |
|
from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP |
|
from models.segmentation_models.refinenet import MyRefineNet |
|
from models.segmentation_models.segformer.segformer import SegFormer |
|
from models.segmentation_models.tokenfusion.segformer import WeTr |
|
from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask |
|
from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency |
|
from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency |
|
from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch |
|
from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork |
|
from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop |
|
|
|
def get_model(args, **kwargs): |
|
if args.seg_model == "dlv3p": |
|
if args.base_model == "r18": |
|
return DeepLabV3p_r18(args.num_classes, args) |
|
elif args.base_model == "r50": |
|
return DeepLabV3p_r50(args.num_classes, args) |
|
elif args.base_model == "r101": |
|
return DeepLabV3p_r101(args.num_classes, args) |
|
else: |
|
raise Exception(f"{args.base_model} not configured") |
|
elif args.seg_model == 'refinenet': |
|
if args.base_model == 'r18': |
|
return MyRefineNet(num_layers = 18, num_classes = args.num_classes) |
|
if args.base_model == 'r50': |
|
return MyRefineNet(num_layers = 50, num_classes = args.num_classes) |
|
if args.base_model == 'r101': |
|
return MyRefineNet(num_layers = 101, num_classes = args.num_classes) |
|
elif args.seg_model == 'cen': |
|
if args.base_model == 'r18': |
|
return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) |
|
if args.base_model == 'r50': |
|
return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) |
|
if args.base_model == 'r101': |
|
return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold) |
|
elif args.seg_model == 'segformer': |
|
return SegFormer(args.base_model, args, num_classes= args.num_classes) |
|
elif args.seg_model == 'tokenfusion': |
|
return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) |
|
elif args.seg_model == 'randomfusion': |
|
return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == 'randomfusiondmlp': |
|
return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == 'randomexchangepredconsistency': |
|
return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) |
|
elif args.seg_model == 'linearfusion': |
|
pretrained = True |
|
if "pretrained_init" in args: |
|
pretrained = args.pretrained_init |
|
print("Using pretrained SegFormer? ", pretrained) |
|
return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) |
|
elif args.seg_model == 'linearfusionconsistency': |
|
return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes) |
|
elif args.seg_model == 'linearfusionmaskedcons': |
|
pretrained = True |
|
if "pretrained_init" in args: |
|
pretrained = args.pretrained_init |
|
print("Using pretrained SegFormer? ", pretrained) |
|
return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained) |
|
elif args.seg_model == 'linearfusionmaskedconsmixbatch': |
|
return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == 'linearfusionsepdecodermaskedcons': |
|
return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == 'linearfusionmaemaskedcons': |
|
return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == 'tokenfusionmaskedcons': |
|
return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) |
|
elif args.seg_model == 'tokenfusionmaskedconsmixbatch': |
|
return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) |
|
elif args.seg_model == 'tokenfusionbothmask': |
|
return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs) |
|
elif args.seg_model == "linearfusebothmask": |
|
return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == "linearfusiontokenmix": |
|
return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent) |
|
elif args.seg_model == "tokenfusionmaemaskedcons": |
|
return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes) |
|
elif args.seg_model == "unifiedrepresentationnetwork": |
|
return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes) |
|
elif args.seg_model == "unifiedrepresentationnetworkmoddrop": |
|
return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes) |
|
else: |
|
raise Exception(f"{args.seg_model} not configured") |