M3L / models /get_model.py
harshm121's picture
Working demo
d4ebf73
raw
history blame
7.11 kB
from models.segmentation_models.cen import ChannelExchangingNetwork
from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask
from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency
from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency
from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency
from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix
from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency
from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion
from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP
from models.segmentation_models.refinenet import MyRefineNet
from models.segmentation_models.segformer.segformer import SegFormer
from models.segmentation_models.tokenfusion.segformer import WeTr
from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask
from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency
from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency
from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch
from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork
from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop
def get_model(args, **kwargs):
if args.seg_model == "dlv3p":
if args.base_model == "r18":
return DeepLabV3p_r18(args.num_classes, args)
elif args.base_model == "r50":
return DeepLabV3p_r50(args.num_classes, args)
elif args.base_model == "r101":
return DeepLabV3p_r101(args.num_classes, args)
else:
raise Exception(f"{args.base_model} not configured")
elif args.seg_model == 'refinenet':
if args.base_model == 'r18':
return MyRefineNet(num_layers = 18, num_classes = args.num_classes)
if args.base_model == 'r50':
return MyRefineNet(num_layers = 50, num_classes = args.num_classes)
if args.base_model == 'r101':
return MyRefineNet(num_layers = 101, num_classes = args.num_classes)
elif args.seg_model == 'cen':
if args.base_model == 'r18':
return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
if args.base_model == 'r50':
return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
if args.base_model == 'r101':
return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
elif args.seg_model == 'segformer':
return SegFormer(args.base_model, args, num_classes= args.num_classes)
elif args.seg_model == 'tokenfusion':
return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == 'randomfusion':
return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'randomfusiondmlp':
return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'randomexchangepredconsistency':
return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes)
elif args.seg_model == 'linearfusion':
pretrained = True
if "pretrained_init" in args:
pretrained = args.pretrained_init
print("Using pretrained SegFormer? ", pretrained)
return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained)
elif args.seg_model == 'linearfusionconsistency':
return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes)
elif args.seg_model == 'linearfusionmaskedcons':
pretrained = True
if "pretrained_init" in args:
pretrained = args.pretrained_init
print("Using pretrained SegFormer? ", pretrained)
return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained)
elif args.seg_model == 'linearfusionmaskedconsmixbatch':
return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'linearfusionsepdecodermaskedcons':
return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'linearfusionmaemaskedcons':
return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'tokenfusionmaskedcons':
return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == 'tokenfusionmaskedconsmixbatch':
return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == 'tokenfusionbothmask':
return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs)
elif args.seg_model == "linearfusebothmask":
return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == "linearfusiontokenmix":
return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent)
elif args.seg_model == "tokenfusionmaemaskedcons":
return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == "unifiedrepresentationnetwork":
return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == "unifiedrepresentationnetworkmoddrop":
return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes)
else:
raise Exception(f"{args.seg_model} not configured")