File size: 7,106 Bytes
d4ebf73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
from models.segmentation_models.cen import ChannelExchangingNetwork
from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask
from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency
from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency
from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency
from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix
from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency
from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion
from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP
from models.segmentation_models.refinenet import MyRefineNet
from models.segmentation_models.segformer.segformer import SegFormer
from models.segmentation_models.tokenfusion.segformer import WeTr
from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask
from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency
from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency
from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch
from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork
from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop
def get_model(args, **kwargs):
if args.seg_model == "dlv3p":
if args.base_model == "r18":
return DeepLabV3p_r18(args.num_classes, args)
elif args.base_model == "r50":
return DeepLabV3p_r50(args.num_classes, args)
elif args.base_model == "r101":
return DeepLabV3p_r101(args.num_classes, args)
else:
raise Exception(f"{args.base_model} not configured")
elif args.seg_model == 'refinenet':
if args.base_model == 'r18':
return MyRefineNet(num_layers = 18, num_classes = args.num_classes)
if args.base_model == 'r50':
return MyRefineNet(num_layers = 50, num_classes = args.num_classes)
if args.base_model == 'r101':
return MyRefineNet(num_layers = 101, num_classes = args.num_classes)
elif args.seg_model == 'cen':
if args.base_model == 'r18':
return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
if args.base_model == 'r50':
return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
if args.base_model == 'r101':
return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes, num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
elif args.seg_model == 'segformer':
return SegFormer(args.base_model, args, num_classes= args.num_classes)
elif args.seg_model == 'tokenfusion':
return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == 'randomfusion':
return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'randomfusiondmlp':
return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'randomexchangepredconsistency':
return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes)
elif args.seg_model == 'linearfusion':
pretrained = True
if "pretrained_init" in args:
pretrained = args.pretrained_init
print("Using pretrained SegFormer? ", pretrained)
return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained)
elif args.seg_model == 'linearfusionconsistency':
return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes)
elif args.seg_model == 'linearfusionmaskedcons':
pretrained = True
if "pretrained_init" in args:
pretrained = args.pretrained_init
print("Using pretrained SegFormer? ", pretrained)
return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained)
elif args.seg_model == 'linearfusionmaskedconsmixbatch':
return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'linearfusionsepdecodermaskedcons':
return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'linearfusionmaemaskedcons':
return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == 'tokenfusionmaskedcons':
return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == 'tokenfusionmaskedconsmixbatch':
return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == 'tokenfusionbothmask':
return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs)
elif args.seg_model == "linearfusebothmask":
return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == "linearfusiontokenmix":
return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent)
elif args.seg_model == "tokenfusionmaemaskedcons":
return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
elif args.seg_model == "unifiedrepresentationnetwork":
return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes)
elif args.seg_model == "unifiedrepresentationnetworkmoddrop":
return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes)
else:
raise Exception(f"{args.seg_model} not configured") |