kushal1506 commited on
Commit
f526c1c
·
verified ·
1 Parent(s): 0b1245a

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -35
  2. README.md +13 -13
  3. app.py +59 -0
  4. final_model.pth +3 -0
  5. model.py +597 -0
  6. requirements.txt +5 -0
  7. xlsr2_300m.pt +3 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: Audio Deep Fake
3
- emoji: 👁
4
- colorFrom: yellow
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 4.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: AudioDeepFake
3
+ emoji: 📊
4
+ colorFrom: purple
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.28.3
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import librosa
3
+ import numpy as np
4
+ import torch
5
+ from torch import Tensor
6
+ import torch.nn as nn
7
+ from model import Model
8
+
9
+ model_path = 'final_model.pth'
10
+ def load_data(path):
11
+ X, fs = librosa.load(path)
12
+ X_pad = pad(X,64600)
13
+ x_inp = Tensor(X_pad).unsqueeze(0)
14
+ return x_inp,fs
15
+
16
+ def pad(x, max_len=64600):
17
+ x_len = x.shape[0]
18
+ if x_len >= max_len:
19
+ return x[:max_len]
20
+ # need to pad
21
+ num_repeats = int(max_len / x_len)+1
22
+ padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
23
+ return padded_x
24
+
25
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
26
+ model = Model(None, device)
27
+ nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
28
+ model =nn.DataParallel(model).to(device)
29
+
30
+ model.load_state_dict(torch.load(model_path, map_location=device))
31
+ print("Model loaded : {}".format(model_path))
32
+
33
+ model.eval()
34
+ prediction_dict = {0: 'Fake', 1: 'Real'}
35
+ def Detection(audio_1):
36
+
37
+ x_inp,fs = load_data(audio_1)
38
+ print(x_inp.shape)
39
+ validity_probs = model(x_inp)
40
+ validity_probs = torch.nn.functional.softmax(validity_probs, dim=1)
41
+
42
+ emotion = torch.argmax(validity_probs).item()
43
+ print(emotion)
44
+ validity = prediction_dict[emotion]
45
+ # validity as a dictionary of class probabilities
46
+ # validity = {prediction_dict[i]: float(validity_probs[0][i]) for i in range(2)}
47
+
48
+ return validity
49
+
50
+ audio_1 = gr.Audio(type="filepath", label="Audio 1")
51
+ # text_output = gr.Textbox(label="Prediction")
52
+ text_output = gr.Textbox(label="Similarity Score")
53
+ gr.Interface(
54
+ fn=Detection,
55
+ inputs=audio_1,
56
+ outputs=text_output,
57
+ title="Audio Deepfake Detection",
58
+ description="Audio Deepfake Detection using finetuned model on for-2seconds dataset.",
59
+ ).launch()
final_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9b8cbd6c9edd278e22605a9cb58c212405a1364eaa1e145e54aeea1ab06d9ca2
3
+ size 1271630150
model.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ import fairseq
10
+
11
+
12
+ ___author__ = "Hemlata Tak"
13
+ __email__ = "[email protected]"
14
+
15
+ ############################
16
+ ## FOR fine-tuned SSL MODEL
17
+ ############################
18
+
19
+
20
+ class SSLModel(nn.Module):
21
+ def __init__(self,device):
22
+ super(SSLModel, self).__init__()
23
+
24
+ cp_path = 'xlsr2_300m.pt' # Change the pre-trained XLSR model path.
25
+ model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path])
26
+ self.model = model[0]
27
+ self.device=device
28
+ self.out_dim = 1024
29
+ return
30
+
31
+ def extract_feat(self, input_data):
32
+
33
+ # put the model to GPU if it not there
34
+ if next(self.model.parameters()).device != input_data.device \
35
+ or next(self.model.parameters()).dtype != input_data.dtype:
36
+ self.model.to(input_data.device, dtype=input_data.dtype)
37
+ self.model.train()
38
+
39
+
40
+ if True:
41
+ # input should be in shape (batch, length)
42
+ if input_data.ndim == 3:
43
+ input_tmp = input_data[:, :, 0]
44
+ else:
45
+ input_tmp = input_data
46
+
47
+ # [batch, length, dim]
48
+ emb = self.model(input_tmp, mask=False, features_only=True)['x']
49
+ return emb
50
+
51
+
52
+ #---------AASIST back-end------------------------#
53
+ ''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans.
54
+ AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks.
55
+ In Proc. ICASSP 2022, pp: 6367--6371.'''
56
+
57
+
58
+ class GraphAttentionLayer(nn.Module):
59
+ def __init__(self, in_dim, out_dim, **kwargs):
60
+ super().__init__()
61
+
62
+ # attention map
63
+ self.att_proj = nn.Linear(in_dim, out_dim)
64
+ self.att_weight = self._init_new_params(out_dim, 1)
65
+
66
+ # project
67
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
68
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
69
+
70
+ # batch norm
71
+ self.bn = nn.BatchNorm1d(out_dim)
72
+
73
+ # dropout for inputs
74
+ self.input_drop = nn.Dropout(p=0.2)
75
+
76
+ # activate
77
+ self.act = nn.SELU(inplace=True)
78
+
79
+ # temperature
80
+ self.temp = 1.
81
+ if "temperature" in kwargs:
82
+ self.temp = kwargs["temperature"]
83
+
84
+ def forward(self, x):
85
+ '''
86
+ x :(#bs, #node, #dim)
87
+ '''
88
+ # apply input dropout
89
+ x = self.input_drop(x)
90
+
91
+ # derive attention map
92
+ att_map = self._derive_att_map(x)
93
+
94
+ # projection
95
+ x = self._project(x, att_map)
96
+
97
+ # apply batch norm
98
+ x = self._apply_BN(x)
99
+ x = self.act(x)
100
+ return x
101
+
102
+ def _pairwise_mul_nodes(self, x):
103
+ '''
104
+ Calculates pairwise multiplication of nodes.
105
+ - for attention map
106
+ x :(#bs, #node, #dim)
107
+ out_shape :(#bs, #node, #node, #dim)
108
+ '''
109
+
110
+ nb_nodes = x.size(1)
111
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
112
+ x_mirror = x.transpose(1, 2)
113
+
114
+ return x * x_mirror
115
+
116
+ def _derive_att_map(self, x):
117
+ '''
118
+ x :(#bs, #node, #dim)
119
+ out_shape :(#bs, #node, #node, 1)
120
+ '''
121
+ att_map = self._pairwise_mul_nodes(x)
122
+ # size: (#bs, #node, #node, #dim_out)
123
+ att_map = torch.tanh(self.att_proj(att_map))
124
+ # size: (#bs, #node, #node, 1)
125
+ att_map = torch.matmul(att_map, self.att_weight)
126
+
127
+ # apply temperature
128
+ att_map = att_map / self.temp
129
+
130
+ att_map = F.softmax(att_map, dim=-2)
131
+
132
+ return att_map
133
+
134
+ def _project(self, x, att_map):
135
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
136
+ x2 = self.proj_without_att(x)
137
+
138
+ return x1 + x2
139
+
140
+ def _apply_BN(self, x):
141
+ org_size = x.size()
142
+ x = x.view(-1, org_size[-1])
143
+ x = self.bn(x)
144
+ x = x.view(org_size)
145
+
146
+ return x
147
+
148
+ def _init_new_params(self, *size):
149
+ out = nn.Parameter(torch.FloatTensor(*size))
150
+ nn.init.xavier_normal_(out)
151
+ return out
152
+
153
+
154
+ class HtrgGraphAttentionLayer(nn.Module):
155
+ def __init__(self, in_dim, out_dim, **kwargs):
156
+ super().__init__()
157
+
158
+ self.proj_type1 = nn.Linear(in_dim, in_dim)
159
+ self.proj_type2 = nn.Linear(in_dim, in_dim)
160
+
161
+ # attention map
162
+ self.att_proj = nn.Linear(in_dim, out_dim)
163
+ self.att_projM = nn.Linear(in_dim, out_dim)
164
+
165
+ self.att_weight11 = self._init_new_params(out_dim, 1)
166
+ self.att_weight22 = self._init_new_params(out_dim, 1)
167
+ self.att_weight12 = self._init_new_params(out_dim, 1)
168
+ self.att_weightM = self._init_new_params(out_dim, 1)
169
+
170
+ # project
171
+ self.proj_with_att = nn.Linear(in_dim, out_dim)
172
+ self.proj_without_att = nn.Linear(in_dim, out_dim)
173
+
174
+ self.proj_with_attM = nn.Linear(in_dim, out_dim)
175
+ self.proj_without_attM = nn.Linear(in_dim, out_dim)
176
+
177
+ # batch norm
178
+ self.bn = nn.BatchNorm1d(out_dim)
179
+
180
+ # dropout for inputs
181
+ self.input_drop = nn.Dropout(p=0.2)
182
+
183
+ # activate
184
+ self.act = nn.SELU(inplace=True)
185
+
186
+ # temperature
187
+ self.temp = 1.
188
+ if "temperature" in kwargs:
189
+ self.temp = kwargs["temperature"]
190
+
191
+ def forward(self, x1, x2, master=None):
192
+ '''
193
+ x1 :(#bs, #node, #dim)
194
+ x2 :(#bs, #node, #dim)
195
+ '''
196
+ #print('x1',x1.shape)
197
+ #print('x2',x2.shape)
198
+ num_type1 = x1.size(1)
199
+ num_type2 = x2.size(1)
200
+ #print('num_type1',num_type1)
201
+ #print('num_type2',num_type2)
202
+ x1 = self.proj_type1(x1)
203
+ #print('proj_type1',x1.shape)
204
+ x2 = self.proj_type2(x2)
205
+ #print('proj_type2',x2.shape)
206
+ x = torch.cat([x1, x2], dim=1)
207
+ #print('Concat x1 and x2',x.shape)
208
+
209
+ if master is None:
210
+ master = torch.mean(x, dim=1, keepdim=True)
211
+ #print('master',master.shape)
212
+ # apply input dropout
213
+ x = self.input_drop(x)
214
+
215
+ # derive attention map
216
+ att_map = self._derive_att_map(x, num_type1, num_type2)
217
+ #print('master',master.shape)
218
+ # directional edge for master node
219
+ master = self._update_master(x, master)
220
+ #print('master',master.shape)
221
+ # projection
222
+ x = self._project(x, att_map)
223
+ #print('proj x',x.shape)
224
+ # apply batch norm
225
+ x = self._apply_BN(x)
226
+ x = self.act(x)
227
+
228
+ x1 = x.narrow(1, 0, num_type1)
229
+ #print('x1',x1.shape)
230
+ x2 = x.narrow(1, num_type1, num_type2)
231
+ #print('x2',x2.shape)
232
+ return x1, x2, master
233
+
234
+ def _update_master(self, x, master):
235
+
236
+ att_map = self._derive_att_map_master(x, master)
237
+ master = self._project_master(x, master, att_map)
238
+
239
+ return master
240
+
241
+ def _pairwise_mul_nodes(self, x):
242
+ '''
243
+ Calculates pairwise multiplication of nodes.
244
+ - for attention map
245
+ x :(#bs, #node, #dim)
246
+ out_shape :(#bs, #node, #node, #dim)
247
+ '''
248
+
249
+ nb_nodes = x.size(1)
250
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
251
+ x_mirror = x.transpose(1, 2)
252
+
253
+ return x * x_mirror
254
+
255
+ def _derive_att_map_master(self, x, master):
256
+ '''
257
+ x :(#bs, #node, #dim)
258
+ out_shape :(#bs, #node, #node, 1)
259
+ '''
260
+ att_map = x * master
261
+ att_map = torch.tanh(self.att_projM(att_map))
262
+
263
+ att_map = torch.matmul(att_map, self.att_weightM)
264
+
265
+ # apply temperature
266
+ att_map = att_map / self.temp
267
+
268
+ att_map = F.softmax(att_map, dim=-2)
269
+
270
+ return att_map
271
+
272
+ def _derive_att_map(self, x, num_type1, num_type2):
273
+ '''
274
+ x :(#bs, #node, #dim)
275
+ out_shape :(#bs, #node, #node, 1)
276
+ '''
277
+ att_map = self._pairwise_mul_nodes(x)
278
+ # size: (#bs, #node, #node, #dim_out)
279
+ att_map = torch.tanh(self.att_proj(att_map))
280
+ # size: (#bs, #node, #node, 1)
281
+
282
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
283
+
284
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
285
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
286
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
287
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
288
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
289
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
290
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
291
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
292
+
293
+ att_map = att_board
294
+
295
+
296
+
297
+ # apply temperature
298
+ att_map = att_map / self.temp
299
+
300
+ att_map = F.softmax(att_map, dim=-2)
301
+
302
+ return att_map
303
+
304
+ def _project(self, x, att_map):
305
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
306
+ x2 = self.proj_without_att(x)
307
+
308
+ return x1 + x2
309
+
310
+ def _project_master(self, x, master, att_map):
311
+
312
+ x1 = self.proj_with_attM(torch.matmul(
313
+ att_map.squeeze(-1).unsqueeze(1), x))
314
+ x2 = self.proj_without_attM(master)
315
+
316
+ return x1 + x2
317
+
318
+ def _apply_BN(self, x):
319
+ org_size = x.size()
320
+ x = x.view(-1, org_size[-1])
321
+ x = self.bn(x)
322
+ x = x.view(org_size)
323
+
324
+ return x
325
+
326
+ def _init_new_params(self, *size):
327
+ out = nn.Parameter(torch.FloatTensor(*size))
328
+ nn.init.xavier_normal_(out)
329
+ return out
330
+
331
+
332
+ class GraphPool(nn.Module):
333
+ def __init__(self, k: float, in_dim: int, p: Union[float, int]):
334
+ super().__init__()
335
+ self.k = k
336
+ self.sigmoid = nn.Sigmoid()
337
+ self.proj = nn.Linear(in_dim, 1)
338
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
339
+ self.in_dim = in_dim
340
+
341
+ def forward(self, h):
342
+ Z = self.drop(h)
343
+ weights = self.proj(Z)
344
+ scores = self.sigmoid(weights)
345
+ new_h = self.top_k_graph(scores, h, self.k)
346
+
347
+ return new_h
348
+
349
+ def top_k_graph(self, scores, h, k):
350
+ """
351
+ args
352
+ =====
353
+ scores: attention-based weights (#bs, #node, 1)
354
+ h: graph data (#bs, #node, #dim)
355
+ k: ratio of remaining nodes, (float)
356
+ returns
357
+ =====
358
+ h: graph pool applied data (#bs, #node', #dim)
359
+ """
360
+ _, n_nodes, n_feat = h.size()
361
+ n_nodes = max(int(n_nodes * k), 1)
362
+ _, idx = torch.topk(scores, n_nodes, dim=1)
363
+ idx = idx.expand(-1, -1, n_feat)
364
+
365
+ h = h * scores
366
+ h = torch.gather(h, 1, idx)
367
+
368
+ return h
369
+
370
+
371
+
372
+
373
+ class Residual_block(nn.Module):
374
+ def __init__(self, nb_filts, first=False):
375
+ super().__init__()
376
+ self.first = first
377
+
378
+ if not self.first:
379
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
380
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
381
+ out_channels=nb_filts[1],
382
+ kernel_size=(2, 3),
383
+ padding=(1, 1),
384
+ stride=1)
385
+ self.selu = nn.SELU(inplace=True)
386
+
387
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
388
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
389
+ out_channels=nb_filts[1],
390
+ kernel_size=(2, 3),
391
+ padding=(0, 1),
392
+ stride=1)
393
+
394
+ if nb_filts[0] != nb_filts[1]:
395
+ self.downsample = True
396
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
397
+ out_channels=nb_filts[1],
398
+ padding=(0, 1),
399
+ kernel_size=(1, 3),
400
+ stride=1)
401
+
402
+ else:
403
+ self.downsample = False
404
+
405
+
406
+ def forward(self, x):
407
+ identity = x
408
+ if not self.first:
409
+ out = self.bn1(x)
410
+ out = self.selu(out)
411
+ else:
412
+ out = x
413
+
414
+ #print('out',out.shape)
415
+ out = self.conv1(x)
416
+
417
+ #print('aft conv1 out',out.shape)
418
+ out = self.bn2(out)
419
+ out = self.selu(out)
420
+ # print('out',out.shape)
421
+ out = self.conv2(out)
422
+ #print('conv2 out',out.shape)
423
+
424
+ if self.downsample:
425
+ identity = self.conv_downsample(identity)
426
+
427
+ out += identity
428
+ #out = self.mp(out)
429
+ return out
430
+
431
+
432
+ class Model(nn.Module):
433
+ def __init__(self, args,device):
434
+ super().__init__()
435
+ self.device = device
436
+
437
+ # AASIST parameters
438
+ filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]]
439
+ gat_dims = [64, 32]
440
+ pool_ratios = [0.5, 0.5, 0.5, 0.5]
441
+ temperatures = [2.0, 2.0, 100.0, 100.0]
442
+
443
+
444
+ ####
445
+ # create network wav2vec 2.0
446
+ ####
447
+ self.ssl_model = SSLModel(self.device)
448
+ self.LL = nn.Linear(self.ssl_model.out_dim, 128)
449
+
450
+ self.first_bn = nn.BatchNorm2d(num_features=1)
451
+ self.first_bn1 = nn.BatchNorm2d(num_features=64)
452
+ self.drop = nn.Dropout(0.5, inplace=True)
453
+ self.drop_way = nn.Dropout(0.2, inplace=True)
454
+ self.selu = nn.SELU(inplace=True)
455
+
456
+ # RawNet2 encoder
457
+ self.encoder = nn.Sequential(
458
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
459
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
460
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
461
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
462
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
463
+ nn.Sequential(Residual_block(nb_filts=filts[4])))
464
+
465
+ self.attention = nn.Sequential(
466
+ nn.Conv2d(64, 128, kernel_size=(1,1)),
467
+ nn.SELU(inplace=True),
468
+ nn.BatchNorm2d(128),
469
+ nn.Conv2d(128, 64, kernel_size=(1,1)),
470
+
471
+ )
472
+ # position encoding
473
+ self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1]))
474
+
475
+ self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
476
+ self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
477
+
478
+ # Graph module
479
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
480
+ gat_dims[0],
481
+ temperature=temperatures[0])
482
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
483
+ gat_dims[0],
484
+ temperature=temperatures[1])
485
+ # HS-GAL layer
486
+ self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
487
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
488
+ self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
489
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
490
+ self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
491
+ gat_dims[0], gat_dims[1], temperature=temperatures[2])
492
+ self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
493
+ gat_dims[1], gat_dims[1], temperature=temperatures[2])
494
+
495
+ # Graph pooling layers
496
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
497
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
498
+ self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
499
+ self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
500
+
501
+ self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
502
+ self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
503
+
504
+ self.out_layer = nn.Linear(5 * gat_dims[1], 2)
505
+
506
+ def forward(self, x):
507
+ #-------pre-trained Wav2vec model fine tunning ------------------------##
508
+ x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1))
509
+ x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim)
510
+
511
+ # post-processing on front-end features
512
+ x = x.transpose(1, 2) #(bs,feat_out_dim,frame_number)
513
+ x = x.unsqueeze(dim=1) # add channel
514
+ x = F.max_pool2d(x, (3, 3))
515
+ x = self.first_bn(x)
516
+ x = self.selu(x)
517
+
518
+ # RawNet2-based encoder
519
+ x = self.encoder(x)
520
+ x = self.first_bn1(x)
521
+ x = self.selu(x)
522
+
523
+ w = self.attention(x)
524
+
525
+ #------------SA for spectral feature-------------#
526
+ w1 = F.softmax(w,dim=-1)
527
+ m = torch.sum(x * w1, dim=-1)
528
+ e_S = m.transpose(1, 2) + self.pos_S
529
+
530
+ # graph module layer
531
+ gat_S = self.GAT_layer_S(e_S)
532
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
533
+
534
+ #------------SA for temporal feature-------------#
535
+ w2 = F.softmax(w,dim=-2)
536
+ m1 = torch.sum(x * w2, dim=-2)
537
+
538
+ e_T = m1.transpose(1, 2)
539
+
540
+ # graph module layer
541
+ gat_T = self.GAT_layer_T(e_T)
542
+ out_T = self.pool_T(gat_T)
543
+
544
+ # learnable master node
545
+ master1 = self.master1.expand(x.size(0), -1, -1)
546
+ master2 = self.master2.expand(x.size(0), -1, -1)
547
+
548
+ # inference 1
549
+ out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
550
+ out_T, out_S, master=self.master1)
551
+
552
+ out_S1 = self.pool_hS1(out_S1)
553
+ out_T1 = self.pool_hT1(out_T1)
554
+
555
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
556
+ out_T1, out_S1, master=master1)
557
+ out_T1 = out_T1 + out_T_aug
558
+ out_S1 = out_S1 + out_S_aug
559
+ master1 = master1 + master_aug
560
+
561
+ # inference 2
562
+ out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
563
+ out_T, out_S, master=self.master2)
564
+ out_S2 = self.pool_hS2(out_S2)
565
+ out_T2 = self.pool_hT2(out_T2)
566
+
567
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
568
+ out_T2, out_S2, master=master2)
569
+ out_T2 = out_T2 + out_T_aug
570
+ out_S2 = out_S2 + out_S_aug
571
+ master2 = master2 + master_aug
572
+
573
+ out_T1 = self.drop_way(out_T1)
574
+ out_T2 = self.drop_way(out_T2)
575
+ out_S1 = self.drop_way(out_S1)
576
+ out_S2 = self.drop_way(out_S2)
577
+ master1 = self.drop_way(master1)
578
+ master2 = self.drop_way(master2)
579
+
580
+ out_T = torch.max(out_T1, out_T2)
581
+ out_S = torch.max(out_S1, out_S2)
582
+ master = torch.max(master1, master2)
583
+
584
+ # Readout operation
585
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
586
+ T_avg = torch.mean(out_T, dim=1)
587
+
588
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
589
+ S_avg = torch.mean(out_S, dim=1)
590
+
591
+ last_hidden = torch.cat(
592
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
593
+
594
+ last_hidden = self.drop(last_hidden)
595
+ output = self.out_layer(last_hidden)
596
+
597
+ return output
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ librosa
2
+ numpy
3
+ torch
4
+ torchaudio
5
+ git+https://github.com/KhadgaA/fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1.git
xlsr2_300m.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b08927597f2c9eb2ebd7dcc3ac78ee4b5f6021cbac4b3a6c5a9deec445d80ed9
3
+ size 3808868242