File size: 7,106 Bytes
d4ebf73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from models.segmentation_models.cen import ChannelExchangingNetwork
from models.segmentation_models.deeplabv3p import DeepLabV3p_r101, DeepLabV3p_r18, DeepLabV3p_r50
from models.segmentation_models.linearfuse.segformer import WeTrLinearFusion
from models.segmentation_models.linearfusebothmask.segformer import LinearFusionBothMask
from models.segmentation_models.linearfusecons.segformer import LinearFusionConsistency
from models.segmentation_models.linearfusemaemaskedcons.segformer import LinearFusionMAEMaskedConsistency
from models.segmentation_models.linearfusemaskedcons.segformer import LinearFusionMaskedConsistency
from models.segmentation_models.linearfusemaskedconsmixbatch.segformer import LinearFusionMaskedConsistencyMixBatch
from models.segmentation_models.linearfusesepdecodermaskedcons.segformer import LinearFusionSepDecoderMaskedConsistency
from models.segmentation_models.linearfusetokenmix.segformer import LinearFusionTokenMix
from models.segmentation_models.randomexchangecons.segformer import RandomExchangePredConsistency
from models.segmentation_models.randomfusion.segformer import WeTrRandomFusion
from models.segmentation_models.randomfusiondmlp.segformer import WeTrRandomFusionDMLP
from models.segmentation_models.refinenet import MyRefineNet
from models.segmentation_models.segformer.segformer import SegFormer
from models.segmentation_models.tokenfusion.segformer import WeTr
from models.segmentation_models.tokenfusionbothmask.segformer import TokenFusionBothMask
from models.segmentation_models.tokenfusionmaemaskedconsistency.segformer import TokenFusionMAEMaskedConsistency
from models.segmentation_models.tokenfusionmaskedconsistency.segformer import TokenFusionMaskedConsistency
from models.segmentation_models.tokenfusionmaskedconsistencymixbatch.segformer import TokenFusionMaskedConsistencyMixBatch
from models.segmentation_models.unifiedrepresentation.segformer import UnifiedRepresentationNetwork
from models.segmentation_models.unifiedrepresentationmoddrop.segformer import UnifiedRepresentationNetworkModDrop

def get_model(args, **kwargs):
    if args.seg_model == "dlv3p":
        if args.base_model == "r18":
            return DeepLabV3p_r18(args.num_classes, args)
        elif args.base_model == "r50":
            return DeepLabV3p_r50(args.num_classes, args)
        elif args.base_model == "r101":
            return DeepLabV3p_r101(args.num_classes, args)
        else:
            raise Exception(f"{args.base_model} not configured")
    elif args.seg_model == 'refinenet':
        if args.base_model == 'r18':
            return MyRefineNet(num_layers = 18, num_classes = args.num_classes)
        if args.base_model == 'r50':
            return MyRefineNet(num_layers = 50, num_classes = args.num_classes)
        if args.base_model == 'r101':
            return MyRefineNet(num_layers = 101, num_classes = args.num_classes)
    elif args.seg_model == 'cen':
        if args.base_model == 'r18':
            return ChannelExchangingNetwork(num_layers = 18, num_classes = args.num_classes,  num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
        if args.base_model == 'r50':
            return ChannelExchangingNetwork(num_layers = 50, num_classes = args.num_classes,  num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
        if args.base_model == 'r101':
            return ChannelExchangingNetwork(num_layers = 101, num_classes = args.num_classes,  num_parallel = 2, l1_lambda = args.l1_lambda, bn_threshold = args.exchange_threshold)
    elif args.seg_model == 'segformer':
        return SegFormer(args.base_model, args, num_classes=  args.num_classes)
    elif args.seg_model == 'tokenfusion':
        return WeTr(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
    elif args.seg_model == 'randomfusion':
        return WeTrRandomFusion(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == 'randomfusiondmlp':
        return WeTrRandomFusionDMLP(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == 'randomexchangepredconsistency':
        return RandomExchangePredConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes)
    elif args.seg_model == 'linearfusion':
        pretrained = True
        if "pretrained_init" in args:
            pretrained = args.pretrained_init
        print("Using pretrained SegFormer? ", pretrained)
        return WeTrLinearFusion(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained)
    elif args.seg_model == 'linearfusionconsistency':
        return LinearFusionConsistency(args.base_model, args, cons_lambda = args.cons_lambda, num_classes = args.num_classes)
    elif args.seg_model == 'linearfusionmaskedcons':
        pretrained = True
        if "pretrained_init" in args:
            pretrained = args.pretrained_init
        print("Using pretrained SegFormer? ", pretrained)
        return LinearFusionMaskedConsistency(args.base_model, args, num_classes = args.num_classes, pretrained=pretrained)
    elif args.seg_model == 'linearfusionmaskedconsmixbatch':
        return LinearFusionMaskedConsistencyMixBatch(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == 'linearfusionsepdecodermaskedcons':
        return LinearFusionSepDecoderMaskedConsistency(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == 'linearfusionmaemaskedcons':
        return LinearFusionMAEMaskedConsistency(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == 'tokenfusionmaskedcons':
        return TokenFusionMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
    elif args.seg_model == 'tokenfusionmaskedconsmixbatch':
        return TokenFusionMaskedConsistencyMixBatch(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
    elif args.seg_model == 'tokenfusionbothmask':
        return TokenFusionBothMask(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes, **kwargs)
    elif args.seg_model == "linearfusebothmask":
        return LinearFusionBothMask(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == "linearfusiontokenmix":
        return LinearFusionTokenMix(args.base_model, args, num_classes = args.num_classes, exchange_percent = args.exchange_percent)
    elif args.seg_model == "tokenfusionmaemaskedcons":
        return TokenFusionMAEMaskedConsistency(args.base_model, args, l1_lambda = args.l1_lambda, num_classes = args.num_classes)
    elif args.seg_model == "unifiedrepresentationnetwork":
        return UnifiedRepresentationNetwork(args.base_model, args, num_classes = args.num_classes)
    elif args.seg_model == "unifiedrepresentationnetworkmoddrop":
        return UnifiedRepresentationNetworkModDrop(args.base_model, args, num_classes = args.num_classes)  
    else:
        raise Exception(f"{args.seg_model} not configured")