File size: 5,935 Bytes
731f5de
02768a2
1f3a9b6
 
02768a2
 
 
731f5de
02768a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
731f5de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db8dccd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f3a9b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Model definitions"""
import tensorflow as tf
from transformers import TFAutoModel, TFViTModel
from kapre.augmentation import SpecAugment


class FixMatchTune(tf.keras.Model):
    """fixmatch"""
    def __init__(
        self,
        encoder_name="readerbench/RoBERT-base",
        num_classes=4,
        **kwargs
    ):
        super(FixMatchTune,self).__init__(**kwargs)

        self.bert = TFAutoModel.from_pretrained(encoder_name)
        self.num_classes = num_classes
        self.weak_augment = tf.keras.layers.GaussianNoise(stddev=0.5)
        self.strong_augment = tf.keras.layers.GaussianNoise(stddev=5)

        self.cls_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64,activation="relu"),
            tf.keras.layers.Dense(self.num_classes, activation="softmax")
        ])

    def call(self, inputs, training):
        ids, mask = inputs

        embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output

        strongs = self.strong_augment(embeds,training=training)
        weaks = self.weak_augment(embeds,training=training)

        strong_preds = self.cls_head(strongs,training=training)
        weak_preds = self.cls_head(weaks,training=training)

        return weak_preds, strong_preds

class MixMatch(tf.keras.Model):
    """mixmatch"""
    def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs):
        super(MixMatch,self).__init__(**kwargs)
        self.bert = TFAutoModel.from_pretrained(bert_model)

        self.num_classes = num_classes

        self.cls_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64,activation="relu"),
            tf.keras.layers.Dense(self.num_classes, activation="softmax")
        ])

        self.augment = tf.keras.layers.GaussianNoise(stddev=2)

    def call(self, inputs, training):
        ids, mask = inputs

        embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output
        augs = self.augment(embeds,training=training)

        return self.cls_head(augs,training=training)
    
class LPModel(tf.keras.Model):
    """label propagation"""
    def __init__(self,bert_model="readerbench/RoBERT-base",num_classes=4,**kwargs):
        super(LPModel,self).__init__(**kwargs)
        self.bert = TFAutoModel.from_pretrained(bert_model)
        self.num_classes = num_classes

        self.cls_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64,activation="relu"),
            tf.keras.layers.Dense(self.num_classes, activation="softmax")
        ])

    def call(self, inputs, training):
        ids, mask = inputs

        embeds = self.bert(input_ids=ids, attention_mask=mask,training=training).pooler_output

        return self.cls_head(embeds, training=training)
    
class AudioFixMatch(tf.keras.Model):
    def __init__(self, encoder_name='google/vit-base-patch16-224', num_classes=6, **kwargs):
        super(AudioFixMatch, self).__init__(**kwargs)
        self.vit = TFViTModel.from_pretrained(encoder_name)
        self.num_classes = num_classes
        self.cls_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64,activation="relu"),
            tf.keras.layers.Dense(self.num_classes, activation="softmax")
        ])
        self.strong_augment = SpecAugment(
            freq_mask_param=8,
            time_mask_param=8,
            n_freq_masks=2,
            n_time_masks=2,
            mask_value=0.0,
            data_format="channels_first"
        )
        self.weak_augment = SpecAugment(
            freq_mask_param=2,
            time_mask_param=2,
            n_freq_masks=2,
            n_time_masks=2,
            mask_value=0.0,
            data_format="channels_first"
        )

    def call(self, inputs, training):

        strong = self.strong_augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
        weak = self.weak_augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
        embeds_strong = self.vit(pixel_values=tf.repeat(strong,3,axis=1),training=training).pooler_output
        embeds_weak = self.vit(pixel_values=tf.repeat(weak,3,axis=1),training=training).pooler_output

        return self.cls_head(embeds_weak), self.cls_head(embeds_strong)
    
class AudioMixMatch(tf.keras.Model):
    def __init__(self, encoder_name='google/vit-base-patch16-224', num_classes=6, **kwargs):
        super(AudioMixMatch, self).__init__(**kwargs)
        self.vit = TFViTModel.from_pretrained(encoder_name)
        self.num_classes = num_classes
        self.cls_head = tf.keras.Sequential([
            tf.keras.layers.Dense(256,activation="relu"),
            tf.keras.layers.Dropout(0.2),
            tf.keras.layers.Dense(64,activation="relu"),
            tf.keras.layers.Dense(self.num_classes, activation="softmax")
        ])
        self.augment = SpecAugment(
            freq_mask_param=3,
            time_mask_param=3,
            n_freq_masks=2,
            n_time_masks=2,
            mask_value=0.0,
            data_format="channels_first"
        )

    def aug_features(self, inputs, training):
        aug = self.augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
        embeds = self.vit(pixel_values=tf.repeat(aug,3,axis=1),training=training).pooler_output
        return embeds

    def call(self, inputs, training):

        aug = self.augment(inputs[:,0,:,:][:,tf.newaxis,:,:],training=training)
        embeds = self.vit(pixel_values=tf.repeat(aug,3,axis=1),training=training).pooler_output
        
        return self.cls_head(embeds)